From 24a8dd22098c1e8c19830f70cba32581718bebc9 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 2 Feb 2023 17:40:15 +0530 Subject: [PATCH 001/391] Update lite.py As a feature request #56615, the _dtypes.int16 to be allowed when 16x8 quantization is not used so that the custom ops returning 16bit outputs can be benefitted. --- tensorflow/lite/python/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 41e5ed4fce2c6b..fc5aa2825d17e2 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -983,7 +983,7 @@ def _validate_inference_input_output_types(self, quant_mode): if quant_mode.is_post_training_int16x8_quantization(): all_types = default_types + [_dtypes.int16] else: - all_types = default_types + [_dtypes.int8, _dtypes.uint8] + all_types = default_types + [_dtypes.int8, _dtypes.uint8,_dtypes.int16] if (self.inference_input_type not in all_types or self.inference_output_type not in all_types): all_types_names = ["tf." + t.name for t in all_types] From 976e29b9dee9150a60a2a5757aecb5b954724918 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 13 Apr 2023 23:50:21 +0530 Subject: [PATCH 002/391] [lite]Add int64 support for split builtin op This commit adds the support for int64 for the split op. As mentioned in #50636 >int64 is already supported for concat and split_v ops > --- tensorflow/lite/kernels/split.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/split.cc b/tensorflow/lite/kernels/split.cc index 1491f4bbb98823..67979a7a63908b 100644 --- a/tensorflow/lite/kernels/split.cc +++ b/tensorflow/lite/kernels/split.cc @@ -87,7 +87,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || - input_type == kTfLiteInt32); + input_type == kTfLiteInt32 || input_type == kTfLiteInt64); for (int i = 0; i < NumOutputs(node); ++i) { TfLiteTensor* tensor; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); @@ -158,6 +158,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(int32_t); break; } + case kTfLiteInt64: { + TF_LITE_SPLIT(int64_t); + break; + } default: TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.", TfLiteTypeGetName(op_context.input->type)); From d96cc0ea810fa0e739df23d7e0fd01cee1341f56 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Tue, 18 Apr 2023 10:57:55 +0000 Subject: [PATCH 003/391] Fixing: //tensorflow/core/grappler/optimizers:remapper_test //tensorflow/python:nn_test Reenabling full tests and disabling specific failing subtests: //tensorflow/python/grappler:remapper_test //tensorflow/python/kernel_tests/math_ops:cwise_ops_binary_test //tensorflow/core/kernels/mlir_generated:gpu_binary_ops_test a few others --- tensorflow/core/grappler/optimizers/BUILD | 2 +- .../core/grappler/optimizers/remapper_test.cc | 20 ++++++++++++++++++- tensorflow/core/kernels/BUILD | 2 +- tensorflow/core/kernels/mlir_generated/BUILD | 1 - .../mlir_generated/gpu_binary_ops_test.cc | 2 ++ .../core/kernels/tensor_to_hash_bucket_op.cc | 4 ++-- .../core/kernels/tensor_to_hash_bucket_op.h | 4 ++-- .../tensor_to_hash_bucket_op_gpu.cu.cc | 4 ++-- tensorflow/python/client/BUILD | 1 + .../python/client/session_partial_run_test.py | 1 - tensorflow/python/grappler/remapper_test.py | 2 ++ .../kernel_tests/linalg/linalg_grad_test.py | 2 +- .../math_ops/cwise_ops_binary_test.py | 4 +++- .../math_ops/cwise_ops_unary_test.py | 12 ----------- .../segment_reduction_ops_d9m_test.py | 3 --- .../kernel_tests/nn_ops/conv_ops_test.py | 4 ++-- tensorflow/python/ops/init_ops_test.py | 3 --- .../ops/parallel_for/control_flow_ops_test.py | 6 ------ .../profiler/internal/run_metadata_test.py | 3 --- 19 files changed, 38 insertions(+), 42 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 016b209cdd6c68..430044dff83cc6 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -910,7 +910,7 @@ tf_kernel_library( tf_cuda_cc_test( name = "remapper_test", srcs = ["remapper_test.cc"], - tags = ["no_rocm"], + tags = [], deps = [ ":remapper", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 4a04413873815c..c7de41eba6cb88 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -1558,7 +1558,17 @@ TEST_F(RemapperFuseMatMulWithBiasTest, F16) { RunTest(); } -TEST_F(RemapperFuseMatMulWithBiasTest, F32) { RunTest(); } +TEST_F(RemapperFuseMatMulWithBiasTest, F32) { + bool skip_test = false; +#if !defined(GOOGLE_CUDA) + skip_test = true; +#endif + if (skip_test || GetNumAvailableGPUs() == 0) { + GTEST_SKIP() << "Skipping FuseMatMulWithBias with float, which is only " + "supported in CUDA."; + } + RunTest(); +} TEST_F(RemapperFuseMatMulWithBiasTest, Bf16) { #if !defined(ENABLE_MKL) @@ -1767,6 +1777,14 @@ TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, F16) { } TEST_F(RemapperFuseMatMulWithBiasAndActivationTest, F32) { + bool skip_test = false; +#if !defined(GOOGLE_CUDA) + skip_test = true; +#endif + if (skip_test || GetNumAvailableGPUs() == 0) { + GTEST_SKIP() << "Skipping FuseMatMulWithBiasAndActivationTest with float, " + "which is only supported in CUDA."; + } RunTest(); } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 4b4d6798e41f37..a67c5daa3c8f1a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5458,7 +5458,7 @@ tf_kernel_library( name = "tensor_to_hash_bucket_op", prefix = "tensor_to_hash_bucket_op", deps = STRING_DEPS + if_oss( - if_cuda(["@farmhash_gpu_archive//:farmhash_gpu"]), + if_cuda_or_rocm(["@farmhash_gpu_archive//:farmhash_gpu"]), tf_fingerprint_deps(), ), ) diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index b94f1ffd77aa32..64b29dca6c4cda 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -527,7 +527,6 @@ tf_cuda_cc_test( shard_count = 20, tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # b/173033461 - "no_rocm", # failed since 7de9cf4 ], deps = [ ":base_binary_ops_test", diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index 020193d4ce190c..928fb86d698948 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -683,11 +683,13 @@ GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( /*test_name=*/UInt64, uint64_t, uint64_t, test::DefaultInput(), test::DefaultInputNonZero(), baseline_floor_mod, test::OpsTestConfig().ExpectStrictlyEqual()); +#if !TENSORFLOW_USE_ROCM GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( FloorMod, /*test_name=*/Half, Eigen::half, Eigen::half, test::DefaultInput(), test::DefaultInputNonZero(), baseline_floor_mod, test::OpsTestConfig()); +#endif GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( FloorMod, /*test_name=*/Float, float, float, test::DefaultInput(), diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc b/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc index 2d6309171afcd4..ed6f675a25090b 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op.cc @@ -73,7 +73,7 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER(Name("_TensorToHashBucketFast") \ @@ -85,6 +85,6 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op.h b/tensorflow/core/kernels/tensor_to_hash_bucket_op.h index 90c4f2084f1bea..a681640be2323b 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op.h +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op.h @@ -66,13 +66,13 @@ struct LaunchTensorToHashBucket { } }; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template struct LaunchTensorToHashBucket { void operator()(OpKernelContext* c, const int64_t num_buckets, const T* input, const int num_elems, int64_t* output); }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc b/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc index 6c06778d214582..de47f990ead2ae 100644 --- a/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc +++ b/tensorflow/core/kernels/tensor_to_hash_bucket_op_gpu.cu.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -119,4 +119,4 @@ TF_CALL_INTEGRAL_TYPES(REGISTER_FUNCTORS); #undef REGISTER_FUNCTORS } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 0a9fedc17725e0..d9c0ad4432b068 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -424,6 +424,7 @@ tf_py_strict_test( tags = [ "no_gpu", "no_windows", + "no_rocm", ], deps = [ ":session", diff --git a/tensorflow/python/client/session_partial_run_test.py b/tensorflow/python/client/session_partial_run_test.py index 075d69e78bc400..79cedb5a2ffdd6 100644 --- a/tensorflow/python/client/session_partial_run_test.py +++ b/tensorflow/python/client/session_partial_run_test.py @@ -26,7 +26,6 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib - class PartialRunTest(test_util.TensorFlowTestCase): def RunTestPartialRun(self, sess): diff --git a/tensorflow/python/grappler/remapper_test.py b/tensorflow/python/grappler/remapper_test.py index 6d693431f60ea4..91f283c5969792 100644 --- a/tensorflow/python/grappler/remapper_test.py +++ b/tensorflow/python/grappler/remapper_test.py @@ -227,6 +227,8 @@ def test_conv2d_biasadd_act_fusion(self): """Test Conv2D+BiasAdd+Relu fusion.""" if not test_util.is_gpu_available(): self.skipTest('No GPU available') + if test.is_built_with_rocm(): + self.skipTest('ROCm does not support conv biasadd fusion') N, H, W, C = (5, 3, 3, 8) # pylint: disable=invalid-name # The runtime fusion requires the output dims to be 32-bit aligned. diff --git a/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py index 3f37a3585101d1..b1ecb18a30807a 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_grad_test.py @@ -240,7 +240,7 @@ def Test(self): lambda x: linalg_ops.matrix_inverse(x, adjoint=True), dtype, shape)) - if not test_lib.is_built_with_rocm(): + if True: #not test_lib.is_built_with_rocm(): # TODO(rocm) : # re-enable this test when upstream issues are resolved # see commit msg for details diff --git a/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py index 8a1d14be8417ee..922b618f8a8d03 100644 --- a/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/math_ops/cwise_ops_binary_test.py @@ -883,7 +883,9 @@ def testPowNegativeExponentGpu(self): z = math_ops.pow(x, y) self.assertAllEqual(self.evaluate(z), [0, 1, 1, 1, -1]) - def testFloorModInfDenominator(self): + @test.disable_with_predicate( + pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") + def testFloorModfInfDenominator(self): """Regression test for GitHub issue #58369.""" if not test_util.is_gpu_available(): self.skipTest("Requires GPU") diff --git a/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py index 29daaea0b1643a..24c06cedce2443 100644 --- a/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py +++ b/tensorflow/python/kernel_tests/math_ops/cwise_ops_unary_test.py @@ -445,8 +445,6 @@ def f(x): self._compareBoth(x, compute_f32(np.vectorize(math.erfc)), math_ops.erfc) self._compareBoth(x, compute_f32(np.square), math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testInt8Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int8) self._compareCpu(x, np.abs, math_ops.abs) @@ -455,14 +453,10 @@ def testInt8Basic(self): self._compareBoth(x, np.negative, _NEG) self._compareBoth(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt8Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint8) self._compareBoth(x, np.square, math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testInt16Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int16) self._compareCpu(x, np.abs, math_ops.abs) @@ -471,8 +465,6 @@ def testInt16Basic(self): self._compareBoth(x, np.negative, _NEG) self._compareBoth(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt16Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint16) self._compareBoth(x, np.square, math_ops.square) @@ -491,8 +483,6 @@ def testInt32Basic(self): self._compareBothSparse(x, np.square, math_ops.square) self._compareBothSparse(x, np.sign, math_ops.sign) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt32Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint32) self._compareBoth(x, np.square, math_ops.square) @@ -514,8 +504,6 @@ def testInt64Square(self): self._compareCpu(x, np.square, math_ops.square) self._compareBothSparse(x, np.square, math_ops.square) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="On ROCm this test fails") def testUInt64Basic(self): x = np.arange(6).reshape(1, 3, 2).astype(np.uint64) self._compareBoth(x, np.square, math_ops.square) diff --git a/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py b/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py index fbd5f9501c0933..3c166b86fabc74 100644 --- a/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py +++ b/tensorflow/python/kernel_tests/math_ops/segment_reduction_ops_d9m_test.py @@ -89,9 +89,6 @@ def testUnsortedOps(self): result = op(data, segment_ids, num_segments) self.evaluate(result) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="No ROCm support for complex types in segment reduction ops") @test_util.run_cuda_only def testUnsortedOpsComplex(self): for op in [ diff --git a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py index 5ae06d0ad401a3..19238f29155b8a 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py @@ -207,11 +207,11 @@ def _DtypesToTest(self, use_gpu): if use_gpu: # It is important that float32 comes first, since we are using its # gradients as a reference for fp16 gradients. - out = [dtypes.float32] + out = [dtypes.float32, dtypes.bfloat16] if test_util.GpuSupportsHalfMatMulAndConv(): out.append(dtypes.float16) if not test.is_built_with_rocm(): - out.extend([dtypes.float64, dtypes.bfloat16]) + out.extend([dtypes.float64]) return out return [dtypes.float32, dtypes.float64, dtypes.float16, dtypes.bfloat16] diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py index a0ef239581405a..0d34a764e5f6fd 100644 --- a/tensorflow/python/ops/init_ops_test.py +++ b/tensorflow/python/ops/init_ops_test.py @@ -172,9 +172,6 @@ def test_Orthogonal(self): self._runner( init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message='Disable subtest on ROCm due to missing QR op support') @test_util.run_gpu_only def testVariablePlacementWithOrthogonalInitializer(self): with ops.Graph().as_default() as g: diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index cee25369963fda..e62eb4c075fcf1 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -2773,9 +2773,6 @@ def loop_fn(i): (fft_ops.rfft2d,), (fft_ops.rfft3d,), ) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="Disable subtest on ROCm due to rocfft issues") def test_rfft(self, op_func): for dtype in (dtypes.float32, dtypes.float64): x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype) @@ -2794,9 +2791,6 @@ def loop_fn(i): (fft_ops.irfft2d,), (fft_ops.irfft3d,), ) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message="Disable subtest on ROCm due to rocfft issues") def test_irfft(self, op_func): if config.list_physical_devices("GPU"): # TODO(b/149957923): The test is flaky diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py index d95dcb79d1e4fd..f5df743995fb86 100644 --- a/tensorflow/python/profiler/internal/run_metadata_test.py +++ b/tensorflow/python/profiler/internal/run_metadata_test.py @@ -112,9 +112,6 @@ class RunMetadataTest(test.TestCase): # work as expected. Since we now run this test with SOFTWARE_TRACE # (see _run_model routine above), this test will / should fail since # GPU device tracers are not enabled - @test.disable_with_predicate( - pred=test.is_built_with_rocm, - skip_message='Test fails on ROCm when run without FULL_TRACE') @test_util.run_deprecated_v1 def testGPU(self): if not test.is_gpu_available(cuda_only=True): From 65a2659287af69e809c804f17c9ac84e087d175a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 14 Jul 2023 22:40:45 +0000 Subject: [PATCH 004/391] [TOSA] fix legalization of tfl.quantize legalization of tfl.quantize used to do: - multiply by scale in f32, add zeropoint in f32, cast to output type this led to mismatches because tosa cast (fp to int) used rounding mode: "round to nearest, tie to even", so rounding can be wrong when adding the zeropoint before rounding. this patch changes legalization to do: - multiply by scale in f32, cast to i32, add zeropoint in i32, cast to output type - also, if zeropoint is 0, then skip the "cast to i32" and "add zeropoint" Signed-off-by: Tai Ly Change-Id: Ibaacd24778d83dd2cf1b1d914b287e32f4d775d0 --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 38 ++++++++++--------- .../mlir/tosa/transforms/legalize_common.cc | 24 ++++++++---- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 9d2bc6a3b5ff36..5ac4425d5ad0d6 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -271,18 +271,20 @@ func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2 // CHECK-LABEL: test_conv3d_qi8( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x21x17x!quant.uniform> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x17x34xf32>) -> tensor<1x4x8x11x34x!quant.uniform> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<-4.000000e+00> : tensor<1x1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>}> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>}> -// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_0]] -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_7]], %[[VAL_2]] {shift = 0 : i8} -// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] -// CHECK: %[[VAL_10:.*]] = tosa.conv3d %[[VAL_8]], %[[VAL_9]], %[[VAL_5]] {dilation = array, pad = array, stride = array} -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_3]] {shift = 0 : i8} -// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_11]], %[[VAL_4]] -// CHECK: %[[VAL_13:.*]] = tosa.cast %[[VAL_12]] +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<-4> : tensor<1x1x1x1x1xi32>} +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>} +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_0]] +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] +// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[VAL_6]] {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_4]] {shift = 0 : i8} +// CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_14]], %[[VAL_5]] +// CHECK: %[[VAL_16:.*]] = tosa.cast %[[VAL_15]] +// CHECK: return %[[VAL_16]] func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform>, %arg1: tensor<2x3x3x17x34xf32>) -> (tensor<1x4x8x11x34x!quant.uniform>) { %0 = "tfl.dequantize"(%arg0) : (tensor<1x4x8x21x17x!quant.uniform>) -> tensor<1x4x8x21x17xf32> %2 = "tfl.no_value"() {value} : () -> none @@ -1840,12 +1842,12 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR0]] {shift = 0 : i8} -// CHECK-DAG: %[[VAR3:.*]] = tosa.cast %[[VAR2]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.cast %[[VAR3]] -// CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR4]], %[[VAR1]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR3:.*]] = tosa.mul %arg0, %[[VAR2]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR5:.*]] = tosa.cast %[[VAR3]] +// CHECK-DAG: %[[VAR6:.*]] = tosa.cast %[[VAR5]] +// CHECK-DAG: %[[VAR8:.*]] = tosa.mul %[[VAR6]], %[[VAR1]] {shift = 0 : i8} func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.quantize"(%arg0) {qtype = tensor<13x21x3x!quant.uniform>} : (tensor<13x21x3xf32>) -> tensor<*x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<*x!quant.uniform>) -> tensor<13x21x3xf32> diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 7fed578c78f86c..2191a421ccd0ff 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -3515,20 +3515,28 @@ std::optional convertQuantizeOp(PatternRewriter& rewriter, Operation* op, ShapedType output_fp_type = output_type.clone(rewriter.getF32Type()); - Value zp_val = - getTosaConstTensorSingleF32(rewriter, op, static_cast(zeropoint)); + auto rank = input_type.getRank(); - auto op1_mul_in = CreateOpAndInfer( + Value result = CreateOpAndInfer( rewriter, op->getLoc(), output_fp_type, input_value, getTosaConstTensorSingleF32(rewriter, op, static_cast(scale)), 0); - auto op2_add_op1 = CreateOpAndInfer( - rewriter, op->getLoc(), output_fp_type, op1_mul_in.getResult(), zp_val); + if (zeropoint != 0) { + // cast to i32 to add zeropoint + ShapedType output_i32_type = output_type.clone(rewriter.getI32Type()); + Value cast_i32 = CreateOpAndInfer(rewriter, op->getLoc(), + output_i32_type, result); - auto op3_cast_op2 = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, op2_add_op1.getResult()); + Value zp_val = getTosaConstTensorSingleI32(rewriter, op, zeropoint); - return op3_cast_op2.getResult(); + result = CreateOpAndInfer(rewriter, op->getLoc(), + output_i32_type, cast_i32, zp_val); + } + + Value final_result = CreateOpAndInfer(rewriter, op->getLoc(), + output_type, result); + + return final_result; } // Lowers Dequantize to a sequence of TOSA dequantization ops. From 36a6dbe0e9b2f178b050016e5deff72ceca3b2a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 12 Nov 2023 01:05:32 -0800 Subject: [PATCH 005/391] Update GraphDef version to 1678. PiperOrigin-RevId: 581674480 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d3dec324fe9cff..08e4344206b0d5 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1677 // Updated: 2023/11/11 +#define TF_GRAPH_DEF_VERSION 1678 // Updated: 2023/11/12 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 2573a40d25d6d8660a2156e04669d6700231ab49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 12 Nov 2023 01:05:33 -0800 Subject: [PATCH 006/391] compat: Update forward compatibility horizon to 2023-11-12 PiperOrigin-RevId: 581674485 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 00a83eadbe18c0..745c92f1bc0597 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 11) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 12) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 00bafe8170f71efe299549ffbb961e2020598761 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 12 Nov 2023 02:18:08 -0800 Subject: [PATCH 007/391] [stream_executor] Add update mode for nested command buffer launching PiperOrigin-RevId: 581684847 --- .../gpu/runtime3/command_buffer_thunk_test.cc | 19 ++++++++++++++++++- .../cuda/cuda_command_buffer_test.cc | 19 +++++++++++++++++++ .../xla/stream_executor/cuda/cuda_driver.cc | 13 +++++++++++++ .../stream_executor/gpu/gpu_command_buffer.cc | 10 ++++++++-- .../xla/xla/stream_executor/gpu/gpu_driver.h | 6 ++++++ 5 files changed, 64 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 95033655ba09a0..12e84df1b0ae59 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -207,7 +207,7 @@ TEST(CommandBufferThunkTest, GemmCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; commands.Emplace(config.value(), slice_lhs, slice_rhs, slice_out, - /*deterministic*/ true); + /*deterministic=*/true); // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); @@ -227,6 +227,23 @@ TEST(CommandBufferThunkTest, GemmCmd) { stream.ThenMemcpy(dst.data(), out, out_length); ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Prepare buffer allocation for updating command buffer. + se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); + stream.ThenMemZero(&updated_out, out_length); + + // Update buffer allocation to updated `out` buffer. + allocations = + BufferAllocations({lhs, rhs, updated_out}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), updated_out, out_length); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 73979361e03b2d..949c2cc0a98324 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -195,6 +195,25 @@ TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { std::vector expected = {3, 3, 3, 3}; ASSERT_EQ(dst, expected); + + // Prepare argument for graph update: d = 0 + DeviceMemory d = executor->AllocateArray(length, 0); + stream.ThenMemZero(&d, byte_length); + + // Update command buffer to write into `d` buffer by creating a new nested + // command buffer. + nested_cmd = CommandBuffer::Create(executor, nested).value(); + TF_ASSERT_OK(nested_cmd.Launch(add, ThreadDim(), BlockDim(4), a, b, d)); + TF_ASSERT_OK(primary_cmd.Update()); + TF_ASSERT_OK(primary_cmd.AddNestedCommandBuffer(nested_cmd)); + TF_ASSERT_OK(primary_cmd.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, primary_cmd)); + + // Copy `d` data back to host. + std::fill(dst.begin(), dst.end(), 42); + stream.ThenMemcpy(dst.data(), d, byte_length); + ASSERT_EQ(dst, expected); } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 5f77fc97da1d66..488a1a68f1544a 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -839,6 +839,19 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return ::tsl::OkStatus(); } +/*static*/ tsl::Status GpuDriver::GraphExecChildNodeSetParams(CUgraphExec exec, + CUgraphNode node, + CUgraph child) { + VLOG(2) << "Set child node params " << node << " in graph executable " << exec + << "to params contained in " << child; + + RETURN_IF_CUDA_RES_ERROR( + cuGraphExecChildGraphNodeSetParams(exec, node, child), + "Failed to set CUDA graph child node params"); + + return ::tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index fc1c56b9c21508..eed1620e628fc8 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -203,12 +203,18 @@ tsl::Status GpuCommandBuffer::AddNestedCommandBuffer( TF_RETURN_IF_ERROR(CheckNotFinalized()); TF_RETURN_IF_ERROR(CheckPrimary()); + GpuGraphHandle child_graph = GpuCommandBuffer::Cast(&nested)->graph(); // Adds a child graph node to the graph under construction. if (state_ == State::kCreate) { absl::Span deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); - return GpuDriver::GraphAddChildNode( - node, graph_, deps, GpuCommandBuffer::Cast(&nested)->graph()); + return GpuDriver::GraphAddChildNode(node, graph_, deps, child_graph); + } + + // Updates child graph node in the executable graph. + if (state_ == State::kUpdate) { + GpuGraphNodeHandle node = nodes_[node_update_idx_++]; + return GpuDriver::GraphExecChildNodeSetParams(exec_, node, child_graph); } return UnsupportedStateError(state_); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 8dff2fd9724650..d48faa11f7d9db 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -463,6 +463,12 @@ class GpuDriver { absl::Span deps, GpuGraphHandle child); + // Sets the parameters for a child graph node in the given graph exec. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g8f2d9893f6b899f992db1a2942ec03ff + static tsl::Status GraphExecChildNodeSetParams(GpuGraphExecHandle exec, + GpuGraphNodeHandle node, + GpuGraphHandle child); + // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting // handle in "module". Any error logs that are produced are logged internally. // (supported on CUDA only) From 2c15e13fa793bd2a2a4aed66965905d40e33a5ac Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 12 Nov 2023 02:51:00 -0800 Subject: [PATCH 008/391] [stream_executor] Fix GpuDriver::GraphExecKernelNodeSetParams argument types for consistency PiperOrigin-RevId: 581690024 --- .../xla/xla/stream_executor/cuda/cuda_driver.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 488a1a68f1544a..f1387689e0abe8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -758,14 +758,13 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { } /*static*/ tsl::Status GpuDriver::GraphExecKernelNodeSetParams( - GpuGraphExecHandle exec, GpuGraphNodeHandle node, - absl::string_view kernel_name, GpuFunctionHandle function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, + CUgraphExec exec, CUgraphNode node, absl::string_view kernel_name, + CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, unsigned int shared_mem_bytes, void** kernel_params, void** extra) { - VLOG(2) << "Set kernel node params " << node << " in graph executabe " << exec - << "; kernel: " << kernel_name << "; gdx: " << grid_dim_x + VLOG(2) << "Set kernel node params " << node << " in graph executable " + << exec << "; kernel: " << kernel_name << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << "; shmem: " << shared_mem_bytes; From 7d7c89beff530bd8f83484cab5c66c96322d5926 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 12 Nov 2023 10:53:58 -0800 Subject: [PATCH 009/391] [PJRT] Add an overload of GetTfrtCpuClient that takes a CpuClientOptions struct. Change in preparation for adding additional options to the CPU client. PiperOrigin-RevId: 581757511 --- .../xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc | 7 ++-- .../xla/xla/pjrt/pjrt_client_test_cpu.cc | 11 +++--- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.cc | 15 ++++---- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.h | 36 +++++++++++++++---- .../xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc | 30 ++++++++-------- .../xla/xla/python/outfeed_receiver_test.cc | 10 +++--- .../pjrt_ifrt/tfrt_cpu_client_test_lib.cc | 20 +++++------ third_party/xla/xla/python/xla.cc | 4 ++- .../xla/xla/tests/pjrt_cpu_client_registry.cc | 11 +++--- 9 files changed, 83 insertions(+), 61 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 1344ab3ccc47b1..89bff439bff6f7 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -34,9 +34,10 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { // TODO(b/263170683): cpu_device_count should be configurable after config // options can be passed to PJRT_Client_Create. - PJRT_ASSIGN_OR_RETURN( - std::unique_ptr client, - xla::GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/4)); + xla::CpuClientOptions options; + options.cpu_device_count = 4; + PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, + xla::GetTfrtCpuClient(options)); args->client = pjrt::CreateWrapperClient(std::move(client)); return nullptr; } diff --git a/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc b/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc index 4bdd19ac29310f..59ef9ff1514472 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc +++ b/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc @@ -20,11 +20,12 @@ namespace xla { namespace { // Register CPU as the backend for tests in pjrt_client_test.cc. -const bool kUnused = - (RegisterTestClientFactory([]() { - return GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/4); - }), - true); +const bool kUnused = (RegisterTestClientFactory([]() { + CpuClientOptions options; + options.cpu_device_count = 4; + return GetTfrtCpuClient(options); + }), + true); } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc index 5f4080eba14d28..863f46ce2725af 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -293,23 +293,20 @@ static StatusOr>> GetTfrtCpuDevices( } StatusOr> GetTfrtCpuClient( - bool asynchronous, int cpu_device_count, - int max_inflight_computations_per_device) { + const CpuClientOptions& options) { // Need at least CpuDeviceCount threads to launch one collective. + int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); - TF_ASSIGN_OR_RETURN(std::vector> devices, - GetTfrtCpuDevices(cpu_device_count, - max_inflight_computations_per_device)); + TF_ASSIGN_OR_RETURN( + std::vector> devices, + GetTfrtCpuDevices(cpu_device_count, + options.max_inflight_computations_per_device)); return std::unique_ptr(std::make_unique( /*process_index=*/0, std::move(devices), num_threads)); } -StatusOr> GetTfrtCpuClient(bool asynchronous) { - return GetTfrtCpuClient(asynchronous, CpuDeviceCount()); -} - TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, size_t num_threads) diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h index 70175c021756a7..c744d10ac3a9ea 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -509,16 +509,38 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { bool cheap_computation_; }; -// Creates a CPU client with one Device. For testing purposes, you can set the -// number of devices passing the --xla_force_host_platform_device_count flag to -// the XLA_FLAGS environment variable. -StatusOr> GetTfrtCpuClient(bool asynchronous); +struct CpuClientOptions { + // Does nothing at the moment. Ignored. + bool asynchronous = true; -// Similar to the function above, but you can set the number of devices and max -// number of inflight computations per device explicitly. + // Number of CPU devices. If not provided, the value of + // --xla_force_host_platform_device_count is used. + std::optional cpu_device_count = std::nullopt; + + int max_inflight_computations_per_device = 32; +}; StatusOr> GetTfrtCpuClient( + const CpuClientOptions& options); + +// Deprecated. Use the overload that takes 'options' instead. +inline StatusOr> GetTfrtCpuClient( + bool asynchronous) { + CpuClientOptions options; + options.asynchronous = asynchronous; + return GetTfrtCpuClient(options); +} + +// Deprecated. Use the overload that takes 'options' instead. +inline StatusOr> GetTfrtCpuClient( bool asynchronous, int cpu_device_count, - int max_inflight_computations_per_device = 32); + int max_inflight_computations_per_device = 32) { + CpuClientOptions options; + options.asynchronous = asynchronous; + options.cpu_device_count = cpu_device_count; + options.max_inflight_computations_per_device = + max_inflight_computations_per_device; + return GetTfrtCpuClient(options); +} } // namespace xla diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc index 1b6e25b8d870e5..7503f5370e31db 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc @@ -66,7 +66,7 @@ ENTRY DonationWithExecutionError() -> f32[2, 2] { ROOT %result = f32[2, 2] get-tuple-element(%custom-call), index=0 })"; - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnUnverifiedModule(kProgram, {})); @@ -105,11 +105,9 @@ TEST(TfrtCpuClientTest, HloSnapshot) { ROOT add = f32[3,2] add(x, y) })"; - TF_ASSERT_OK_AND_ASSIGN( - auto client, - GetTfrtCpuClient(/*asynchronous=*/true, - /*cpu_device_count=*/1, - /*max_inflight_computations_per_device=*/32)); + CpuClientOptions cpu_options; + cpu_options.cpu_device_count = 1; + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(cpu_options)); TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnUnverifiedModule(kProgram, {})); @@ -167,7 +165,7 @@ TEST(TfrtCpuClientTest, HloSnapshot) { } TEST(TfrtCpuClientTest, AsyncTransferRawData) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -187,7 +185,7 @@ TEST(TfrtCpuClientTest, AsyncTransferRawData) { } TEST(TfrtCpuClientTest, AsyncTransferLiteral) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = xla::ShapeUtil::MakeShape(F32, {128, 256}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -203,7 +201,7 @@ TEST(TfrtCpuClientTest, AsyncTransferLiteral) { } TEST(TfrtCpuClientTest, AsyncTransferCallsOnDone) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(F32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -221,7 +219,7 @@ TEST(TfrtCpuClientTest, AsyncTransferCallsOnDone) { } TEST(TfrtCpuClientTest, AsyncTransferNeverTransferred) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -236,7 +234,7 @@ TEST(TfrtCpuClientTest, AsyncTransferNeverTransferred) { } TEST(TfrtCpuClientTest, AsyncTransferBufferCount) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -249,7 +247,7 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferCount) { } TEST(TfrtCpuClientTest, AsyncTransferBufferSize) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -258,7 +256,7 @@ TEST(TfrtCpuClientTest, AsyncTransferBufferSize) { } TEST(TfrtCpuClientTest, AsyncTransferDevice) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); auto* device = client->addressable_devices()[0]; TF_ASSERT_OK_AND_ASSIGN( @@ -268,7 +266,7 @@ TEST(TfrtCpuClientTest, AsyncTransferDevice) { } TEST(TfrtCpuClientTest, AsyncTransferSetBufferError) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( @@ -281,7 +279,7 @@ TEST(TfrtCpuClientTest, AsyncTransferSetBufferError) { } TEST(TfrtCpuClientTest, CreateErrorBuffer) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN( auto buffer, client->CreateErrorBuffer(InternalError("foobar"), shape, @@ -292,7 +290,7 @@ TEST(TfrtCpuClientTest, CreateErrorBuffer) { } TEST(TfrtCpuClientTest, AsyncTransferRawDataToSubBuffer) { - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(/*asynchronous=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(CpuClientOptions())); xla::Shape shape = ShapeUtil::MakeShape(U32, {3, 2}); TF_ASSERT_OK_AND_ASSIGN(auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( diff --git a/third_party/xla/xla/python/outfeed_receiver_test.cc b/third_party/xla/xla/python/outfeed_receiver_test.cc index b1a1364e999e85..62539fa6079bea 100644 --- a/third_party/xla/xla/python/outfeed_receiver_test.cc +++ b/third_party/xla/xla/python/outfeed_receiver_test.cc @@ -112,7 +112,7 @@ class Accumulator { TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(true)); + GetTfrtCpuClient(CpuClientOptions())); std::vector clients{cpu_client.get()}; auto receiver = std::make_unique(); @@ -145,7 +145,7 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(true)); + GetTfrtCpuClient(CpuClientOptions())); std::vector clients{cpu_client.get()}; auto receiver = std::make_unique(); @@ -190,7 +190,7 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(true)); + GetTfrtCpuClient(CpuClientOptions())); std::vector clients{cpu_client.get()}; auto receiver = std::make_unique(); @@ -233,7 +233,7 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(true)); + GetTfrtCpuClient(CpuClientOptions())); std::vector clients{cpu_client.get()}; auto receiver = std::make_unique(); @@ -267,7 +267,7 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { TEST(OutfeedReceiverTest, InvalidConsumerIdError) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(true)); + GetTfrtCpuClient(CpuClientOptions())); std::vector clients{cpu_client.get()}; auto receiver = std::make_unique(); diff --git a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 1a485fc722d6a6..9486ec0156ca8a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -24,16 +24,16 @@ namespace xla { namespace ifrt { namespace { -const bool kUnused = - (test_util::RegisterClientFactory( - []() -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto pjrt_client, - xla::GetTfrtCpuClient(/*asynchronous=*/true, - /*cpu_device_count=*/2)); - return std::shared_ptr( - PjRtClient::Create(std::move(pjrt_client))); - }), - true); +const bool kUnused = (test_util::RegisterClientFactory( + []() -> StatusOr> { + CpuClientOptions options; + options.cpu_device_count = 2; + TF_ASSIGN_OR_RETURN(auto pjrt_client, + xla::GetTfrtCpuClient(options)); + return std::shared_ptr( + PjRtClient::Create(std::move(pjrt_client))); + }), + true); } // namespace } // namespace ifrt diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 8d09b08eb11491..06de87623a078c 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -494,8 +494,10 @@ static void Init(py::module_& m) { "get_tfrt_cpu_client", [](bool asynchronous) -> std::shared_ptr { py::gil_scoped_release gil_release; + CpuClientOptions options; + options.asynchronous = asynchronous; std::unique_ptr client = - xla::ValueOrThrow(GetTfrtCpuClient(asynchronous)); + xla::ValueOrThrow(GetTfrtCpuClient(options)); return std::make_shared( ifrt::PjRtClient::Create(std::move(client))); }, diff --git a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc index 02ec05afbe3352..1f0eaa43d5819f 100644 --- a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc +++ b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc @@ -20,11 +20,12 @@ namespace xla { namespace { // Register a CPU PjRt client for tests. -const bool kUnused = - (RegisterPjRtClientTestFactory([]() { - return GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/4); - }), - true); +const bool kUnused = (RegisterPjRtClientTestFactory([]() { + CpuClientOptions options; + options.cpu_device_count = 4; + return GetTfrtCpuClient(options); + }), + true); } // namespace } // namespace xla From 22541314d340038fe72cd69ce5c3b7f6c47bb10e Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 12 Nov 2023 15:06:52 -0800 Subject: [PATCH 010/391] [stream_executor] Nested command buffers should not be updated PiperOrigin-RevId: 581788661 --- .../xla/xla/stream_executor/gpu/gpu_command_buffer.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index eed1620e628fc8..99037894a887df 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -277,12 +277,14 @@ tsl::Status GpuCommandBuffer::Update() { "Command buffer has to be finalized first before it can be updated"); } - // TODO(ezhulenev): Add support for updating nested command buffers. Today - // we only support updating primary command buffers as we need a non null - // executable graph. if (exec_ == nullptr) { + if (mode_ == Mode::kPrimary) + return absl::InternalError( + "Primary command buffers are expected to have executable graphs"); return absl::UnimplementedError( - "Nested command buffer update is not implemented"); + "Nested command buffer update is deliberately not implemented. One " + "should create a new nested command buffer and update the primary one " + "instead"); } VLOG(5) << "Begin primary command buffer update for executable graph " From 5093c1024151080689f846224d1aa823d57856ef Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 01:02:04 -0800 Subject: [PATCH 011/391] compat: Update forward compatibility horizon to 2023-11-13 PiperOrigin-RevId: 581871209 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 745c92f1bc0597..fb919e026268a9 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 12) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 13) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 168e19001b127823b125f9c9660c223bee27bfea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 01:02:05 -0800 Subject: [PATCH 012/391] Update GraphDef version to 1679. PiperOrigin-RevId: 581871214 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 08e4344206b0d5..12b17cb785e13b 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1678 // Updated: 2023/11/12 +#define TF_GRAPH_DEF_VERSION 1679 // Updated: 2023/11/13 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 01b930daff1e6437f813a2ceb2ee7bdb174babd3 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 13 Nov 2023 03:57:47 -0800 Subject: [PATCH 013/391] Port exponential compile time heuristic to priority fusion. Our emitters don't deal well with certain types of slicing, leading to very inefficient code. Ideally, we would detect this and just use shared memory, but that requires either using Triton or a new type of emitter, so we just port the heuristic for now. PiperOrigin-RevId: 581908968 --- .../xla/xla/service/gpu/priority_fusion.cc | 24 +++++++++++++++++++ .../xla/xla/service/gpu/priority_fusion.h | 8 +++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index a930f18344134e..ecccc546147c04 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -433,6 +433,23 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, return fits_budget; } + // Also check that our emitter can handle the fusion node. We currently can + // have exponential time/memory requirements for emitting certain fusion + // kernels, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + if (consumer->opcode() == HloOpcode::kFusion) { + if (fusion_node_evaluations_.find(consumer) == + fusion_node_evaluations_.end()) { + // We have no cached results for this fusion node yet. Compute it now. + fusion_node_evaluations_.emplace(consumer, + FusionNodeIndexingEvaluation(consumer)); + } + if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( + producer)) { + return "the fusion would result in an overly large code duplication"; + } + } + return InstructionFusion::ShouldFuse(consumer, operand_index); } @@ -466,6 +483,13 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( } else { result = InstructionFusion::FuseInstruction(fusion_instruction, producer); } + + // Invalidate cached values that are now invalid. + for (auto* user : fusion_instruction->users()) { + fusion_node_evaluations_.erase(user); + } + fusion_node_evaluations_.erase(fusion_instruction); + return result; } diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index eb729f85235c26..af24b6e9c688f9 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -19,14 +19,13 @@ limitations under the License. #include #include -#include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/dump.h" +#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" @@ -77,6 +76,11 @@ class GpuPriorityFusion : public InstructionFusion { // Proto with structured logs of fusion decisions. Used only for debugging. If // null, logging is disabled. std::unique_ptr fusion_process_dump_; + + // Keep track of the number of times each instruction inside a fusion node is + // indexed with different index vectors. + absl::flat_hash_map + fusion_node_evaluations_; }; } // namespace gpu From b3521842d2f664356fc1a0c67b64ddae8d52a3a4 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 13 Nov 2023 07:15:40 -0800 Subject: [PATCH 014/391] Reduce Bridge visibility of components to specific targets PiperOrigin-RevId: 581952841 --- tensorflow/compiler/mlir/tf2xla/api/v2/BUILD | 40 +++++++++++--------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index da491d18a584a2..73880851e7abc1 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -2,27 +2,11 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +# Please reach out to tf-bridge-team@ before using the TF2XLA bridge. package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":__subpackages__", - ":tf2xla_users", - ], -) - -# Please reach out to tf-bridge-team@ before using the TF2XLA bridge. -package_group( - name = "tf2xla_users", - packages = [ - "//learning/brain/mlir/bridge", - "//tensorflow/compiler/mlir/quantization/stablehlo/...", - "//learning/serving/contrib/tfrt/mlir/saved_model_analysis", - "//tensorflow/compiler/mlir/tfrt", - "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/mlir", - "//tensorflow/compiler/mlir/tfrt/transforms/ifrt", - # Legacy due to where the bridge currently runs. This should go away. - "//tensorflow/compiler/mlir/tensorflow/transforms", ], ) @@ -30,6 +14,12 @@ cc_library( name = "legalize_tf", srcs = ["legalize_tf.cc"], hdrs = ["legalize_tf.h"], + visibility = [ + "//learning/brain/google/xla:__pkg__", + "//learning/brain/mlir/bridge:__pkg__", + "//tensorflow/compiler/mlir/quantization/stablehlo:__pkg__", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", + ], deps = [ ":device_type_proto_cc", "//tensorflow/compiler/jit:flags_headers", @@ -101,12 +91,22 @@ tf_proto_library( name = "device_type_proto", srcs = ["device_type.proto"], cc_api_version = 2, + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + ], ) cc_library( name = "cluster_tf", srcs = ["cluster_tf.cc"], hdrs = ["cluster_tf.h"], + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tfrt:__pkg__", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ ":device_type_proto_cc", ":tf_dialect_to_executor", @@ -167,6 +167,12 @@ cc_library( name = "tf_dialect_to_executor", srcs = ["tf_dialect_to_executor.cc"], hdrs = ["tf_dialect_to_executor.h"], + visibility = [ + "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tfrt:__pkg__", + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", From 67f43b9687aca627b7b563c67ff44076be638ea9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 07:48:19 -0800 Subject: [PATCH 015/391] [PJRT] Delete unused code. * The key-value store implementation was only used by the pre-coordination service distributed service. Since that service is gone, it is now dead. * Remove an unused protocol header. * Remove a number of unused includes. PiperOrigin-RevId: 581961350 --- third_party/xla/xla/pjrt/distributed/BUILD | 28 ++-------- .../xla/xla/pjrt/distributed/client.cc | 10 ---- .../xla/pjrt/distributed/key_value_store.cc | 44 --------------- .../xla/pjrt/distributed/key_value_store.h | 53 ------------------- .../xla/xla/pjrt/distributed/protocol.h | 25 --------- .../xla/xla/pjrt/distributed/service.cc | 6 --- .../xla/xla/pjrt/distributed/service.h | 2 +- 7 files changed, 4 insertions(+), 164 deletions(-) delete mode 100644 third_party/xla/xla/pjrt/distributed/key_value_store.cc delete mode 100644 third_party/xla/xla/pjrt/distributed/key_value_store.h delete mode 100644 third_party/xla/xla/pjrt/distributed/protocol.h diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 0d968f9736b5d5..875c8f08f58350 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -1,7 +1,7 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") -load("@local_tsl//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl:tsl.default.bzl", "tsl_grpc_cc_dependencies") +load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) @@ -19,33 +19,12 @@ tf_proto_library( visibility = ["//visibility:public"], ) -cc_library( - name = "protocol", - hdrs = ["protocol.h"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "key_value_store", - srcs = ["key_value_store.cc"], - hdrs = ["key_value_store.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - ] + tsl_grpc_cc_dependencies(), -) - cc_library( name = "service", srcs = ["service.cc"], hdrs = ["service.h"], visibility = ["//visibility:public"], deps = [ - ":key_value_store", - ":protocol", ":protocol_cc_grpc_proto", ":topology_util", ":util", @@ -91,7 +70,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":protocol", ":protocol_cc_grpc_proto", ":util", "//xla:statusor", diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index 97e4a3e44e21e5..ce63141e05577a 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -16,33 +16,23 @@ limitations under the License. #include "xla/pjrt/distributed/client.h" #include -#include // NOLINT #include -#include -#include #include #include #include -#include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "grpcpp/channel.h" -#include "xla/pjrt/distributed/protocol.h" -#include "xla/pjrt/distributed/util.h" -#include "xla/util.h" #include "tsl/distributed_runtime/coordination/coordination_client.h" #include "tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "tsl/platform/errors.h" -#include "tsl/platform/random.h" #include "tsl/protobuf/coordination_config.pb.h" #include "tsl/protobuf/coordination_service.pb.h" namespace xla { - class DistributedRuntimeCoordinationServiceClient : public DistributedRuntimeClient { public: diff --git a/third_party/xla/xla/pjrt/distributed/key_value_store.cc b/third_party/xla/xla/pjrt/distributed/key_value_store.cc deleted file mode 100644 index 45151bdd96582f..00000000000000 --- a/third_party/xla/xla/pjrt/distributed/key_value_store.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/pjrt/distributed/key_value_store.h" - -namespace xla { - -KeyValueStore::KeyValueStore() = default; - -::grpc::Status KeyValueStore::Get(const std::string& key, - absl::Duration timeout, std::string* value) { - auto key_is_present = [&]() { - mu_.AssertHeld(); - return entries_.find(key) != entries_.end(); - }; - absl::MutexLock lock(&mu_); - // TODO(phawkins): the synchronization here is very coarse, but probably - // sufficient for its current application. - if (!mu_.AwaitWithTimeout(absl::Condition(&key_is_present), timeout)) { - return ::grpc::Status(::grpc::StatusCode::NOT_FOUND, key); - } - *value = entries_.find(key)->second; - return ::grpc::Status::OK; -} - -::grpc::Status KeyValueStore::Set(const std::string& key, std::string value) { - absl::MutexLock lock(&mu_); - entries_[key] = std::move(value); - return ::grpc::Status::OK; -} - -} // namespace xla diff --git a/third_party/xla/xla/pjrt/distributed/key_value_store.h b/third_party/xla/xla/pjrt/distributed/key_value_store.h deleted file mode 100644 index 4c906c21aa1df9..00000000000000 --- a/third_party/xla/xla/pjrt/distributed/key_value_store.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ -#define XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "grpcpp/grpcpp.h" - -namespace xla { - -// A simple blocking key-value store class. -class KeyValueStore { - public: - KeyValueStore(); - - KeyValueStore(const KeyValueStore&) = delete; - KeyValueStore(KeyValueStore&&) = delete; - KeyValueStore& operator=(const KeyValueStore&) = delete; - KeyValueStore&& operator=(KeyValueStore&&) = delete; - - // Looks up `key`. If present, returns its value. If the key is not present, - // waits until `timeout` expires for the key to arrive. If the key does not - // arrive by the expiry of `timeout`, returns NOT_FOUND. - ::grpc::Status Get(const std::string& key, absl::Duration timeout, - std::string* value); - - // Replaces the value of `key` with `value`. - ::grpc::Status Set(const std::string& key, std::string value); - - private: - absl::Mutex mu_; - absl::flat_hash_map entries_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace xla - -#endif // XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ diff --git a/third_party/xla/xla/pjrt/distributed/protocol.h b/third_party/xla/xla/pjrt/distributed/protocol.h deleted file mode 100644 index 3db718417e176e..00000000000000 --- a/third_party/xla/xla/pjrt/distributed/protocol.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ -#define XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ - -namespace xla { - -inline constexpr int DistributedRuntimeProtocolVersion() { return 3; } - -} // namespace xla - -#endif // XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index a156de841c2f68..947c7267bd3f30 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -22,17 +22,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "grpcpp/server_builder.h" -#include "xla/pjrt/distributed/protocol.h" -#include "xla/pjrt/distributed/topology_util.h" -#include "xla/pjrt/distributed/util.h" -#include "xla/status.h" #include "xla/util.h" #include "tsl/distributed_runtime/coordination/coordination_service.h" #include "tsl/distributed_runtime/rpc/async_service_interface.h" #include "tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/random.h" #include "tsl/platform/threadpool.h" #include "tsl/protobuf/coordination_config.pb.h" diff --git a/third_party/xla/xla/pjrt/distributed/service.h b/third_party/xla/xla/pjrt/distributed/service.h index 91316fd3310143..ef79e12e363759 100644 --- a/third_party/xla/xla/pjrt/distributed/service.h +++ b/third_party/xla/xla/pjrt/distributed/service.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" +#include "grpcpp/grpcpp.h" #include "grpcpp/security/server_credentials.h" -#include "xla/pjrt/distributed/key_value_store.h" #include "xla/pjrt/distributed/protocol.grpc.pb.h" #include "xla/statusor.h" #include "xla/types.h" From 7e97bfddb48c0d5e91f91c787fdfef4f30f7cde6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 08:02:50 -0800 Subject: [PATCH 016/391] Rollback of PR #5782 Rollback change to enable cudnn fused attention by default. PiperOrigin-RevId: 581965108 --- third_party/xla/xla/debug_options_flags.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 1 - third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc | 2 +- third_party/xla/xla/service/gpu/nvptx_compiler.cc | 2 -- third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc | 2 +- 5 files changed, 3 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index e0c7e38c04c20e..e7b4805aa2a34f 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -85,7 +85,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_fast_math_honor_division(true); // TODO(AyanmoI): Remove this flag when cuDNN FMHA is fully supported. - opts.set_xla_gpu_enable_cudnn_fmha(true); + opts.set_xla_gpu_enable_cudnn_fmha(false); opts.set_xla_gpu_fused_attention_use_cudnn_rng(false); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 512ef8190e18d9..1131c358999380 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3016,7 +3016,6 @@ cc_library( ":gpu_sort_rewriter", ":ir_emission_utils", ":metrics", - ":move_copy_to_users", ":target_constants", ":triangular_solve_rewriter", ":triton_autotuner", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 69c318f701746d..148977393bc81f 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -3006,7 +3006,7 @@ ENTRY e { ->fused_instructions_computation() ->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 32}, {1, 0}), - m::Op().WithShape(BF16, {32, 40}, {1, 0})) + m::Op().WithShape(BF16, {40, 32}, {1, 0})) .WithShape(BF16, {16, 40}, {1, 0}))); } diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 329dfd14cdfc77..c98f64520c51ad 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -59,7 +59,6 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/move_copy_to_users.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/triangular_solve_rewriter.h" #include "xla/service/gpu/triton_autotuner.h" @@ -226,7 +225,6 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( false); if (debug_options.xla_gpu_normalize_layouts()) { mha_fusion_pipeline.AddPass(); - mha_fusion_pipeline.AddPass>(); mha_fusion_pipeline.AddPass(); } mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc index ba3513f89a56a2..a05f411fec5f31 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -934,7 +934,7 @@ ENTRY main.4 { if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { MatchOptimizedHlo(hlo_text, R"( -; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %{{.*}}, s8[4,4]{0,1} %{{.*}}), custom_call_target="__cublas$gemm", +; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %fusion.1, s8[4,4]{0,1} %bitcast.13), custom_call_target="__cublas$gemm", ; CHECK: backend_config={ ; CHECK-DAG: "selected_algorithm":"0" ; CHECK-DAG: "alpha_real":1 From 1f9ffe6c62b9418f6ff131d7b38ae17f3d1df006 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 08:25:14 -0800 Subject: [PATCH 017/391] Update header files for tensorflow.org doc generator. Add `\note` about include paths, and shorten Doxygen module names. PiperOrigin-RevId: 581971390 --- .../configuration/c/delegate_plugin.h | 19 +++++++++++---- .../acceleration/configuration/c/gpu_plugin.h | 23 +++++++++++++------ .../configuration/c/xnnpack_plugin.h | 23 +++++++++++++------ tensorflow/lite/core/c/c_api.h | 21 ++++++++++++----- tensorflow/lite/core/c/c_api_opaque.h | 17 +++++++++++++- tensorflow/lite/core/c/c_api_types.h | 22 ++++++++++++------ tensorflow/lite/core/c/common.h | 21 ++++++++++++----- 7 files changed, 107 insertions(+), 39 deletions(-) diff --git a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h index 900a2666934186..bca5b4d912190f 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h @@ -12,18 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) // WARNING: Users of TensorFlow Lite should not include this file directly, // but should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/delegate_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_DELEGATE_PLUGIN_H_ /// C API types for TF Lite delegate plugins. +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/acceleration/configuration/c/delegate_plugin.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on + #include "tensorflow/lite/core/c/common.h" #ifdef __cplusplus @@ -32,7 +41,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup delegate_plugin tensorflow/lite/acceleration/configuration/c/delegate_plugin.h +/** \defgroup delegate_plugin lite/acceleration/configuration/c/delegate_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h index c30ce4dcdf4452..8ca6d7dd262da4 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/gpu_plugin.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/gpu_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_GPU_PLUGIN_H_ @@ -32,6 +31,16 @@ limitations under the License. /// /// But to provide a C API to access the GPU delegate plugin, we do expose /// some functions, which are declared below. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/acceleration/configuration/c/gpu_plugin.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" @@ -41,7 +50,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup gpu_plugin tensorflow/lite/acceleration/configuration/c/gpu_plugin.h +/** \defgroup gpu_plugin lite/acceleration/configuration/c/gpu_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h index fce48ff8622288..e17e11e07fbaac 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// NOLINTBEGIN(whitespace/line_length) -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include // "third_party/tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. -// NOLINTEND(whitespace/line_length) +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ #define TENSORFLOW_LITE_CORE_ACCELERATION_CONFIGURATION_C_XNNPACK_PLUGIN_H_ @@ -32,6 +31,16 @@ limitations under the License. /// /// But to provide a C API to access the XNNPACK delegate plugin, we do expose /// some functions, which are declared below. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" @@ -41,7 +50,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup xnnpack_plugin tensorflow/lite/acceleration/configuration/c/xnnpack_plugin.h +/** \defgroup xnnpack_plugin lite/acceleration/configuration/c/xnnpack_plugin.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/c/c_api.h b/tensorflow/lite/core/c/c_api.h index b98fddf2569744..38433325452e9f 100644 --- a/tensorflow/lite/core/c/c_api.h +++ b/tensorflow/lite/core/c/c_api.h @@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// \warning Note: Users of TensorFlow Lite should not include this file -// directly, but should instead include -// "third_party/tensorflow/lite/c/c_api.h". Only the TensorFlow Lite -// implementation itself should include this -// file directly. +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. #ifndef TENSORFLOW_LITE_CORE_C_C_API_H_ #define TENSORFLOW_LITE_CORE_C_C_API_H_ @@ -76,6 +75,16 @@ limitations under the License. /// TfLiteInterpreterOptionsDelete(options); /// TfLiteModelDelete(model); /// +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/c/c_api.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on #ifdef __cplusplus extern "C" { @@ -83,7 +92,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api tensorflow/lite/c/c_api.h +/** \defgroup c_api lite/c/c_api.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/c/c_api_opaque.h b/tensorflow/lite/core/c/c_api_opaque.h index 06bdc194b221f6..a778205ccc769d 100644 --- a/tensorflow/lite/core/c/c_api_opaque.h +++ b/tensorflow/lite/core/c/c_api_opaque.h @@ -12,6 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api_opaque.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. + #ifndef TENSORFLOW_LITE_CORE_C_C_API_OPAQUE_H_ #define TENSORFLOW_LITE_CORE_C_C_API_OPAQUE_H_ @@ -36,10 +41,20 @@ extern "C" { /// potentially including non-backwards-compatible changes, on a different /// schedule than for the other TensorFlow Lite APIs. See /// https://www.tensorflow.org/guide/versions#separate_version_number_for_tensorflow_lite_extension_apis. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/c/c_api_opaque.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api_opaque tensorflow/lite/c/c_api_opaque.h +/** \defgroup c_api_opaque lite/c/c_api_opaque.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index c1f0c568fcf04a..268c0cc2712ec4 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -12,16 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api_types.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. /// This file declares types used by the pure C inference API defined in /// c_api.h, some of which are also used in the C++ and C kernel and interpreter /// APIs. - -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include -// "third_party/tensorflow/lite/c/c_api_types.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/c/c_api_types.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on // IWYU pragma: private, include "third_party/tensorflow/lite/c/c_api_types.h" @@ -36,7 +44,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup c_api_types tensorflow/lite/c/c_api_types.h +/** \defgroup c_api_types lite/c/c_api_types.h * @{ */ // NOLINTEND(whitespace/line_length) diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 0ebba76e948f33..62ad3fc85fb5a0 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/common.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. /// This file defines common C types and APIs for implementing operations, /// delegates and other constructs in TensorFlow Lite. The actual operations and @@ -32,12 +36,17 @@ limitations under the License. /// /// NOTE: The order of values in these structs are "semi-ABI stable". New values /// should be added only to the end of structs and never reordered. +/// +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// +/// #include "tensorflow/lite/c/common.h" +/// +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on -// WARNING: Users of TensorFlow Lite should not include this file directly, -// but should instead include -// "third_party/tensorflow/lite/c/common.h". -// Only the TensorFlow Lite implementation itself should include this -// file directly. // IWYU pragma: private, include "third_party/tensorflow/lite/c/common.h" #ifndef TENSORFLOW_LITE_CORE_C_COMMON_H_ @@ -56,7 +65,7 @@ extern "C" { // clang-format off // NOLINTBEGIN(whitespace/line_length) -/** \defgroup common tensorflow/lite/c/common.h +/** \defgroup common lite/c/common.h * @{ */ // NOLINTEND(whitespace/line_length) From fb74e6c67ac3e1d60e10cc895ee5220db2c38d6f Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Mon, 13 Nov 2023 09:19:42 -0800 Subject: [PATCH 018/391] Do not mangle async error in various pjrt buffer implementations. Before this change, at various points in pjrt buffer implementations, when the async buffer becomes available as an absl::Status error, the error code is replaced by absl::StatusCode::kInternal, and the error message is also modified. This change propagates the underlying error without modification to the callback waiting on that async value. PiperOrigin-RevId: 581987519 --- third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc | 3 +-- .../xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 60efb3e8671307..448d1ee463a716 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -565,6 +565,7 @@ cc_library( "//xla/runtime:cpu_event", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/concurrency:async_value", diff --git a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc index b1e7be85f1c8a4..92de62a2f39476 100644 --- a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc @@ -435,8 +435,7 @@ PjRtFuture AbstractTfrtCpuBuffer::ToLiteralHelper( // Errors in src buffer are surfaced to user. for (const auto& av : device_buffer_wait_avs) { if (auto* error = av->GetErrorIfPresent()) { - ready_event.emplace(Internal("Error converting to literal: %s", - error->message())); + ready_event.emplace(*error); return; } } diff --git a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc index ae7d556a142754..8895aa927f44dd 100644 --- a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/runtime/cpu_event.h" #include "tsl/concurrency/async_value_ref.h" @@ -45,7 +46,7 @@ tsl::AsyncValueRef AfterAll( tsl::AsyncValueRef after_all; absl::Mutex mutex; - std::string error_message; + absl::Status error; }; auto after_all = tsl::MakeConstructedAsyncValueRef(); @@ -55,12 +56,12 @@ tsl::AsyncValueRef AfterAll( event.AndThen([state, event = event.AsPtr()]() { if (event.IsError()) { absl::MutexLock lock(&state->mutex); - state->error_message = event.GetError().message(); + state->error = event.GetError(); } if (state->count.fetch_sub(1, std::memory_order_acq_rel) == 1) { - if (!state->error_message.empty()) { - state->after_all.SetError(state->error_message); + if (!state->error.ok()) { + state->after_all.SetError(state->error); } else { state->after_all.SetStateConcrete(); } From 9d12bd26f98d66d174fde8db96243b42ea73c806 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 09:28:10 -0800 Subject: [PATCH 019/391] Integrate LLVM at llvm/llvm-project@b8a062061571 Updates LLVM usage to match [b8a062061571](https://github.com/llvm/llvm-project/commit/b8a062061571) PiperOrigin-RevId: 581989960 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/xla/xla/service/gpu/tests/gpu_index_test.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 840c36045a28cc..5df608affbc631 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "75d6795e420274346b14aca8b6bd49bfe6030eeb" - LLVM_SHA256 = "9f8a9f28d82aef17e58755f30a926dc17a7d48a2c393d547fcf014f62a704590" + LLVM_COMMIT = "b8a062061571b7868013f1fefb891bdaa2da1adc" + LLVM_SHA256 = "b5cafaa5d5f80f25e701a5c5e37898c2b6e0f925db298190c2b694f6e328275d" tf_http_archive( name = name, diff --git a/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc index 649804c4cd971e..261b2721675403 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_index_test.cc @@ -144,7 +144,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { CompileAndVerifyIr(std::move(module), R"( ; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14 -; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64 +; CHECK: %[[idx1:.*]] = zext nneg i{{[0-9]*}} %[[urem1]] to i64 ; CHECK: getelementptr inbounds float, ptr{{( addrspace\(1\))?}} %[[alloc:.*]], i64 %[[idx1]] )", /*match_optimized_ir=*/true); From 0a692312042f8241074d5636f25f9994a00d2780 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Mon, 13 Nov 2023 09:30:18 -0800 Subject: [PATCH 020/391] Fix MSan issue. PiperOrigin-RevId: 581990543 --- .../lite/delegates/gpu/common/model_builder_helper_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc index f6478fc5a8ae54..0b6819333d6b39 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc @@ -32,6 +32,7 @@ TEST(ModelBuilderHelperTest, CreateVectorCopyDataDifferentSize) { tflite_tensor.data.i32 = src_data; tflite_tensor.dims = TfLiteIntArrayCreate(1); tflite_tensor.dims->data[0] = 4; + tflite_tensor.bytes = 4 * sizeof(int32_t); int16_t dst[4]; ASSERT_OK(CreateVectorCopyData(tflite_tensor, dst)); From ea7f22aa0b7bede312bcdea0304a69e5793b46a2 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 13 Nov 2023 09:43:23 -0800 Subject: [PATCH 021/391] [xla:gpu] Add a custom fusion kind to HLO fusion operation #6904 `__custom_fusion` is a new kind of a custom backend supported by Fusion instruction, and it allows to map a fusion computation to a hand-written C++ kernel. Example: mapping dot fusion to cutlass gemm kernel ``` HloModule cutlass cutlass_gemm { arg0 = f32[32,64]{1,0} parameter(0) arg1 = f32[64,16]{1,0} parameter(1) ROOT dot = f32[32,16]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { arg0 = f32[32, 64]{1,0} parameter(0) arg1 = f32[64, 16]{1,0} parameter(1) ROOT _ = f32[32,16]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}} } ``` PiperOrigin-RevId: 581994280 --- third_party/xla/xla/service/gpu/BUILD | 5 + .../xla/xla/service/gpu/backend_configs.proto | 9 ++ .../xla/service/gpu/hlo_fusion_analysis.cc | 11 +- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 1 + .../xla/xla/service/gpu/ir_emission_utils.h | 5 + .../xla/service/gpu/ir_emitter_unnested.cc | 63 ++++++++ .../xla/xla/service/gpu/ir_emitter_unnested.h | 6 + .../xla/xla/service/gpu/kernel_thunk.cc | 80 ++++++++++ .../xla/xla/service/gpu/kernel_thunk.h | 58 ++++++- third_party/xla/xla/service/gpu/kernels/BUILD | 60 ++++++- .../xla/service/gpu/kernels/custom_fusion.cc | 56 +++++++ .../xla/service/gpu/kernels/custom_fusion.h | 150 ++++++++++++++++++ .../gpu/kernels/cutlass_gemm_fusion.cc | 64 ++++++++ .../gpu/kernels/cutlass_gemm_fusion_test.cc | 68 ++++++++ .../xla/xla/service/gpu/priority_fusion.cc | 1 + .../xla/xla/stream_executor/kernel_spec.h | 10 +- 16 files changed, 637 insertions(+), 10 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/kernels/custom_fusion.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/custom_fusion.h create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1131c358999380..24b429d3a49037 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -323,6 +323,8 @@ cc_library( "//xla/service/gpu/fusions:thunk_util", "//xla/service/gpu/fusions:tiling_util", "//xla/service/gpu/fusions:transpose", + "//xla/service/gpu/kernels:custom_fusion", + "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/runtime3:custom_call_thunk", "//xla/service/gpu/runtime3:fft_thunk", "//xla/service/llvm_ir:buffer_assignment_util", @@ -999,6 +1001,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:xla_debug_info_manager", + "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/runtime:executable", "//xla/service/gpu/runtime:support", "//xla/service/gpu/runtime3:custom_call_thunk", @@ -1081,6 +1084,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -3455,6 +3459,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@llvm-project//llvm:ir_headers", "@local_tsl//tsl/platform:macros", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index c9bdbd009f0005..7f068fb77e1f7f 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -120,6 +120,12 @@ message ReificationCost { double end_to_end_cycles = 1; // Total execution time of the reified op. } +// Backend config for a custom fusion (pre-compiled device kernel implementing a +// fusion computation). +message CustomFusionConfig { + string name = 1; +} + message FusionBackendConfig { // kLoop, kInput, or kOutput (from HloInstruction::FusionKind), or your own // custom string. @@ -136,6 +142,9 @@ message FusionBackendConfig { // present, we use the default Triton config. AutotuneResult.TritonGemmKey triton_gemm_config = 2; + // Only valid when kind == "__custom_fusion". + CustomFusionConfig custom_fusion_config = 4; + // Cost model prediction. ReificationCost reification_cost = 3; } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index e89497a6139cc5..bb2fe734a63055 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/numeric/bits.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -308,6 +309,10 @@ bool HloFusionAnalysis::HasConsistentTransposeHeros() const { HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() const { + if (fusion_backend_config_.kind() == kCustomFusionKind) { + return EmitterFusionKind::kCustomFusion; + } + #if GOOGLE_CUDA if (fusion_backend_config_.kind() == kTritonGemmFusionKind || fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { @@ -388,8 +393,12 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions() { return CalculateLaunchDimensions(root_shape, *device_info_, {unroll_factor, /*few_waves=*/false}); } + case EmitterFusionKind::kCustomFusion: + return absl::UnimplementedError( + "GetLaunchDimensions is not implemented for custom fusions"); case EmitterFusionKind::kTriton: - return Unimplemented("GetLaunchDimensions"); + return absl::UnimplementedError( + "GetLaunchDimensions is not implemented for Triton fusions"); } } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index 8f6a245a6578d4..c07819db2d3a15 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -40,6 +40,7 @@ class HloFusionAnalysis { // The type of emitted fusion. enum class EmitterFusionKind { kLoop, + kCustomFusion, kTriton, kReduction, kTranspose, diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index e72068bb8209b1..6ae8a2a61df7f3 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,6 +49,10 @@ bool IsMatrixMultiplication(const HloInstruction& dot); inline constexpr int64_t WarpSize() { return 32; } +// Fusions that implemented with pre-compiled device kernels have +// FusionBackendConfig.kind requel to this string. +inline constexpr absl::string_view kCustomFusionKind = "__custom_fusion"; + // Fusions that use Triton have FusionBackendConfig.kind equal to this string. inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm"; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 0b023ca4c66dc9..d8e2909f193737 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -117,6 +117,8 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_thunk.h" +#include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/nccl_all_gather_thunk.h" @@ -318,6 +320,17 @@ StatusOr> BuildKernelThunkForFusion( /*shmem_bytes=*/0); } +StatusOr> BuildCustomKernelThunkForFusion( + IrEmitterContext& ir_emitter_context, const HloFusionInstruction* fusion, + kernel::CustomKernel custom_kernel) { + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.buffer_assignment(), fusion)); + + return std::make_unique( + fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); +} + // Derives the number of warps to use for processing a Triton Softmax fusion. int DeriveNumWarpsFromTritonSoftmaxComputation( const HloComputation* computation) { @@ -2101,6 +2114,11 @@ Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, } case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(instr, nullptr, fusion_analysis); + case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); + return EmitCustomFusion(instr, backend_config.custom_fusion_config()); + } default: return FailedPrecondition( "Fusion type not supported by the HLO emitter."); @@ -2165,6 +2183,11 @@ Status IrEmitterUnnested::EmitFusion( #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: + if (!backend_config.has_custom_fusion_config()) + return absl::InternalError( + "custom fusion is missing custom fusion config"); + return EmitCustomFusion(fusion, backend_config.custom_fusion_config()); case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion, fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: @@ -3190,6 +3213,46 @@ Status IrEmitterUnnested::EmitScatter(const HloFusionInstruction* fusion, return OkStatus(); } +Status IrEmitterUnnested::EmitCustomFusion(const HloFusionInstruction* fusion, + const CustomFusionConfig& config) { + VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); + + auto* registry = kernel::CustomFusionRegistry::Default(); + auto* custom_fusion = registry->Lookup(config.name()); + + // If custom fusion is not found it means that some of the build targets might + // not be statically linked into the binary. + if (custom_fusion == nullptr) { + return absl::InternalError(absl::StrCat( + "Custom fusion ", config.name(), " not found in a default registry.")); + } + + // Load custom kernels that can implement a fusion computation. + TF_ASSIGN_OR_RETURN( + std::vector kernels, + custom_fusion->LoadKernels(fusion->fused_instructions_computation())); + + // This should never happen, it means that compilation pipeline created a + // fusion operation that is not supported by a given custom fusion. + if (kernels.empty()) { + return absl::InternalError( + absl::StrCat("Custom fusion ", config.name(), + " returned empty custom kernels for a fused computation")); + } + + // TODO(ezhulenev): Add support for auto tuning to select the best kernel. + if (kernels.size() != 1) { + return absl::InternalError("Expected exactly one custom kernel"); + } + + TF_ASSIGN_OR_RETURN( + auto thunk, BuildCustomKernelThunkForFusion(*ir_emitter_context_, fusion, + std::move(kernels[0]))); + + AddThunkToThunkSequence(std::move(thunk)); + return OkStatus(); +} + Status IrEmitterUnnested::EmitOp( mlir::Operation* op, const absl::flat_hash_map& diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 12f64b9db2241c..3fc2255488b55a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" namespace xla { @@ -356,6 +357,11 @@ class IrEmitterUnnested : public IrEmitter { mlir::lmhlo::FusionOp fusion_op, HloFusionAnalysis& fusion_analysis); + // Emits kernel thunk for a custom fusion implemented with hand written custom + // device kernels. + Status EmitCustomFusion(const HloFusionInstruction* fusion, + const CustomFusionConfig& config); + // Builds a kernel thunk for a non-fusion operation, without reuse. // // All input and output tensors of `op` are passed to the kernel. diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.cc b/third_party/xla/xla/service/gpu/kernel_thunk.cc index f02b51fd5408e6..1dfb672ef65c8e 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/kernel_thunk.cc @@ -22,21 +22,32 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/thunk.h" +#include "xla/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" namespace xla { namespace gpu { namespace { +//===----------------------------------------------------------------------===// +// KernelThunk +//===----------------------------------------------------------------------===// + mlir::Value RemoveTransformingOperations(mlir::Value value) { mlir::Operation* defining_op = value.getDefiningOp(); if (auto cast_op = llvm::isa kernel_arguments) + : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), + custom_kernel_(std::move(custom_kernel)) { + args_.reserve(kernel_arguments.size()); + written_.reserve(kernel_arguments.size()); + for (const auto& kernel_argument : kernel_arguments) { + if (!kernel_argument.first_with_same_slice().has_value()) { + args_.push_back(kernel_argument.slice()); + written_.push_back(kernel_argument.written()); + } + } +} + +std::string CustomKernelThunk::ToStringExtra(int indent) const { + // TODO(ezhulenev): Add `name` to a custom kernel and add pretty printing for + // custom kernel launch dimensions. + return absl::StrFormat(", kernel = %s, launch dimensions = %s", "", + ""); +} + +Status CustomKernelThunk::Initialize(se::StreamExecutor* executor, + ExecutableSource src) { + absl::MutexLock lock(&mutex_); + + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + auto kernel = std::make_unique(executor); + TF_RETURN_IF_ERROR( + executor->GetKernel(custom_kernel_.kernel_spec(), kernel.get())); + kernel_cache_.emplace(executor, std::move(kernel)); + } + + return OkStatus(); +} + +Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { + se::StreamExecutor* executor = params.stream->parent(); + + const se::Kernel* kernel = [&] { + absl::MutexLock lock(&mutex_); + return kernel_cache_[executor].get(); + }(); + + VLOG(3) << "Launching " << kernel->name(); + + absl::InlinedVector buffer_args; + for (const BufferAllocation::Slice& arg : args_) { + se::DeviceMemoryBase buf = params.buffer_allocations->GetDeviceAddress(arg); + VLOG(3) << " Arg: alloc #" << arg.index() << ", offset: " << arg.offset() + << ": " << buf.opaque() << " (" << buf.size() << "B)"; + buffer_args.push_back(buf); + } + + if (VLOG_IS_ON(100)) { + PrintBufferContents(params.stream, buffer_args); + } + + se::KernelArgsDeviceMemoryArray args(buffer_args, + custom_kernel_.shared_memory_bytes()); + return executor->Launch(params.stream, custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *kernel, args); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.h b/third_party/xla/xla/service/gpu/kernel_thunk.h index 24dd5d05d89501..be458cbcc494ea 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/kernel_thunk.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_KERNEL_THUNK_H_ #define XLA_SERVICE_GPU_KERNEL_THUNK_H_ +#include #include #include +#include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -28,16 +32,29 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/kernel_arguments.h" +#include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/thunk.h" +#include "xla/status.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" +#include "xla/types.h" // IWYU pragma: keep namespace xla { namespace gpu { class GpuExecutable; +// TODO(ezhulenev): Unify KernelThunk and CustomKernelThunk as they are very +// similar. XLA:GPU should use more of kernel loading APIs provided by +// StreamExecutor out of the box and less custom kernel loading solutions. +// +// Today KernelThunk is required for lowering to XLA runtime, and +// CustomKernelThunk is only supported for thunk execution. + +//===----------------------------------------------------------------------===// +// KernelThunk +//===----------------------------------------------------------------------===// + // This class stores everything that StreamExecutor needs for launching a // kernel. It implements the ExecuteOnStream interface for GpuExecutable to // invoke the corresponding kernel. @@ -104,10 +121,45 @@ class KernelThunk : public Thunk { // mlir::Value(s) corresponding to the buffer slice arguments. std::vector values_; + // Loaded kernels for each `StreamExecutor`. mutable absl::Mutex mutex_; + absl::flat_hash_map> + kernel_cache_ ABSL_GUARDED_BY(mutex_); +}; + +//===----------------------------------------------------------------------===// +// CustomKernelThunk +//===----------------------------------------------------------------------===// - // Loaded kernels for each `StreamExecutor`. Requires pointer stability of - // values. +// CustomKernelThunk loads and executes kernels defined by a custom kernel +// (which in practice means hand written CUDA C++ kernel), instead of a kernel +// compiled by XLA and loaded from an executable source. +class CustomKernelThunk : public Thunk { + public: + CustomKernelThunk(const HloInstruction* instr, + kernel::CustomKernel custom_kernel, + absl::Span kernel_arguments); + + std::string ToStringExtra(int indent) const override; + + Status Initialize(se::StreamExecutor* executor, + ExecutableSource src) override; + Status ExecuteOnStream(const ExecuteParams& params) override; + + private: + // Buffer slices passed to the kernel as arguments. + std::vector args_; + + // args_[i] is written iff (written_[i] == true). + std::vector written_; + + // mlir::Value(s) corresponding to the buffer slice arguments. + std::vector values_; + + kernel::CustomKernel custom_kernel_; + + // Loaded kernels for each `StreamExecutor`. + mutable absl::Mutex mutex_; absl::flat_hash_map> kernel_cache_ ABSL_GUARDED_BY(mutex_); }; diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 3f96561227402a..43edb6503fd86c 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -4,10 +4,35 @@ load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configu package( default_visibility = ["//visibility:public"], - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) +package_group( + name = "friends", + includes = ["//xla:friends"], +) + +cc_library( + name = "custom_fusion", + srcs = ["custom_fusion.cc"], + hdrs = ["custom_fusion.h"], + visibility = ["//visibility:public"], + deps = [ + ":custom_kernel", + "//xla:status", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "custom_kernel", srcs = ["custom_kernel.cc"], @@ -24,6 +49,39 @@ cc_library( # copybara:uncomment_begin(google-only) # # TODO(ezhulenev): We currently do not have a CUTLASS dependency in open source BUILD. # +# cc_library( +# name = "cutlass_gemm_fusion", +# srcs = ["cutlass_gemm_fusion.cc"], +# deps = [ +# ":custom_fusion", +# ":custom_kernel", +# ":cutlass_gemm_kernel", +# "@com_google_absl//absl/status", +# "//xla:statusor", +# "//xla:xla_data_proto_cc", +# "//xla/hlo/ir:hlo", +# "@local_tsl//tsl/platform:errors", +# "@local_tsl//tsl/platform:logging", +# "@local_tsl//tsl/platform:statusor", +# ], +# alwayslink = 1, # static fusion registration +# ) +# +# xla_test( +# name = "cutlass_gemm_fusion_test", +# srcs = ["cutlass_gemm_fusion_test.cc"], +# backends = ["gpu"], +# deps = [ +# ":cutlass_gemm_fusion", +# "@com_google_absl//absl/strings", +# "//xla:debug_options_flags", +# "//xla:error_spec", +# "//xla/tests:hlo_test_base", +# "@local_tsl//tsl/platform:test", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) +# # cuda_library( # name = "cutlass_gemm_kernel", # srcs = ["cutlass_gemm_kernel.cu.cc"], diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc new file mode 100644 index 00000000000000..02e7b60c6dbe10 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/custom_fusion.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "xla/status.h" + +namespace xla::gpu::kernel { + +//===----------------------------------------------------------------------===// +// CustomFusionRegistry +//===----------------------------------------------------------------------===// + +CustomFusionRegistry* CustomFusionRegistry::Default() { + static auto* registry = new CustomFusionRegistry(); + return registry; +} + +Status CustomFusionRegistry::Register(std::string name, + std::unique_ptr fusion) { + absl::MutexLock lock(&mutex_); + if (auto it = registry_.try_emplace(name, std::move(fusion)); it.second) + return OkStatus(); + return absl::InternalError( + absl::StrCat("Custom fusion ", name, " already registered.")); +} + +CustomFusion* CustomFusionRegistry::Lookup(std::string_view name) const { + absl::MutexLock lock(&mutex_); + if (auto it = registry_.find(name); it != registry_.end()) + return it->second.get(); + return nullptr; +} + +} // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h new file mode 100644 index 00000000000000..9110e20ddf5e2f --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h @@ -0,0 +1,150 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ +#define XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "tsl/platform/logging.h" + +namespace xla::gpu::kernel { + +//===----------------------------------------------------------------------===// +// CustomFusion +//===----------------------------------------------------------------------===// + +// Custom fusion is a mechanism for registering custom kernels corresponding to +// HLO fusions. +// +// Example: row-major mixed dtype gemm with fused bitcast +// +// %gemm (parameter_0: s8[19,17], parameter_1: f16[15,19]) -> f16[15,17] { +// %parameter_1 = f16[15,19]{1,0} parameter(1) +// %parameter_0 = s8[19,17]{1,0} parameter(0) +// %cp1.1 = f16[19,17]{1,0} convert(%parameter_0) +// ROOT %r.1 = f16[15,17]{1,0} dot(%parameter_1, %cp1.1), +// lhs_contracting_dims={1}, +// rhs_contracting_dims={0} +// } +// +// ENTRY %e (p0: f16[15,19], p1: s8[19,17]) -> f16[15,17] { +// %p1 = s8[19,17]{1,0} parameter(1) +// %p0 = f16[15,19]{1,0} parameter(0) +// ROOT %gemm = f16[15,17]{1,0} fusion(%p1, %p0), kind=kCustom, +// +// } +// +// XLA:GPU has multiple strategies for executing this fusion on device: +// +// (1) cuBLAS library call: a lot of simple gemm operations are supported by +// cuBLAS out of the box. However some combinations of paramters casting and +// epilogue fusion are not supported, which means that XLA has to form +// smaller fusions or use code generation to compiled a device kernel. +// +// (2) Triton: XLA:GPU uses Triton to codegen gemm fusion into devie kernels +// (PTX and CUBIN for NVIDIA gpus). +// +// (3) Custom fusion is another mechanism to execute fusion on device, which +// relies on pre-compiled libraries of custom kernels authored by CUDA C++ +// experts. Custom fusion implements one particular fusion pattern (e.g. +// type casting plus a dot operation like in the example above) with custom +// kernels that XLA has to choose from at run time based on auto tuning. +// +// In practice custom fusion almost always implemented with multiple +// kernels, because input shapes are not known at compile time, and custom +// fusion has multiple kernels with different tiling schemes. +// +// What differentiates custom fusions from custom calls, is that custom fusion +// should be implemented with a device kernel, and this allows XLA:GPU to treat +// custom fusion just like any other device kernel: it's launched as a regular +// KernelThunk and automatically captured into command buffers. +// +// Custom calls (registered with XLA:FFI) on the other hand gives much more +// flexibility, and can be implemented as a combination of a non-trivial host +// side code plus multiple kernel launches or library calls. +// +// Also XLA:FFI offers a stable C API that allows registering external functions +// loaded from dynamic libraries compiled with a different toolchain of XLA +// version. Custom fusions integration relies on C++ ABI and static linking. +// +// TODO(ezhulenev): It should be possible to lower `stablehlo.custom_call` +// operations to custom fusions, albeit with a static linking restriction. +class CustomFusion { + public: + virtual ~CustomFusion() = default; + + // Loads kernels implementing `hlo_computation`. + virtual StatusOr> LoadKernels( + const HloComputation* computation) const = 0; +}; + +//===----------------------------------------------------------------------===// +// CustomFusionRegistry +//===----------------------------------------------------------------------===// + +// Custom fusion registry is a mapping from a custom fusion name to the custom +// fusion implementation, and XLA compiler uses this registry to lower fusion +// operations to kernels when emitting thunks. +class CustomFusionRegistry { + public: + // Returns a pointer to a default custom fusion registry, which is a global + // static registry. + static CustomFusionRegistry* Default(); + + // Registers custom fusion in the registry. Returns error if fusion with the + // given name already registered. + Status Register(std::string name, std::unique_ptr fusion); + + // Looks up custom fusion by name. Return nullptr if it's not found. + CustomFusion* Lookup(std::string_view name) const; + + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map> registry_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace xla::gpu::kernel + +#define XLA_REGISTER_CUSTOM_FUSION(NAME, FUSION) \ + XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, __COUNTER__) + +#define XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, N) \ + XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) + +#define XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) \ + ABSL_ATTRIBUTE_UNUSED static const bool \ + xla_custom_fusion_##N##_registered_ = [] { \ + ::xla::Status status = \ + ::xla::gpu::kernel::CustomFusionRegistry::Default()->Register( \ + NAME, std::make_unique()); \ + if (!status.ok()) LOG(ERROR) << status; \ + return status.ok(); \ + }() + +#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc new file mode 100644 index 00000000000000..ce8e43208f1a09 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -0,0 +1,64 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/statusor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu::kernel { + +class CutlassGemmFusion : public CustomFusion { + public: + StatusOr> LoadKernels( + const HloComputation* computation) const final { + // TODO(ezhulenev): This is the most basic check to pass a single test we + // have today. Expand it to properly check all invariants of a dot + // instruction supported by CUTLASS gemm kernels. + auto* dot = DynCast(computation->root_instruction()); + if (dot == nullptr) + return absl::InternalError( + "cutlass_gemm requires ROOT operation to be a dot"); + + PrimitiveType dtype = dot->shape().element_type(); + if (dtype != PrimitiveType::F32) + return absl::InternalError("Unsupported element type"); + + auto& lhs_shape = dot->operand(0)->shape(); + auto& rhs_shape = dot->operand(1)->shape(); + + size_t m = lhs_shape.dimensions(0); + size_t k = lhs_shape.dimensions(1); + size_t n = rhs_shape.dimensions(1); + + TF_ASSIGN_OR_RETURN(auto kernel, GetCutlassGemmKernel(dtype, m, n, k)); + return std::vector{std::move(kernel)}; + } +}; + +} // namespace xla::gpu::kernel + +XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", + ::xla::gpu::kernel::CutlassGemmFusion); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc new file mode 100644 index 00000000000000..afb5ac568983a8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/debug_options_flags.h" +#include "xla/error_spec.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla::gpu::kernel { + +class CutlassFusionTest : public HloTestBase { + // Custom fusions are not supported by XLA runtime. + DebugOptions GetDebugOptionsForTest() override { + auto debug_options = GetDebugOptionsFromFlags(); + debug_options.set_xla_gpu_enable_xla_runtime_executable(false); + return debug_options; + } +}; + +TEST_F(CutlassFusionTest, SimpleF32Gemm) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + arg0 = f32[32, 64]{1,0} parameter(0) + arg1 = f32[64, 16]{1,0} parameter(1) + gemm = (f32[32,16]{1,0}, s8[0]{0}) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + ROOT get-tuple-element = f32[32,16]{1,0} get-tuple-element((f32[32,16]{1,0}, s8[0]{0}) gemm), index=0 + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm { + arg0 = f32[32,64]{1,0} parameter(0) + arg1 = f32[64,16]{1,0} parameter(1) + ROOT dot = f32[32,16]{1,0} dot(arg0, arg1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY e { + arg0 = f32[32, 64]{1,0} parameter(0) + arg1 = f32[64, 16]{1,0} parameter(1) + ROOT _ = f32[32,16]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, + backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + error_spec, /*run_hlo_passes=*/false)); +} + +} // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index ecccc546147c04..069514a435fa9d 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -466,6 +466,7 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( case HloFusionAnalysis::EmitterFusionKind::kLoop: return HloInstruction::FusionKind::kLoop; case HloFusionAnalysis::EmitterFusionKind::kTriton: + case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: return HloInstruction::FusionKind::kCustom; case HloFusionAnalysis::EmitterFusionKind::kReduction: case HloFusionAnalysis::EmitterFusionKind::kTranspose: diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index 5e5bfe89d66353..7e7aef9d3887a2 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -333,15 +333,15 @@ class MultiKernelLoaderSpec { } private: - std::unique_ptr + std::shared_ptr in_process_symbol_; // In process symbol pointer. - std::unique_ptr + std::shared_ptr cuda_ptx_on_disk_; // PTX text that resides in a file. - std::unique_ptr + std::shared_ptr cuda_cubin_on_disk_; // Binary CUDA program in a file. - std::unique_ptr + std::shared_ptr cuda_cubin_in_memory_; // Binary CUDA program in memory. - std::unique_ptr + std::shared_ptr cuda_ptx_in_memory_; // PTX text that resides in memory. // Number of parameters that the kernel takes. (This is nicer to have in a From c7df906a310e30826a377d469477672b68e11d76 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 13 Nov 2023 10:10:52 -0800 Subject: [PATCH 022/391] Update tensorflow/lite/python/lite.py --- tensorflow/lite/python/lite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 1ef88ab4c2ffe2..e4cf86e5227525 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -1087,8 +1087,6 @@ def _validate_inference_input_output_types(self, quant_mode): all_types = default_types + [_dtypes.int16] else: all_types = default_types + [_dtypes.int8, _dtypes.uint8, _dtypes.int16] - if (self.inference_input_type not in all_types or - self.inference_output_type not in all_types): all_types = default_types + [_dtypes.int8, _dtypes.uint8] if ( self.inference_input_type not in all_types From a354ed36ca567c864b62c4987417399cefd93052 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 13 Nov 2023 10:11:26 -0800 Subject: [PATCH 023/391] Update tensorflow/lite/python/lite.py --- tensorflow/lite/python/lite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index e4cf86e5227525..5ec6f6de6bda57 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -1087,7 +1087,6 @@ def _validate_inference_input_output_types(self, quant_mode): all_types = default_types + [_dtypes.int16] else: all_types = default_types + [_dtypes.int8, _dtypes.uint8, _dtypes.int16] - all_types = default_types + [_dtypes.int8, _dtypes.uint8] if ( self.inference_input_type not in all_types or self.inference_output_type not in all_types From ee3787df25a3a90f8b13696abf2343bfbe5cbd43 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 10:08:06 -0800 Subject: [PATCH 024/391] Remove Windows and Darwinn functionality from cuda_configure.bzl. Remove @cub_archive (shipped with CUDA toolkit since v11). PiperOrigin-RevId: 582002499 --- tensorflow/opensource_only.files | 2 + tensorflow/tools/lib_package/BUILD | 2 + tensorflow/tools/pip_package/BUILD | 1 + tensorflow/workspace2.bzl | 10 + third_party/gpus/check_cuda_libs.py | 7 +- .../windows/msvc_wrapper_for_nvcc.py.tpl | 256 +++++++++++++ third_party/gpus/cuda/BUILD.tpl | 42 +-- third_party/gpus/cuda/BUILD.windows.tpl | 230 ++++++++++++ third_party/gpus/cuda_configure.bzl | 338 ++++++++++++++++-- third_party/gpus/find_cuda_config.py | 46 ++- .../xla/third_party/tsl/opensource_only.files | 2 + .../tsl/third_party/gpus/check_cuda_libs.py | 7 +- .../windows/msvc_wrapper_for_nvcc.py.tpl | 256 +++++++++++++ .../tsl/third_party/gpus/cuda/BUILD.tpl | 42 +-- .../third_party/gpus/cuda/BUILD.windows.tpl | 230 ++++++++++++ .../tsl/third_party/gpus/cuda_configure.bzl | 338 ++++++++++++++++-- .../tsl/third_party/gpus/find_cuda_config.py | 46 ++- .../tsl/platform/default/build_config/BUILD | 14 +- 18 files changed, 1747 insertions(+), 122 deletions(-) create mode 100644 third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl create mode 100644 third_party/gpus/cuda/BUILD.windows.tpl create mode 100644 third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl create mode 100644 third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 68f19f8d488a63..606141b99f24a4 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -236,7 +236,9 @@ tf_staging/third_party/gpus/crosstool/BUILD: tf_staging/third_party/gpus/crosstool/LICENSE: tf_staging/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl: tf_staging/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl: +tf_staging/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl: tf_staging/third_party/gpus/cuda/BUILD.tpl: +tf_staging/third_party/gpus/cuda/BUILD.windows.tpl: tf_staging/third_party/gpus/cuda/BUILD: tf_staging/third_party/gpus/cuda/LICENSE: tf_staging/third_party/gpus/cuda/build_defs.bzl.tpl: diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 0a712456f4e609..513b271be55508 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -165,6 +165,7 @@ genrule( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", @@ -207,6 +208,7 @@ genrule( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index b926de8e53952a..8b83ce23ab5ef6 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -237,6 +237,7 @@ filegroup( ], "//conditions:default": [], }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", "@local_config_nccl//:LICENSE", ]) + if_mkl([ "//third_party/mkl_dnn:LICENSE", diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 65074788800b78..0cc7b5e1c5aae2 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -589,6 +589,16 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/google/pprof/archive/83db2b799d1f74c40857232cb5eb4c60379fe6c2.tar.gz"), ) + # The CUDA 11 toolkit ships with CUB. We should be able to delete this rule + # once TF drops support for CUDA 10. + tf_http_archive( + name = "cub_archive", + build_file = "//third_party:cub.BUILD", + sha256 = "162514b3cc264ac89d91898b58450190b8192e2af1142cf8ccac2d59aa160dda", + strip_prefix = "cub-1.9.9", + urls = tf_mirror_urls("https://github.com/NVlabs/cub/archive/1.9.9.zip"), + ) + tf_http_archive( name = "nvtx_archive", build_file = "//third_party:nvtx.BUILD", diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index b7d98ef2581157..afd6380b0ac203 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -23,6 +23,7 @@ """ import os import os.path +import platform import subprocess import sys @@ -38,6 +39,10 @@ class ConfigError(Exception): pass +def _is_windows(): + return platform.system() == "Windows" + + def check_cuda_lib(path, check_soname=True): """Tests if a library exists on disk and whether its soname matches the filename. @@ -52,7 +57,7 @@ def check_cuda_lib(path, check_soname=True): if not os.path.isfile(path): raise ConfigError("No library found under: " + path) objdump = which("objdump") - if check_soname and objdump is not None: + if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = [line for line in output.splitlines() if "SONAME" in line] diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl new file mode 100644 index 00000000000000..c46e09484fdfad --- /dev/null +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows. + +DESCRIPTION: + This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc +""" + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import tempfile + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('%{cpu_compiler}') +GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') + +NVCC_PATH = '%{nvcc_path}' +NVCC_VERSION = '%{cuda_version}' +NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from options. + + Args: + option: The option whose value to extract. + + Returns: + 1. A list of values, either directly following the option, + (eg., /opt val1 val2) or values collected from multiple occurrences of + the option (eg., /opt val1 /opt val2). + 2. The leftover options. + """ + + parser = ArgumentParser(prefix_chars='-/') + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-/').replace('-', '_') + args, leftover = parser.parse_known_args(argv) + if args and vars(args)[option]: + return (sum(vars(args)[option], []), leftover) + return ([], leftover) + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + 1. The string that can be passed directly to nvcc. + 2. The leftover options. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, leftover = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return (['--' + a for a in options], leftover) + return ([], leftover) + + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling os.system('nvcc ' + args) + """ + + src_files = [f for f in argv if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + if len(src_files) == 0: + raise Error('No source files found for cuda compilation.') + + out_file = [ f for f in argv if f.startswith('/Fo') ] + if len(out_file) != 1: + raise Error('Please specify exactly one output file for cuda compilation.') + out = ['-o', out_file[0][len('/Fo'):]] + + nvcc_compiler_options, argv = GetNvccOptions(argv) + + opt_option, argv = GetOptionValue(argv, '/O') + opt = ['-g'] + if (len(opt_option) > 0 and opt_option[0] != 'd'): + opt = ['-O2'] + + include_options, argv = GetOptionValue(argv, '/I') + includes = ["-I " + include for include in include_options] + + defines, argv = GetOptionValue(argv, '/D') + defines = [ + '-D' + define + for define in defines + if 'BAZEL_CURRENT_REPOSITORY' not in define + ] + + undefines, argv = GetOptionValue(argv, '/U') + undefines = ['-U' + define for define in undefines] + + fatbin_options, argv = GetOptionValue(argv, '-Xcuda-fatbinary') + fatbin_options = ['--fatbin-options=' + option for option in fatbin_options] + + # The rest of the unrecognized options should be passed to host compiler + host_compiler_options = [option for option in argv if option not in (src_files + out_file)] + + m_options = ["-m64"] + + nvccopts = ['-D_FORCE_INLINES'] + compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch") + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability) + ] + compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx') + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability) + ] + _, argv = GetOptionValue(argv, '--no-cuda-include-ptx') + + # nvcc doesn't respect the INCLUDE and LIB env vars from MSVC, + # so we explicity specify the system include paths and library search paths. + if 'INCLUDE' in os.environ: + nvccopts += [('--system-include="%s"' % p) for p in os.environ['INCLUDE'].split(";")] + if 'LIB' in os.environ: + nvccopts += [('--library-path="%s"' % p) for p in os.environ['LIB'].split(";")] + + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += m_options + nvccopts += fatbin_options + nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)] + nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files + # Specify a unique temp directory for nvcc to generate intermediate files, + # then Bazel can ignore files under NVCC_TEMP_DIR during dependency check + # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver + # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists. + if os.path.isfile(NVCC_TEMP_DIR): + os.remove(NVCC_TEMP_DIR) + if not os.path.exists(NVCC_TEMP_DIR): + os.makedirs(NVCC_TEMP_DIR) + # Provide a unique dir for each compiling action to avoid conflicts. + tempdir = tempfile.mkdtemp(dir = NVCC_TEMP_DIR) + nvccopts += ['--keep', '--keep-dir', tempdir] + # Force C++17 dialect (note, everything in just one string!) + nvccopts += ['--std c++17'] + if log: + Log([NVCC_PATH] + nvccopts) + + # Store command line options in a file to avoid hitting the character limit. + optsfile = tempfile.NamedTemporaryFile(mode='w', dir=tempdir, delete=False) + optsfile.write("\n".join(nvccopts)) + optsfile.close() + + proc = subprocess.Popen([NVCC_PATH, "--options-file", optsfile.name], + stdout=sys.stdout, + stderr=sys.stderr, + env=os.environ.copy(), + shell=True) + proc.wait() + return proc.returncode + +def ExpandParamsFileForArgv(): + new_argv = [] + for arg in sys.argv: + if arg.startswith("@"): + with open(arg.strip("@")) as f: + new_argv.extend([l.strip() for l in f.readlines()]) + else: + new_argv.append(arg) + + sys.argv = new_argv + +def ProcessFlagForCommandFile(flag): + if flag.startswith("/D") or flag.startswith("-D"): + # We need to re-escape /DFOO="BAR" as /DFOO=\"BAR\", so that we get + # `#define FOO "BAR"` after expansion as a string literal define + if flag.endswith('"') and not flag.endswith('\\"'): + flag = '\\"'.join(flag.split('"', 1)) + flag = '\\"'.join(flag.rsplit('"', 1)) + return flag + return flag + +def main(): + ExpandParamsFileForArgv() + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log')) + and not flag.startswith(('-nvcc_options'))] + output = [flag for flag in cpu_compiler_flags if flag.startswith("/Fo")] + + # Store command line options in a file to avoid hitting the character limit. + if len(output) == 1: + commandfile_path = output[0][3:] + ".msvc_params" + commandfile = open(commandfile_path, "w") + cpu_compiler_flags = [ProcessFlagForCommandFile(flag) for flag in cpu_compiler_flags] + commandfile.write("\n".join(cpu_compiler_flags)) + commandfile.close() + return subprocess.call([CPU_COMPILER, "@" + commandfile_path]) + else: + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 700e040a88eeca..90a18b90de048c 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -61,23 +61,23 @@ cuda_header_library( cc_library( name = "cudart_static", - srcs = ["cuda/lib/libcudart_static.a"], + srcs = ["cuda/lib/%{cudart_static_lib}"], linkopts = [ "-ldl", - "-lrt", "-lpthread", + %{cudart_static_linkopt} ], ) cc_library( name = "cuda_driver", - srcs = ["cuda/lib/libcuda.so"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], ) cc_library( name = "cudart", - srcs = glob(["cuda/lib/libcudart.so.*"]), - data = glob(["cuda/lib/libcudart.so.*"]), + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], linkstatic = 1, ) @@ -128,30 +128,30 @@ cuda_header_library( cc_library( name = "cublas", - srcs = glob(["cuda/lib/libcublas.so.*"]), - data = glob(["cuda/lib/libcublas.so.*"]), + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], linkstatic = 1, ) cc_library( name = "cublasLt", - srcs = glob(["cuda/lib/libcublasLt.so.*"]), - data = glob(["cuda/lib/libcublasLt.so.*"]), + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], linkstatic = 1, ) cc_library( name = "cusolver", - srcs = glob(["cuda/lib/libcusolver.so.*"]), - data = glob(["cuda/lib/libcusolver.so.*"]), + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) cc_library( name = "cudnn", - srcs = glob(["cuda/lib/libcudnn.so.*"]), - data = glob(["cuda/lib/libcudnn.so.*"]), + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], linkstatic = 1, ) @@ -165,15 +165,15 @@ cc_library( cc_library( name = "cufft", - srcs = glob(["cuda/lib/libcufft.so.*"]), - data = glob(["cuda/lib/libcufft.so.*"]), + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], linkstatic = 1, ) cc_library( name = "curand", - srcs = glob(["cuda/lib/libcurand.so.*"]), - data = glob(["cuda/lib/libcurand.so.*"]), + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], linkstatic = 1, ) @@ -192,7 +192,7 @@ cc_library( alias( name = "cub_headers", - actual = ":cuda_headers", + actual = "%{cub_actual}", ) cuda_header_library( @@ -213,13 +213,13 @@ cuda_header_library( cc_library( name = "cupti_dsos", - data = glob(["cuda/lib/libcupti.so.*"]), + data = ["cuda/lib/%{cupti_lib}"], ) cc_library( name = "cusparse", - srcs = glob(["cuda/lib/libcusparse.so.*"]), - data = glob(["cuda/lib/libcusparse.so.*"]), + srcs = ["cuda/lib/%{cusparse_lib}"], + data = ["cuda/lib/%{cusparse_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl new file mode 100644 index 00000000000000..f20ecbd654bf6f --- /dev/null +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -0,0 +1,230 @@ +load(":build_defs.bzl", "cuda_header_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cuda_header_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + "cuda/include", + ], +) + +cc_import( + name = "cudart_static", + # /WHOLEARCHIVE:cudart_static.lib will cause a + # "Internal error during CImplib::EmitThunk" error. + # Treat this library as interface library to avoid being whole archived when + # linking a DLL that depends on this. + # TODO(pcloudy): Remove this rule after b/111278841 is resolved. + interface_library = "cuda/lib/%{cudart_static_lib}", + system_provided = 1, +) + +cc_import( + name = "cuda_driver", + interface_library = "cuda/lib/%{cuda_driver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudart", + interface_library = "cuda/lib/%{cudart_lib}", + system_provided = 1, +) + +cuda_header_library( + name = "cublas_headers", + hdrs = [":cublas-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cublas/include"], + strip_include_prefix = "cublas/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusolver_headers", + hdrs = [":cusolver-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusolver/include"], + strip_include_prefix = "cusolver/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cufft_headers", + hdrs = [":cufft-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cufft/include"], + strip_include_prefix = "cufft/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusparse_headers", + hdrs = [":cusparse-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusparse/include"], + strip_include_prefix = "cusparse/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "curand_headers", + hdrs = [":curand-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["curand/include"], + strip_include_prefix = "curand/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cublas", + interface_library = "cuda/lib/%{cublas_lib}", + system_provided = 1, +) + +cc_import( + name = "cublasLt", + interface_library = "cuda/lib/%{cublasLt_lib}", + system_provided = 1, +) + +cc_import( + name = "cusolver", + interface_library = "cuda/lib/%{cusolver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudnn", + interface_library = "cuda/lib/%{cudnn_lib}", + system_provided = 1, +) + +cc_library( + name = "cudnn_header", + hdrs = [":cudnn-include"], + include_prefix = "third_party/gpus/cudnn", + strip_include_prefix = "cudnn/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cufft", + interface_library = "cuda/lib/%{cufft_lib}", + system_provided = 1, +) + +cc_import( + name = "curand", + interface_library = "cuda/lib/%{curand_lib}", + system_provided = 1, +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = "%{cub_actual}", +) + +cuda_header_library( + name = "cupti_headers", + hdrs = [":cuda-extras"], + include_prefix = "third_party/gpus", + includes = ["cuda/extras/CUPTI/include/"], + deps = [":cuda_headers"], +) + +cc_import( + name = "cupti_dsos", + interface_library = "cuda/lib/%{cupti_lib}", + system_provided = 1, +) + +cc_import( + name = "cusparse", + interface_library = "cuda/lib/%{cusparse_lib}", + system_provided = 1, +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +%{copy_rules} diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 4a10e9b2aa74a8..e73e41a0c383a2 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -27,14 +27,27 @@ """ load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") +load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", + "get_env_var", +) +load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", + "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", "err_out", "execute", "get_bash_bin", + "get_cpu_value", "get_host_environ", "get_python_bin", + "is_windows", "raw_exec", "read_dir", "realpath", @@ -83,7 +96,16 @@ def verify_build_defines(params): "host_compiler_warnings", "linker_bin_path", "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", "unfiltered_compile_flags", + "win_compiler_deps", ]: if ("%{" + param + "}") not in params: missing.append(param) @@ -97,11 +119,102 @@ def verify_build_defines(params): ".", ) +def _get_nvcc_tmp_dir_for_windows(repository_ctx): + """Return the Windows tmp directory for nvcc to generate intermediate source files.""" + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" + +def _get_msvc_compiler(repository_ctx): + vc_path = find_vc_path(repository_ctx) + return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") + +def _get_win_cuda_defines(repository_ctx): + """Return CROSSTOOL defines for Windows""" + + # If we are not on Windows, return fake vaules for Windows specific fields. + # This ensures the CROSSTOOL file parser is happy. + if not is_windows(repository_ctx): + return { + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + } + + vc_path = find_vc_path(repository_ctx) + if not vc_path: + auto_configure_fail( + "Visual C++ build tools not found on your machine." + + "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", + ) + return {} + + env = setup_vc_env_vars(repository_ctx, vc_path) + escaped_paths = escape_string(env["PATH"]) + escaped_include_paths = escape_string(env["INCLUDE"]) + escaped_lib_paths = escape_string(env["LIB"]) + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + + msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat" + msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( + "\\", + "/", + ) + msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( + "\\", + "/", + ) + msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( + "\\", + "/", + ) + + # nvcc will generate some temporary source files under %{nvcc_tmp_dir} + # The generated files are guaranteed to have unique name, so they can share + # the same tmp directory + escaped_cxx_include_directories = [ + _get_nvcc_tmp_dir_for_windows(repository_ctx), + "C:\\\\botcode\\\\w", + ] + for path in escaped_include_paths.split(";"): + if path: + escaped_cxx_include_directories.append(path) + + return { + "%{msvc_env_tmp}": escaped_tmp_dir, + "%{msvc_env_path}": escaped_paths, + "%{msvc_env_include}": escaped_include_paths, + "%{msvc_env_lib}": escaped_lib_paths, + "%{msvc_cl_path}": msvc_cl_path, + "%{msvc_ml_path}": msvc_ml_path, + "%{msvc_link_path}": msvc_link_path, + "%{msvc_lib_path}": msvc_lib_path, + "%{cxx_builtin_include_directories}": to_list_of_strings( + escaped_cxx_include_directories, + ), + } + # TODO(dzc): Once these functions have been factored out of Bazel's # cc_configure.bzl, load them from @bazel_tools instead. # BEGIN cc_configure common functions. def find_cc(repository_ctx, use_cuda_clang): """Find the C++ compiler.""" + if is_windows(repository_ctx): + return _get_msvc_compiler(repository_ctx) if use_cuda_clang: target_cc_name = "clang" @@ -252,9 +365,10 @@ def _cuda_include_path(repository_ctx, cuda_config): Returns: A list of the gcc host compiler include directories. """ - nvcc_path = repository_ctx.path( - "%s/bin/nvcc" % cuda_config.cuda_toolkit_path, - ) + nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( + cuda_config.cuda_toolkit_path, + ".exe" if cuda_config.cpu_value == "Windows" else "", + )) # The expected exit code of this command is non-zero. Bazel remote execution # only caches commands with zero exit code. So force a zero exit code. @@ -315,6 +429,10 @@ def matches_version(environ_version, detected_version): return False return True +_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " + +_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" + def compute_capabilities(repository_ctx): """Returns a list of strings representing cuda compute capabilities. @@ -357,11 +475,12 @@ def compute_capabilities(repository_ctx): return capabilities -def lib_name(base_name, version = None, static = False): +def lib_name(base_name, cpu_value, version = None, static = False): """Constructs the platform-specific name of a library. Args: base_name: The name of the library, such as "cudart" + cpu_value: The name of the host operating system. version: The version of the library. static: True the library is static or False if it is a shared object. @@ -369,20 +488,29 @@ def lib_name(base_name, version = None, static = False): The platform-specific name of the library. """ version = "" if not version else "." + version - if static: - return "lib%s.a" % base_name - return "lib%s.so%s" % (base_name, version) + if cpu_value in ("Linux", "FreeBSD"): + if static: + return "lib%s.a" % base_name + return "lib%s.so%s" % (base_name, version) + elif cpu_value == "Windows": + return "%s.lib" % base_name + elif cpu_value == "Darwin": + if static: + return "lib%s.a" % base_name + return "lib%s%s.dylib" % (base_name, version) + else: + auto_configure_fail("Invalid cpu_value: %s" % cpu_value) -def _lib_path(lib, basedir, version, static): - file_name = lib_name(lib, version, static) +def _lib_path(lib, cpu_value, basedir, version, static): + file_name = lib_name(lib, cpu_value, version, static) return "%s/%s" % (basedir, file_name) def _should_check_soname(version, static): return version and not static -def _check_cuda_lib_params(lib, basedir, version, static = False): +def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False): return ( - _lib_path(lib, basedir, version, static), + _lib_path(lib, cpu_value, basedir, version, static), _should_check_soname(version, static), ) @@ -402,6 +530,8 @@ def _check_cuda_libs(repository_ctx, script_path, libs): all_paths = [path for path, _ in libs] checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines() + # Filter out empty lines from splitting on '\r\n' on Windows + checked_paths = [path for path in checked_paths if len(path) > 0] if all_paths != checked_paths: auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths)) @@ -419,62 +549,86 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): Returns: Map of library names to structs of filename and path. """ + cpu_value = cuda_config.cpu_value + stub_dir = "" if is_windows(repository_ctx) else "/stubs" + check_cuda_libs_params = { "cuda": _check_cuda_lib_params( "cuda", - cuda_config.config["cuda_library_dir"] + "/stubs", + cpu_value, + cuda_config.config["cuda_library_dir"] + stub_dir, version = None, + static = False, ), "cudart": _check_cuda_lib_params( "cudart", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, + static = False, ), "cudart_static": _check_cuda_lib_params( "cudart_static", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, static = True, ), "cublas": _check_cuda_lib_params( "cublas", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cublasLt": _check_cuda_lib_params( "cublasLt", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cusolver": _check_cuda_lib_params( "cusolver", + cpu_value, cuda_config.config["cusolver_library_dir"], cuda_config.cusolver_version, + static = False, ), "curand": _check_cuda_lib_params( "curand", + cpu_value, cuda_config.config["curand_library_dir"], cuda_config.curand_version, + static = False, ), "cufft": _check_cuda_lib_params( "cufft", + cpu_value, cuda_config.config["cufft_library_dir"], cuda_config.cufft_version, + static = False, ), "cudnn": _check_cuda_lib_params( "cudnn", + cpu_value, cuda_config.config["cudnn_library_dir"], cuda_config.cudnn_version, + static = False, ), "cupti": _check_cuda_lib_params( "cupti", + cpu_value, cuda_config.config["cupti_library_dir"], cuda_config.cupti_version, + static = False, ), "cusparse": _check_cuda_lib_params( "cusparse", + cpu_value, cuda_config.config["cusparse_library_dir"], cuda_config.cusparse_version, + static = False, ), } @@ -484,6 +638,10 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()} return paths +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "" if cpu_value == "Darwin" else "\"-lrt\"," + # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, # and nccl_configure.bzl. def find_cuda_config(repository_ctx, cuda_libraries): @@ -510,34 +668,37 @@ def _get_cuda_config(repository_ctx): cudart_version: The CUDA runtime version on the system. cudnn_version: The version of cuDNN on the system. compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. """ config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) + cpu_value = get_cpu_value(repository_ctx) toolkit_path = config["cuda_toolkit_path"] + is_windows = cpu_value == "Windows" cuda_version = config["cuda_version"].split(".") cuda_major = cuda_version[0] cuda_minor = cuda_version[1] - cuda_version = "%s.%s" % (cuda_major, cuda_minor) - cudnn_version = "%s" % config["cudnn_version"] + cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) + cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] if int(cuda_major) >= 11: # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability. if int(cuda_major) == 11: - cudart_version = "11.0" + cudart_version = "64_110" if is_windows else "11.0" cupti_version = cuda_version else: - cudart_version = "%s" % cuda_major + cudart_version = ("64_%s" if is_windows else "%s") % cuda_major cupti_version = cudart_version - cublas_version = "%s" % config["cublas_version"].split(".")[0] - cusolver_version = "%s" % config["cusolver_version"].split(".")[0] - curand_version = "%s" % config["curand_version"].split(".")[0] - cufft_version = "%s" % config["cufft_version"].split(".")[0] - cusparse_version = "%s" % config["cusparse_version"].split(".")[0] + cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0] + cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0] + curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0] + cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0] + cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0] elif (int(cuda_major), int(cuda_minor)) >= (10, 1): # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. # It changed from 'x.y' to just 'x' in CUDA 10.1. - cuda_lib_version = "%s" % cuda_major + cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major cudart_version = cuda_version cupti_version = cuda_version cublas_version = cuda_lib_version @@ -567,6 +728,7 @@ def _get_cuda_config(repository_ctx): cusparse_version = cusparse_version, cudnn_version = cudnn_version, compute_capabilities = compute_capabilities(repository_ctx), + cpu_value = cpu_value, config = config, ) @@ -612,6 +774,8 @@ error_gpu_disabled() """ def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + # Set up BUILD file for cuda/. _tpl( repository_ctx, @@ -626,6 +790,23 @@ def _create_dummy_repository(repository_ctx): repository_ctx, "cuda:BUILD", { + "%{cuda_driver_lib}": lib_name("cuda", cpu_value), + "%{cudart_static_lib}": lib_name( + "cudart_static", + cpu_value, + static = True, + ), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + "%{cudart_lib}": lib_name("cudart", cpu_value), + "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), + "%{cusolver_lib}": lib_name("cusolver", cpu_value), + "%{cudnn_lib}": lib_name("cudnn", cpu_value), + "%{cufft_lib}": lib_name("cufft", cpu_value), + "%{curand_lib}": lib_name("curand", cpu_value), + "%{cupti_lib}": lib_name("cupti", cpu_value), + "%{cusparse_lib}": lib_name("cusparse", cpu_value), + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": """ filegroup(name="cuda-include") filegroup(name="cublas-include") @@ -644,9 +825,20 @@ filegroup(name="cudnn-include") repository_ctx.file("cuda/cuda/include/cublas.h") repository_ctx.file("cuda/cuda/include/cudnn.h") repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") - repository_ctx.file("cuda/cuda/lib/libcuda.so") - repository_ctx.file("cuda/cuda/lib/libcudart_static.a") repository_ctx.file("cuda/cuda/nvml/include/nvml.h") + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) + repository_ctx.file( + "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), + ) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. @@ -710,7 +902,7 @@ def make_copy_files_rule(repository_ctx, name, srcs, outs): cmd = \"""%s \""", )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) -def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): +def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None): """Returns a rule to recursively copy a directory. If exceptions is not None, it must be a list of files or directories in 'src_dir'; these will be excluded from copying. @@ -718,18 +910,27 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): src_dir = _norm_path(src_dir) out_dir = _norm_path(out_dir) outs = read_dir(repository_ctx, src_dir) + post_cmd = "" + if exceptions != None: + outs = [x for x in outs if not any([ + x.startswith(src_dir + "/" + y) + for y in exceptions + ])] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] # '@D' already contains the relative path for a single file, see # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" + if exceptions != None: + for x in exceptions: + post_cmd += " ; rm -fR " + out_dir + "/" + x return """genrule( name = "%s", outs = [ %s ], - cmd = \"""cp -rLf "%s/." "%s/" \""", -)""" % (name, "\n".join(outs), src_dir, out_dir) + cmd = \"""cp -rLf "%s/." "%s/" %s\""", +)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd) def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" @@ -758,6 +959,22 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _tpl_path(repository_ctx, filename): return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) +def _basename(repository_ctx, path_str): + """Returns the basename of a path of type string. + + This method is different from path.basename in that it also works if + the host platform is different from the execution platform + i.e. linux -> windows. + """ + + num_chars = len(path_str) + is_win = is_windows(repository_ctx) + for i in range(num_chars): + r_i = num_chars - 1 - i + if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/": + return path_str[r_i + 1:] + return path_str + def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" @@ -766,14 +983,15 @@ def _create_local_cuda_repository(repository_ctx): # can easily lead to a O(n^2) runtime in the number of labels. # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [ - "cuda:BUILD", "cuda:build_defs.bzl", "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", + "crosstool:windows/msvc_wrapper_for_nvcc.py", "crosstool:BUILD", "crosstool:cc_toolchain_config.bzl", "cuda:cuda_config.h", "cuda:cuda_config.py", ]} + tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD") cuda_config = _get_cuda_config(repository_ctx) @@ -885,7 +1103,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_lib_outs = [] for path in cuda_libs.values(): cuda_lib_srcs.append(path) - cuda_lib_outs.append("cuda/lib/" + path.rpartition("/")[-1]) + cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path)) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-lib", @@ -894,7 +1112,11 @@ def _create_local_cuda_repository(repository_ctx): )) # copy files mentioned in third_party/nccl/build_defs.bzl.tpl - bin_files = ["crt/link.stub", "bin2c", "fatbinary", "nvlink", "nvprune"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + bin_files = ( + ["crt/link.stub"] + + [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]] + ) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-bin", @@ -902,7 +1124,7 @@ def _create_local_cuda_repository(repository_ctx): outs = ["cuda/bin/" + f for f in bin_files], )) - # Select the headers based on the cuDNN version. + # Select the headers based on the cuDNN version (strip '64_' for Windows). cudnn_headers = ["cudnn.h"] if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": cudnn_headers += [ @@ -943,10 +1165,27 @@ def _create_local_cuda_repository(repository_ctx): }, ) + cub_actual = "@cub_archive//:cub" + if int(cuda_config.cuda_version_major) >= 11: + cub_actual = ":cuda_headers" + repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], { + "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]), + "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), + "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), + "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), + "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), + "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), + "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), + "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), + "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), + "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), + "%{cub_actual}": cub_actual, "%{copy_rules}": "\n".join(copy_rules), }, ) @@ -1009,10 +1248,12 @@ def _create_local_cuda_repository(repository_ctx): """ cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) cuda_defines["%{compiler_deps}"] = ":empty" + cuda_defines["%{win_compiler_deps}"] = ":empty" repository_ctx.file( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "", ) + repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") else: cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" cuda_defines["%{host_compiler_warnings}"] = "" @@ -1035,8 +1276,10 @@ def _create_local_cuda_repository(repository_ctx): if not is_cuda_clang: cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" - nvcc_path = "%s/nvcc" % cuda_config.config["cuda_binary_dir"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext) cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files" wrapper_defines = { "%{cpu_compiler}": str(cc), @@ -1044,12 +1287,26 @@ def _create_local_cuda_repository(repository_ctx): "%{nvcc_path}": nvcc_path, "%{host_compiler_path}": str(cc), "%{use_clang_compiler}": str(is_nvcc_and_clang), + "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), } repository_ctx.template( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"], wrapper_defines, ) + repository_ctx.file( + "crosstool/windows/msvc_wrapper_for_nvcc.bat", + content = "@echo OFF\n{} -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py %*".format( + get_python_bin(repository_ctx), + ), + ) + repository_ctx.template( + "crosstool/windows/msvc_wrapper_for_nvcc.py", + tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"], + wrapper_defines, + ) + + cuda_defines.update(_get_win_cuda_defines(repository_ctx)) verify_build_defines(cuda_defines) @@ -1180,6 +1437,21 @@ def _cuda_autoconf_impl(repository_ctx): repository_ctx.symlink(build_file, "BUILD") +# For @bazel_tools//tools/cpp:windows_cc_configure.bzl +_MSVC_ENVVARS = [ + "BAZEL_VC", + "BAZEL_VC_FULL_VERSION", + "BAZEL_VS", + "BAZEL_WINSDK_FULL_VERSION", + "VS90COMNTOOLS", + "VS100COMNTOOLS", + "VS110COMNTOOLS", + "VS120COMNTOOLS", + "VS140COMNTOOLS", + "VS150COMNTOOLS", + "VS160COMNTOOLS", +] + _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, @@ -1198,7 +1470,7 @@ _ENVIRONS = [ "TMP", "TMPDIR", "TF_CUDA_PATHS", -] +] + _MSVC_ENVVARS remote_cuda_configure = repository_rule( implementation = _create_local_cuda_repository, diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 78292c7b40237a..b88694af5c014d 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -29,6 +29,8 @@ If TF_CUDA_PATHS is not specified, a OS specific default is used: Linux: /usr/local/cuda, /usr, and paths from 'ldconfig -p'. + Windows: CUDA_PATH environment variable, or + C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\* For backwards compatibility, some libraries also use alternative base directories from other environment variables if they are specified. List of @@ -54,6 +56,7 @@ import io import os import glob +import platform import re import subprocess import sys @@ -70,6 +73,18 @@ class ConfigError(Exception): pass +def _is_linux(): + return platform.system() == "Linux" + + +def _is_windows(): + return platform.system() == "Windows" + + +def _is_macos(): + return platform.system() == "Darwin" + + def _matches_version(actual_version, required_version): """Checks whether some version meets the requirements. @@ -119,6 +134,8 @@ def _cartesian_product(first, second): def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" + if not _is_linux(): + return [] ldconfig_path = which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") @@ -139,6 +156,13 @@ def _get_default_cuda_paths(cuda_version): elif not "." in cuda_version: cuda_version = cuda_version + ".*" + if _is_windows(): + return [ + os.environ.get( + "CUDA_PATH", + "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v%s\\" % + cuda_version) + ] return ["/usr/local/cuda-%s" % cuda_version, "/usr/local/cuda", "/usr", "/usr/local/cudnn"] + _get_ld_config_paths() @@ -188,8 +212,14 @@ def _find_file(base_paths, relative_paths, filepattern): def _find_library(base_paths, library_name, required_version): """Returns first valid path to the requested library.""" - filepattern = ".".join(["lib" + library_name, "so"] + - required_version.split(".")[:1]) + "*" + if _is_windows(): + filepattern = library_name + ".lib" + elif _is_macos(): + filepattern = "%s*.dylib" % (".".join(["lib" + library_name] + + required_version.split(".")[:1])) + else: + filepattern = ".".join(["lib" + library_name, "so"] + + required_version.split(".")[:1]) + "*" return _find_file(base_paths, _library_paths(), filepattern) @@ -238,7 +268,7 @@ def get_nvcc_version(path): return match.group(1) return None - nvcc_name = "nvcc" + nvcc_name = "nvcc.exe" if _is_windows() else "nvcc" nvcc_path, nvcc_version = _find_versioned_file(base_paths, [ "", "bin", @@ -528,6 +558,14 @@ def _get_legacy_path(env_name, default=[]): return _list_from_env(env_name, default) +def _normalize_path(path): + """Returns normalized path, with forward slashes on Windows.""" + path = os.path.realpath(path) + if _is_windows(): + path = path.replace("\\", "/") + return path + + def find_cuda_config(): """Returns a dictionary of CUDA library and header file paths.""" libraries = [argv.lower() for argv in sys.argv[1:]] @@ -596,7 +634,7 @@ def find_cuda_config(): for k, v in result.items(): if k.endswith("_dir") or k.endswith("_path"): - result[k] = os.path.realpath(v) + result[k] = _normalize_path(v) return result diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index e4974e79805725..f2f2b14eba7be9 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -29,7 +29,9 @@ third_party/gpus/crosstool/BUILD: third_party/gpus/crosstool/LICENSE: third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl: third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl: +third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl: third_party/gpus/cuda/BUILD.tpl: +third_party/gpus/cuda/BUILD.windows.tpl: third_party/gpus/cuda/BUILD: third_party/gpus/cuda/LICENSE: third_party/gpus/cuda/build_defs.bzl.tpl: diff --git a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py index b7d98ef2581157..afd6380b0ac203 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -23,6 +23,7 @@ """ import os import os.path +import platform import subprocess import sys @@ -38,6 +39,10 @@ class ConfigError(Exception): pass +def _is_windows(): + return platform.system() == "Windows" + + def check_cuda_lib(path, check_soname=True): """Tests if a library exists on disk and whether its soname matches the filename. @@ -52,7 +57,7 @@ def check_cuda_lib(path, check_soname=True): if not os.path.isfile(path): raise ConfigError("No library found under: " + path) objdump = which("objdump") - if check_soname and objdump is not None: + if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = [line for line in output.splitlines() if "SONAME" in line] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl new file mode 100644 index 00000000000000..c46e09484fdfad --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -0,0 +1,256 @@ +#!/usr/bin/env python +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows. + +DESCRIPTION: + This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc +""" + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import tempfile + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('%{cpu_compiler}') +GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') + +NVCC_PATH = '%{nvcc_path}' +NVCC_VERSION = '%{cuda_version}' +NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from options. + + Args: + option: The option whose value to extract. + + Returns: + 1. A list of values, either directly following the option, + (eg., /opt val1 val2) or values collected from multiple occurrences of + the option (eg., /opt val1 /opt val2). + 2. The leftover options. + """ + + parser = ArgumentParser(prefix_chars='-/') + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-/').replace('-', '_') + args, leftover = parser.parse_known_args(argv) + if args and vars(args)[option]: + return (sum(vars(args)[option], []), leftover) + return ([], leftover) + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + 1. The string that can be passed directly to nvcc. + 2. The leftover options. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, leftover = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return (['--' + a for a in options], leftover) + return ([], leftover) + + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling os.system('nvcc ' + args) + """ + + src_files = [f for f in argv if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + if len(src_files) == 0: + raise Error('No source files found for cuda compilation.') + + out_file = [ f for f in argv if f.startswith('/Fo') ] + if len(out_file) != 1: + raise Error('Please specify exactly one output file for cuda compilation.') + out = ['-o', out_file[0][len('/Fo'):]] + + nvcc_compiler_options, argv = GetNvccOptions(argv) + + opt_option, argv = GetOptionValue(argv, '/O') + opt = ['-g'] + if (len(opt_option) > 0 and opt_option[0] != 'd'): + opt = ['-O2'] + + include_options, argv = GetOptionValue(argv, '/I') + includes = ["-I " + include for include in include_options] + + defines, argv = GetOptionValue(argv, '/D') + defines = [ + '-D' + define + for define in defines + if 'BAZEL_CURRENT_REPOSITORY' not in define + ] + + undefines, argv = GetOptionValue(argv, '/U') + undefines = ['-U' + define for define in undefines] + + fatbin_options, argv = GetOptionValue(argv, '-Xcuda-fatbinary') + fatbin_options = ['--fatbin-options=' + option for option in fatbin_options] + + # The rest of the unrecognized options should be passed to host compiler + host_compiler_options = [option for option in argv if option not in (src_files + out_file)] + + m_options = ["-m64"] + + nvccopts = ['-D_FORCE_INLINES'] + compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch") + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability) + ] + compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx') + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability) + ] + _, argv = GetOptionValue(argv, '--no-cuda-include-ptx') + + # nvcc doesn't respect the INCLUDE and LIB env vars from MSVC, + # so we explicity specify the system include paths and library search paths. + if 'INCLUDE' in os.environ: + nvccopts += [('--system-include="%s"' % p) for p in os.environ['INCLUDE'].split(";")] + if 'LIB' in os.environ: + nvccopts += [('--library-path="%s"' % p) for p in os.environ['LIB'].split(";")] + + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += m_options + nvccopts += fatbin_options + nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)] + nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files + # Specify a unique temp directory for nvcc to generate intermediate files, + # then Bazel can ignore files under NVCC_TEMP_DIR during dependency check + # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver + # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists. + if os.path.isfile(NVCC_TEMP_DIR): + os.remove(NVCC_TEMP_DIR) + if not os.path.exists(NVCC_TEMP_DIR): + os.makedirs(NVCC_TEMP_DIR) + # Provide a unique dir for each compiling action to avoid conflicts. + tempdir = tempfile.mkdtemp(dir = NVCC_TEMP_DIR) + nvccopts += ['--keep', '--keep-dir', tempdir] + # Force C++17 dialect (note, everything in just one string!) + nvccopts += ['--std c++17'] + if log: + Log([NVCC_PATH] + nvccopts) + + # Store command line options in a file to avoid hitting the character limit. + optsfile = tempfile.NamedTemporaryFile(mode='w', dir=tempdir, delete=False) + optsfile.write("\n".join(nvccopts)) + optsfile.close() + + proc = subprocess.Popen([NVCC_PATH, "--options-file", optsfile.name], + stdout=sys.stdout, + stderr=sys.stderr, + env=os.environ.copy(), + shell=True) + proc.wait() + return proc.returncode + +def ExpandParamsFileForArgv(): + new_argv = [] + for arg in sys.argv: + if arg.startswith("@"): + with open(arg.strip("@")) as f: + new_argv.extend([l.strip() for l in f.readlines()]) + else: + new_argv.append(arg) + + sys.argv = new_argv + +def ProcessFlagForCommandFile(flag): + if flag.startswith("/D") or flag.startswith("-D"): + # We need to re-escape /DFOO="BAR" as /DFOO=\"BAR\", so that we get + # `#define FOO "BAR"` after expansion as a string literal define + if flag.endswith('"') and not flag.endswith('\\"'): + flag = '\\"'.join(flag.split('"', 1)) + flag = '\\"'.join(flag.rsplit('"', 1)) + return flag + return flag + +def main(): + ExpandParamsFileForArgv() + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log')) + and not flag.startswith(('-nvcc_options'))] + output = [flag for flag in cpu_compiler_flags if flag.startswith("/Fo")] + + # Store command line options in a file to avoid hitting the character limit. + if len(output) == 1: + commandfile_path = output[0][3:] + ".msvc_params" + commandfile = open(commandfile_path, "w") + cpu_compiler_flags = [ProcessFlagForCommandFile(flag) for flag in cpu_compiler_flags] + commandfile.write("\n".join(cpu_compiler_flags)) + commandfile.close() + return subprocess.call([CPU_COMPILER, "@" + commandfile_path]) + else: + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl index 700e040a88eeca..90a18b90de048c 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl @@ -61,23 +61,23 @@ cuda_header_library( cc_library( name = "cudart_static", - srcs = ["cuda/lib/libcudart_static.a"], + srcs = ["cuda/lib/%{cudart_static_lib}"], linkopts = [ "-ldl", - "-lrt", "-lpthread", + %{cudart_static_linkopt} ], ) cc_library( name = "cuda_driver", - srcs = ["cuda/lib/libcuda.so"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], ) cc_library( name = "cudart", - srcs = glob(["cuda/lib/libcudart.so.*"]), - data = glob(["cuda/lib/libcudart.so.*"]), + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], linkstatic = 1, ) @@ -128,30 +128,30 @@ cuda_header_library( cc_library( name = "cublas", - srcs = glob(["cuda/lib/libcublas.so.*"]), - data = glob(["cuda/lib/libcublas.so.*"]), + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], linkstatic = 1, ) cc_library( name = "cublasLt", - srcs = glob(["cuda/lib/libcublasLt.so.*"]), - data = glob(["cuda/lib/libcublasLt.so.*"]), + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], linkstatic = 1, ) cc_library( name = "cusolver", - srcs = glob(["cuda/lib/libcusolver.so.*"]), - data = glob(["cuda/lib/libcusolver.so.*"]), + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) cc_library( name = "cudnn", - srcs = glob(["cuda/lib/libcudnn.so.*"]), - data = glob(["cuda/lib/libcudnn.so.*"]), + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], linkstatic = 1, ) @@ -165,15 +165,15 @@ cc_library( cc_library( name = "cufft", - srcs = glob(["cuda/lib/libcufft.so.*"]), - data = glob(["cuda/lib/libcufft.so.*"]), + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], linkstatic = 1, ) cc_library( name = "curand", - srcs = glob(["cuda/lib/libcurand.so.*"]), - data = glob(["cuda/lib/libcurand.so.*"]), + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], linkstatic = 1, ) @@ -192,7 +192,7 @@ cc_library( alias( name = "cub_headers", - actual = ":cuda_headers", + actual = "%{cub_actual}", ) cuda_header_library( @@ -213,13 +213,13 @@ cuda_header_library( cc_library( name = "cupti_dsos", - data = glob(["cuda/lib/libcupti.so.*"]), + data = ["cuda/lib/%{cupti_lib}"], ) cc_library( name = "cusparse", - srcs = glob(["cuda/lib/libcusparse.so.*"]), - data = glob(["cuda/lib/libcusparse.so.*"]), + srcs = ["cuda/lib/%{cusparse_lib}"], + data = ["cuda/lib/%{cusparse_lib}"], linkopts = ["-lgomp"], linkstatic = 1, ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl new file mode 100644 index 00000000000000..f20ecbd654bf6f --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -0,0 +1,230 @@ +load(":build_defs.bzl", "cuda_header_library") +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cuda_header_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + "cuda/include", + ], +) + +cc_import( + name = "cudart_static", + # /WHOLEARCHIVE:cudart_static.lib will cause a + # "Internal error during CImplib::EmitThunk" error. + # Treat this library as interface library to avoid being whole archived when + # linking a DLL that depends on this. + # TODO(pcloudy): Remove this rule after b/111278841 is resolved. + interface_library = "cuda/lib/%{cudart_static_lib}", + system_provided = 1, +) + +cc_import( + name = "cuda_driver", + interface_library = "cuda/lib/%{cuda_driver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudart", + interface_library = "cuda/lib/%{cudart_lib}", + system_provided = 1, +) + +cuda_header_library( + name = "cublas_headers", + hdrs = [":cublas-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cublas/include"], + strip_include_prefix = "cublas/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusolver_headers", + hdrs = [":cusolver-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusolver/include"], + strip_include_prefix = "cusolver/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cufft_headers", + hdrs = [":cufft-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cufft/include"], + strip_include_prefix = "cufft/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "cusparse_headers", + hdrs = [":cusparse-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["cusparse/include"], + strip_include_prefix = "cusparse/include", + deps = [":cuda_headers"], +) + +cuda_header_library( + name = "curand_headers", + hdrs = [":curand-include"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["curand/include"], + strip_include_prefix = "curand/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cublas", + interface_library = "cuda/lib/%{cublas_lib}", + system_provided = 1, +) + +cc_import( + name = "cublasLt", + interface_library = "cuda/lib/%{cublasLt_lib}", + system_provided = 1, +) + +cc_import( + name = "cusolver", + interface_library = "cuda/lib/%{cusolver_lib}", + system_provided = 1, +) + +cc_import( + name = "cudnn", + interface_library = "cuda/lib/%{cudnn_lib}", + system_provided = 1, +) + +cc_library( + name = "cudnn_header", + hdrs = [":cudnn-include"], + include_prefix = "third_party/gpus/cudnn", + strip_include_prefix = "cudnn/include", + deps = [":cuda_headers"], +) + +cc_import( + name = "cufft", + interface_library = "cuda/lib/%{cufft_lib}", + system_provided = 1, +) + +cc_import( + name = "curand", + interface_library = "cuda/lib/%{curand_lib}", + system_provided = 1, +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = "%{cub_actual}", +) + +cuda_header_library( + name = "cupti_headers", + hdrs = [":cuda-extras"], + include_prefix = "third_party/gpus", + includes = ["cuda/extras/CUPTI/include/"], + deps = [":cuda_headers"], +) + +cc_import( + name = "cupti_dsos", + interface_library = "cuda/lib/%{cupti_lib}", + system_provided = 1, +) + +cc_import( + name = "cusparse", + interface_library = "cuda/lib/%{cusparse_lib}", + system_provided = 1, +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +%{copy_rules} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl index 62df7cc0124568..ff2f2f41091fe8 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -27,14 +27,27 @@ """ load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") +load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", + "get_env_var", +) +load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", + "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", "err_out", "execute", "get_bash_bin", + "get_cpu_value", "get_host_environ", "get_python_bin", + "is_windows", "raw_exec", "read_dir", "realpath", @@ -83,7 +96,16 @@ def verify_build_defines(params): "host_compiler_warnings", "linker_bin_path", "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", "unfiltered_compile_flags", + "win_compiler_deps", ]: if ("%{" + param + "}") not in params: missing.append(param) @@ -97,11 +119,102 @@ def verify_build_defines(params): ".", ) +def _get_nvcc_tmp_dir_for_windows(repository_ctx): + """Return the Windows tmp directory for nvcc to generate intermediate source files.""" + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" + +def _get_msvc_compiler(repository_ctx): + vc_path = find_vc_path(repository_ctx) + return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") + +def _get_win_cuda_defines(repository_ctx): + """Return CROSSTOOL defines for Windows""" + + # If we are not on Windows, return fake vaules for Windows specific fields. + # This ensures the CROSSTOOL file parser is happy. + if not is_windows(repository_ctx): + return { + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + } + + vc_path = find_vc_path(repository_ctx) + if not vc_path: + auto_configure_fail( + "Visual C++ build tools not found on your machine." + + "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", + ) + return {} + + env = setup_vc_env_vars(repository_ctx, vc_path) + escaped_paths = escape_string(env["PATH"]) + escaped_include_paths = escape_string(env["INCLUDE"]) + escaped_lib_paths = escape_string(env["LIB"]) + escaped_tmp_dir = escape_string( + get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( + "\\", + "\\\\", + ), + ) + + msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat" + msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( + "\\", + "/", + ) + msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( + "\\", + "/", + ) + msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( + "\\", + "/", + ) + + # nvcc will generate some temporary source files under %{nvcc_tmp_dir} + # The generated files are guaranteed to have unique name, so they can share + # the same tmp directory + escaped_cxx_include_directories = [ + _get_nvcc_tmp_dir_for_windows(repository_ctx), + "C:\\\\botcode\\\\w", + ] + for path in escaped_include_paths.split(";"): + if path: + escaped_cxx_include_directories.append(path) + + return { + "%{msvc_env_tmp}": escaped_tmp_dir, + "%{msvc_env_path}": escaped_paths, + "%{msvc_env_include}": escaped_include_paths, + "%{msvc_env_lib}": escaped_lib_paths, + "%{msvc_cl_path}": msvc_cl_path, + "%{msvc_ml_path}": msvc_ml_path, + "%{msvc_link_path}": msvc_link_path, + "%{msvc_lib_path}": msvc_lib_path, + "%{cxx_builtin_include_directories}": to_list_of_strings( + escaped_cxx_include_directories, + ), + } + # TODO(dzc): Once these functions have been factored out of Bazel's # cc_configure.bzl, load them from @bazel_tools instead. # BEGIN cc_configure common functions. def find_cc(repository_ctx, use_cuda_clang): """Find the C++ compiler.""" + if is_windows(repository_ctx): + return _get_msvc_compiler(repository_ctx) if use_cuda_clang: target_cc_name = "clang" @@ -252,9 +365,10 @@ def _cuda_include_path(repository_ctx, cuda_config): Returns: A list of the gcc host compiler include directories. """ - nvcc_path = repository_ctx.path( - "%s/bin/nvcc" % cuda_config.cuda_toolkit_path, - ) + nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( + cuda_config.cuda_toolkit_path, + ".exe" if cuda_config.cpu_value == "Windows" else "", + )) # The expected exit code of this command is non-zero. Bazel remote execution # only caches commands with zero exit code. So force a zero exit code. @@ -315,6 +429,10 @@ def matches_version(environ_version, detected_version): return False return True +_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " + +_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" + def compute_capabilities(repository_ctx): """Returns a list of strings representing cuda compute capabilities. @@ -357,11 +475,12 @@ def compute_capabilities(repository_ctx): return capabilities -def lib_name(base_name, version = None, static = False): +def lib_name(base_name, cpu_value, version = None, static = False): """Constructs the platform-specific name of a library. Args: base_name: The name of the library, such as "cudart" + cpu_value: The name of the host operating system. version: The version of the library. static: True the library is static or False if it is a shared object. @@ -369,20 +488,29 @@ def lib_name(base_name, version = None, static = False): The platform-specific name of the library. """ version = "" if not version else "." + version - if static: - return "lib%s.a" % base_name - return "lib%s.so%s" % (base_name, version) + if cpu_value in ("Linux", "FreeBSD"): + if static: + return "lib%s.a" % base_name + return "lib%s.so%s" % (base_name, version) + elif cpu_value == "Windows": + return "%s.lib" % base_name + elif cpu_value == "Darwin": + if static: + return "lib%s.a" % base_name + return "lib%s%s.dylib" % (base_name, version) + else: + auto_configure_fail("Invalid cpu_value: %s" % cpu_value) -def _lib_path(lib, basedir, version, static): - file_name = lib_name(lib, version, static) +def _lib_path(lib, cpu_value, basedir, version, static): + file_name = lib_name(lib, cpu_value, version, static) return "%s/%s" % (basedir, file_name) def _should_check_soname(version, static): return version and not static -def _check_cuda_lib_params(lib, basedir, version, static = False): +def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False): return ( - _lib_path(lib, basedir, version, static), + _lib_path(lib, cpu_value, basedir, version, static), _should_check_soname(version, static), ) @@ -402,6 +530,8 @@ def _check_cuda_libs(repository_ctx, script_path, libs): all_paths = [path for path, _ in libs] checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines() + # Filter out empty lines from splitting on '\r\n' on Windows + checked_paths = [path for path in checked_paths if len(path) > 0] if all_paths != checked_paths: auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths)) @@ -419,62 +549,86 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): Returns: Map of library names to structs of filename and path. """ + cpu_value = cuda_config.cpu_value + stub_dir = "" if is_windows(repository_ctx) else "/stubs" + check_cuda_libs_params = { "cuda": _check_cuda_lib_params( "cuda", - cuda_config.config["cuda_library_dir"] + "/stubs", + cpu_value, + cuda_config.config["cuda_library_dir"] + stub_dir, version = None, + static = False, ), "cudart": _check_cuda_lib_params( "cudart", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, + static = False, ), "cudart_static": _check_cuda_lib_params( "cudart_static", + cpu_value, cuda_config.config["cuda_library_dir"], cuda_config.cudart_version, static = True, ), "cublas": _check_cuda_lib_params( "cublas", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cublasLt": _check_cuda_lib_params( "cublasLt", + cpu_value, cuda_config.config["cublas_library_dir"], cuda_config.cublas_version, + static = False, ), "cusolver": _check_cuda_lib_params( "cusolver", + cpu_value, cuda_config.config["cusolver_library_dir"], cuda_config.cusolver_version, + static = False, ), "curand": _check_cuda_lib_params( "curand", + cpu_value, cuda_config.config["curand_library_dir"], cuda_config.curand_version, + static = False, ), "cufft": _check_cuda_lib_params( "cufft", + cpu_value, cuda_config.config["cufft_library_dir"], cuda_config.cufft_version, + static = False, ), "cudnn": _check_cuda_lib_params( "cudnn", + cpu_value, cuda_config.config["cudnn_library_dir"], cuda_config.cudnn_version, + static = False, ), "cupti": _check_cuda_lib_params( "cupti", + cpu_value, cuda_config.config["cupti_library_dir"], cuda_config.cupti_version, + static = False, ), "cusparse": _check_cuda_lib_params( "cusparse", + cpu_value, cuda_config.config["cusparse_library_dir"], cuda_config.cusparse_version, + static = False, ), } @@ -484,6 +638,10 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()} return paths +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "" if cpu_value == "Darwin" else "\"-lrt\"," + # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, # and nccl_configure.bzl. def find_cuda_config(repository_ctx, cuda_libraries): @@ -510,34 +668,37 @@ def _get_cuda_config(repository_ctx): cudart_version: The CUDA runtime version on the system. cudnn_version: The version of cuDNN on the system. compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. """ config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) + cpu_value = get_cpu_value(repository_ctx) toolkit_path = config["cuda_toolkit_path"] + is_windows = cpu_value == "Windows" cuda_version = config["cuda_version"].split(".") cuda_major = cuda_version[0] cuda_minor = cuda_version[1] - cuda_version = "%s.%s" % (cuda_major, cuda_minor) - cudnn_version = "%s" % config["cudnn_version"] + cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) + cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] if int(cuda_major) >= 11: # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability. if int(cuda_major) == 11: - cudart_version = "11.0" + cudart_version = "64_110" if is_windows else "11.0" cupti_version = cuda_version else: - cudart_version = "%s" % cuda_major + cudart_version = ("64_%s" if is_windows else "%s") % cuda_major cupti_version = cudart_version - cublas_version = "%s" % config["cublas_version"].split(".")[0] - cusolver_version = "%s" % config["cusolver_version"].split(".")[0] - curand_version = "%s" % config["curand_version"].split(".")[0] - cufft_version = "%s" % config["cufft_version"].split(".")[0] - cusparse_version = "%s" % config["cusparse_version"].split(".")[0] + cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0] + cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0] + curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0] + cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0] + cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0] elif (int(cuda_major), int(cuda_minor)) >= (10, 1): # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. # It changed from 'x.y' to just 'x' in CUDA 10.1. - cuda_lib_version = "%s" % cuda_major + cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major cudart_version = cuda_version cupti_version = cuda_version cublas_version = cuda_lib_version @@ -567,6 +728,7 @@ def _get_cuda_config(repository_ctx): cusparse_version = cusparse_version, cudnn_version = cudnn_version, compute_capabilities = compute_capabilities(repository_ctx), + cpu_value = cpu_value, config = config, ) @@ -612,6 +774,8 @@ error_gpu_disabled() """ def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + # Set up BUILD file for cuda/. _tpl( repository_ctx, @@ -626,6 +790,23 @@ def _create_dummy_repository(repository_ctx): repository_ctx, "cuda:BUILD", { + "%{cuda_driver_lib}": lib_name("cuda", cpu_value), + "%{cudart_static_lib}": lib_name( + "cudart_static", + cpu_value, + static = True, + ), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + "%{cudart_lib}": lib_name("cudart", cpu_value), + "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), + "%{cusolver_lib}": lib_name("cusolver", cpu_value), + "%{cudnn_lib}": lib_name("cudnn", cpu_value), + "%{cufft_lib}": lib_name("cufft", cpu_value), + "%{curand_lib}": lib_name("curand", cpu_value), + "%{cupti_lib}": lib_name("cupti", cpu_value), + "%{cusparse_lib}": lib_name("cusparse", cpu_value), + "%{cub_actual}": ":cuda_headers", "%{copy_rules}": """ filegroup(name="cuda-include") filegroup(name="cublas-include") @@ -644,9 +825,20 @@ filegroup(name="cudnn-include") repository_ctx.file("cuda/cuda/include/cublas.h") repository_ctx.file("cuda/cuda/include/cudnn.h") repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") - repository_ctx.file("cuda/cuda/lib/libcuda.so") - repository_ctx.file("cuda/cuda/lib/libcudart_static.a") repository_ctx.file("cuda/cuda/nvml/include/nvml.h") + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) + repository_ctx.file( + "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), + ) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. @@ -710,7 +902,7 @@ def make_copy_files_rule(repository_ctx, name, srcs, outs): cmd = \"""%s \""", )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) -def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): +def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None): """Returns a rule to recursively copy a directory. If exceptions is not None, it must be a list of files or directories in 'src_dir'; these will be excluded from copying. @@ -718,18 +910,27 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): src_dir = _norm_path(src_dir) out_dir = _norm_path(out_dir) outs = read_dir(repository_ctx, src_dir) + post_cmd = "" + if exceptions != None: + outs = [x for x in outs if not any([ + x.startswith(src_dir + "/" + y) + for y in exceptions + ])] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] # '@D' already contains the relative path for a single file, see # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" + if exceptions != None: + for x in exceptions: + post_cmd += " ; rm -fR " + out_dir + "/" + x return """genrule( name = "%s", outs = [ %s ], - cmd = \"""cp -rLf "%s/." "%s/" \""", -)""" % (name, "\n".join(outs), src_dir, out_dir) + cmd = \"""cp -rLf "%s/." "%s/" %s\""", +)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd) def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" @@ -758,6 +959,22 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _tpl_path(repository_ctx, filename): return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) +def _basename(repository_ctx, path_str): + """Returns the basename of a path of type string. + + This method is different from path.basename in that it also works if + the host platform is different from the execution platform + i.e. linux -> windows. + """ + + num_chars = len(path_str) + is_win = is_windows(repository_ctx) + for i in range(num_chars): + r_i = num_chars - 1 - i + if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/": + return path_str[r_i + 1:] + return path_str + def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" @@ -766,14 +983,15 @@ def _create_local_cuda_repository(repository_ctx): # can easily lead to a O(n^2) runtime in the number of labels. # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [ - "cuda:BUILD", "cuda:build_defs.bzl", "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", + "crosstool:windows/msvc_wrapper_for_nvcc.py", "crosstool:BUILD", "crosstool:cc_toolchain_config.bzl", "cuda:cuda_config.h", "cuda:cuda_config.py", ]} + tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD") cuda_config = _get_cuda_config(repository_ctx) @@ -885,7 +1103,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_lib_outs = [] for path in cuda_libs.values(): cuda_lib_srcs.append(path) - cuda_lib_outs.append("cuda/lib/" + path.rpartition("/")[-1]) + cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path)) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-lib", @@ -894,7 +1112,11 @@ def _create_local_cuda_repository(repository_ctx): )) # copy files mentioned in third_party/nccl/build_defs.bzl.tpl - bin_files = ["crt/link.stub", "bin2c", "fatbinary", "nvlink", "nvprune"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + bin_files = ( + ["crt/link.stub"] + + [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]] + ) copy_rules.append(make_copy_files_rule( repository_ctx, name = "cuda-bin", @@ -902,7 +1124,7 @@ def _create_local_cuda_repository(repository_ctx): outs = ["cuda/bin/" + f for f in bin_files], )) - # Select the headers based on the cuDNN version. + # Select the headers based on the cuDNN version (strip '64_' for Windows). cudnn_headers = ["cudnn.h"] if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8": cudnn_headers += [ @@ -943,10 +1165,27 @@ def _create_local_cuda_repository(repository_ctx): }, ) + cub_actual = "@cub_archive//:cub" + if int(cuda_config.cuda_version_major) >= 11: + cub_actual = ":cuda_headers" + repository_ctx.template( "cuda/BUILD", tpl_paths["cuda:BUILD"], { + "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]), + "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]), + "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), + "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), + "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), + "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), + "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), + "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), + "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]), + "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]), + "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]), + "%{cub_actual}": cub_actual, "%{copy_rules}": "\n".join(copy_rules), }, ) @@ -1009,10 +1248,12 @@ def _create_local_cuda_repository(repository_ctx): """ cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) cuda_defines["%{compiler_deps}"] = ":empty" + cuda_defines["%{win_compiler_deps}"] = ":empty" repository_ctx.file( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "", ) + repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") else: cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" cuda_defines["%{host_compiler_warnings}"] = "" @@ -1035,8 +1276,10 @@ def _create_local_cuda_repository(repository_ctx): if not is_cuda_clang: cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" - nvcc_path = "%s/nvcc" % cuda_config.config["cuda_binary_dir"] + file_ext = ".exe" if is_windows(repository_ctx) else "" + nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext) cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files" wrapper_defines = { "%{cpu_compiler}": str(cc), @@ -1044,12 +1287,26 @@ def _create_local_cuda_repository(repository_ctx): "%{nvcc_path}": nvcc_path, "%{host_compiler_path}": str(cc), "%{use_clang_compiler}": str(is_nvcc_and_clang), + "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), } repository_ctx.template( "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"], wrapper_defines, ) + repository_ctx.file( + "crosstool/windows/msvc_wrapper_for_nvcc.bat", + content = "@echo OFF\n{} -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py %*".format( + get_python_bin(repository_ctx), + ), + ) + repository_ctx.template( + "crosstool/windows/msvc_wrapper_for_nvcc.py", + tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"], + wrapper_defines, + ) + + cuda_defines.update(_get_win_cuda_defines(repository_ctx)) verify_build_defines(cuda_defines) @@ -1180,6 +1437,21 @@ def _cuda_autoconf_impl(repository_ctx): repository_ctx.symlink(build_file, "BUILD") +# For @bazel_tools//tools/cpp:windows_cc_configure.bzl +_MSVC_ENVVARS = [ + "BAZEL_VC", + "BAZEL_VC_FULL_VERSION", + "BAZEL_VS", + "BAZEL_WINSDK_FULL_VERSION", + "VS90COMNTOOLS", + "VS100COMNTOOLS", + "VS110COMNTOOLS", + "VS120COMNTOOLS", + "VS140COMNTOOLS", + "VS150COMNTOOLS", + "VS160COMNTOOLS", +] + _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, @@ -1198,7 +1470,7 @@ _ENVIRONS = [ "TMP", "TMPDIR", "TF_CUDA_PATHS", -] +] + _MSVC_ENVVARS remote_cuda_configure = repository_rule( implementation = _create_local_cuda_repository, diff --git a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py index 78292c7b40237a..b88694af5c014d 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -29,6 +29,8 @@ If TF_CUDA_PATHS is not specified, a OS specific default is used: Linux: /usr/local/cuda, /usr, and paths from 'ldconfig -p'. + Windows: CUDA_PATH environment variable, or + C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\* For backwards compatibility, some libraries also use alternative base directories from other environment variables if they are specified. List of @@ -54,6 +56,7 @@ import io import os import glob +import platform import re import subprocess import sys @@ -70,6 +73,18 @@ class ConfigError(Exception): pass +def _is_linux(): + return platform.system() == "Linux" + + +def _is_windows(): + return platform.system() == "Windows" + + +def _is_macos(): + return platform.system() == "Darwin" + + def _matches_version(actual_version, required_version): """Checks whether some version meets the requirements. @@ -119,6 +134,8 @@ def _cartesian_product(first, second): def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" + if not _is_linux(): + return [] ldconfig_path = which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") @@ -139,6 +156,13 @@ def _get_default_cuda_paths(cuda_version): elif not "." in cuda_version: cuda_version = cuda_version + ".*" + if _is_windows(): + return [ + os.environ.get( + "CUDA_PATH", + "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v%s\\" % + cuda_version) + ] return ["/usr/local/cuda-%s" % cuda_version, "/usr/local/cuda", "/usr", "/usr/local/cudnn"] + _get_ld_config_paths() @@ -188,8 +212,14 @@ def _find_file(base_paths, relative_paths, filepattern): def _find_library(base_paths, library_name, required_version): """Returns first valid path to the requested library.""" - filepattern = ".".join(["lib" + library_name, "so"] + - required_version.split(".")[:1]) + "*" + if _is_windows(): + filepattern = library_name + ".lib" + elif _is_macos(): + filepattern = "%s*.dylib" % (".".join(["lib" + library_name] + + required_version.split(".")[:1])) + else: + filepattern = ".".join(["lib" + library_name, "so"] + + required_version.split(".")[:1]) + "*" return _find_file(base_paths, _library_paths(), filepattern) @@ -238,7 +268,7 @@ def get_nvcc_version(path): return match.group(1) return None - nvcc_name = "nvcc" + nvcc_name = "nvcc.exe" if _is_windows() else "nvcc" nvcc_path, nvcc_version = _find_versioned_file(base_paths, [ "", "bin", @@ -528,6 +558,14 @@ def _get_legacy_path(env_name, default=[]): return _list_from_env(env_name, default) +def _normalize_path(path): + """Returns normalized path, with forward slashes on Windows.""" + path = os.path.realpath(path) + if _is_windows(): + path = path.replace("\\", "/") + return path + + def find_cuda_config(): """Returns a dictionary of CUDA library and header file paths.""" libraries = [argv.lower() for argv in sys.argv[1:]] @@ -596,7 +634,7 @@ def find_cuda_config(): for k, v in result.items(): if k.endswith("_dir") or k.endswith("_path"): - result[k] = os.path.realpath(v) + result[k] = _normalize_path(v) return result diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD index 93f35c45c0569d..2d6dfda0028a1b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config/BUILD @@ -117,10 +117,16 @@ cc_library( data = [ "@local_config_cuda//cuda:cudart", ], - linkopts = [ - "-Wl,-rpath,../local_config_cuda/cuda/lib64", - "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", - ], + linkopts = select({ + "//tsl:macos": [ + "-Wl,-rpath,../local_config_cuda/cuda/lib", + "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib", + ], + "//conditions:default": [ + "-Wl,-rpath,../local_config_cuda/cuda/lib64", + "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", + ], + }), visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cudart", From 60a80cc80cbb59942e27f45596fd016837732437 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 10:08:48 -0800 Subject: [PATCH 025/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/fa14cd8fbac47b3545e91b387df41d18262ead38. PiperOrigin-RevId: 582002759 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 4ca750a18cce0f..d92da02faa0943 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "39a30d24eb74fbfaa24ab3d5917bf01715add6a0" - TFRT_SHA256 = "3bf036c9ce3a86805deb34b3ea2e7658428b050501bd86762db265d462a3e8cd" + TFRT_COMMIT = "fa14cd8fbac47b3545e91b387df41d18262ead38" + TFRT_SHA256 = "1a8771c039520824dd66b404bf56bb4a387089fd1497b1ceace52e2bf3ce35f2" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 4ca750a18cce0f..d92da02faa0943 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "39a30d24eb74fbfaa24ab3d5917bf01715add6a0" - TFRT_SHA256 = "3bf036c9ce3a86805deb34b3ea2e7658428b050501bd86762db265d462a3e8cd" + TFRT_COMMIT = "fa14cd8fbac47b3545e91b387df41d18262ead38" + TFRT_SHA256 = "1a8771c039520824dd66b404bf56bb4a387089fd1497b1ceace52e2bf3ce35f2" tf_http_archive( name = "tf_runtime", From 3c8de558d91c60d01903966535c25770ed90daf7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 13 Nov 2023 10:27:36 -0800 Subject: [PATCH 026/391] [xla:gpu] Add a pass for pattern matching custom fusions This is initial version of the pass that only supports custom fusions consisting of a single HLO instruction. PiperOrigin-RevId: 582009307 --- third_party/xla/xla/service/gpu/BUILD | 33 +++++ .../xla/service/gpu/custom_fusion_rewriter.cc | 127 ++++++++++++++++++ .../xla/service/gpu/custom_fusion_rewriter.h | 79 +++++++++++ .../gpu/custom_fusion_rewriter_test.cc | 87 ++++++++++++ third_party/xla/xla/service/gpu/kernels/BUILD | 11 ++ .../gpu/kernels/custom_fusion_pattern.cc | 41 ++++++ .../gpu/kernels/custom_fusion_pattern.h | 74 ++++++++++ 7 files changed, 452 insertions(+) create mode 100644 third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc create mode 100644 third_party/xla/xla/service/gpu/custom_fusion_rewriter.h create mode 100644 third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc create mode 100644 third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 24b429d3a49037..20b36c1ec5204d 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2662,6 +2662,39 @@ xla_cc_test( ], ) +cc_library( + name = "custom_fusion_rewriter", + srcs = ["custom_fusion_rewriter.cc"], + hdrs = ["custom_fusion_rewriter.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu/kernels:custom_fusion_pattern", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "custom_fusion_rewriter_test", + srcs = ["custom_fusion_rewriter_test.cc"], + deps = [ + ":custom_fusion_rewriter", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/kernels:custom_fusion_pattern", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "fusion_pipeline", srcs = ["fusion_pipeline.cc"], diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc new file mode 100644 index 00000000000000..382cb50d410ce6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -0,0 +1,127 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_fusion_rewriter.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla::gpu { + +using xla::gpu::kernel::CustomFusionPattern; +using xla::gpu::kernel::CustomFusionPatternRegistry; + +CustomFusionRewriter::CustomFusionRewriter( + const CustomFusionPatternRegistry* patterns) + : patterns_(patterns) {} + +// Creates custom fusion computation and moves all matched instructions into it. +static StatusOr CreateFusionBody( + HloModule* module, const CustomFusionPattern::Match& match) { + HloComputation::Builder builder(match.config.name()); + + // We do not currently support matching custom fusions with more than one + // instruction. + HloInstruction* root = match.instructions[0]; + + // Fusion computation parameters inferred from a matched instruction. + absl::InlinedVector parameters; + for (HloInstruction* operand : root->operands()) { + parameters.push_back(builder.AddInstruction( + HloInstruction::CreateParameter(parameters.size(), operand->shape(), + absl::StrCat("p", parameters.size())))); + } + + builder.AddInstruction(root->CloneWithNewOperands(root->shape(), parameters)); + + return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); +} + +static StatusOr CreateFusionInstruction( + HloModule* module, const CustomFusionPattern::Match& match, + HloComputation* body) { + // We'll be replacing the root operation of a custom fusion with a fusion + // instruction calling fusion computation. + HloInstruction* fusion_root = match.instructions[0]; + HloComputation* fusion_parent = fusion_root->parent(); + + HloInstruction* fusion = + fusion_parent->AddInstruction(HloInstruction::CreateFusion( + fusion_root->shape(), HloInstruction::FusionKind::kCustom, + fusion_root->operands(), body)); + + // Assign unique name to a new fusion instruction. + module->SetAndUniquifyInstrName(fusion, match.config.name()); + + // Set backends config to a matched custom fusion config. + FusionBackendConfig backend_config; + backend_config.set_kind("__custom_fusion"); + *backend_config.mutable_custom_fusion_config() = match.config; + TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(backend_config))); + + // Replace fusion root with a fusion instruction. + TF_RETURN_IF_ERROR(fusion_parent->ReplaceInstruction(fusion_root, fusion)); + + return fusion; +} + +StatusOr CustomFusionRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::vector matches; + + // Collect all potential custom fusion matches in the module. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + auto matched = patterns_->Match(instr); + matches.insert(matches.end(), matched.begin(), matched.end()); + } + } + + if (matches.empty()) return false; + + for (const CustomFusionPattern::Match& match : matches) { + if (match.instructions.size() != 1) + return absl::InternalError( + "Custom fusions with multiple instruction are not yet supported"); + + TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, + CreateFusionBody(module, match)); + + TF_ASSIGN_OR_RETURN(HloInstruction * fusion, + CreateFusionInstruction(module, match, fusion_body)); + + VLOG(5) << "Added a fusion instruction: " << fusion->name() + << " for custom fusion " << match.config.name() + << " (instruction count = " << match.instructions.size() << ")"; + } + + return true; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h new file mode 100644 index 00000000000000..2db45ea0c8558f --- /dev/null +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h @@ -0,0 +1,79 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ +#define XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla::gpu { + +// Pattern matches HLO instruction to custom fusions (hand written CUDA C++ +// kernels, e.g. custom GEMMs implemented with CUTLASS) and rewrites them into +// fusion instructions and fusion computations. +// +// Example: pattern matching dot operation into CUTLASS gemm +// +// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), +// lhs_contracting_dims={1}, rhs_contracting_dims={0} +// } +// +// After the pass: +// +// %cutlass_gemm (p0: f16[19,17], p1: f16[15,19]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), +// lhs_contracting_dims={1}, rhs_contracting_dims={0} +// } +// +// ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { +// %p0 = f16[15,19]{1,0} parameter(0) +// %p1 = f16[19,17]{1,0} parameter(1) +// ROOT %r = f16[15,17]{1,0} fusion(%p0, %p1), kind=kCustom, +// calls==cutlass_gemm, +// backend_config={kind: "__custom_fusion", +// custom_fusion_config: {"name":"cutlass_gemm"}} +// } +// +class CustomFusionRewriter : public HloModulePass { + public: + explicit CustomFusionRewriter( + const kernel::CustomFusionPatternRegistry* patterns); + + absl::string_view name() const override { return "custom-fusion-rewriter"; } + + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const kernel::CustomFusionPatternRegistry* patterns_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_CUSTOM_FUSION_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc new file mode 100644 index 00000000000000..12395d0316bfb1 --- /dev/null +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_fusion_rewriter.h" + +#include +#include + +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// Simple pattern matchers for testing custom fusion rewriter. +//===----------------------------------------------------------------------===// + +class SimpleGemmPattern : public kernel::CustomFusionPattern { + public: + std::optional TryMatch(HloInstruction* instr) const override { + if (auto* dot = DynCast(instr)) { + CustomFusionConfig config; + config.set_name("simple_gemm"); + return Match{config, {instr}}; + } + return std::nullopt; + } +}; + +//===----------------------------------------------------------------------===// + +class CustomFusionRewriterTest : public HloTestBase {}; + +TEST_F(CustomFusionRewriterTest, SimpleGemm) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f16[15,19], p1: f16[19,17]) -> f16[15,17] { + %p0 = f16[15,19]{1,0} parameter(0) + %p1 = f16[19,17]{1,0} parameter(1) + ROOT %r = f16[15,17]{1,0} dot(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %simple_gemm {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0) + ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1) + ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]), + ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f16[15,17]{1,0} fusion + ; CHECK: kind=kCustom, calls=%simple_gemm, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} + ; CHECK: } + ; CHECK: } + )"; + + kernel::CustomFusionPatternRegistry patterns; + patterns.Emplace(); + + CustomFusionRewriter pass(&patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 43edb6503fd86c..489e49c8d2b6a7 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -33,6 +33,17 @@ cc_library( ], ) +cc_library( + name = "custom_fusion_pattern", + srcs = ["custom_fusion_pattern.cc"], + hdrs = ["custom_fusion_pattern.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + ], +) + cc_library( name = "custom_kernel", srcs = ["custom_kernel.cc"], diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc new file mode 100644 index 00000000000000..304d58045d993d --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla::gpu::kernel { + +std::vector CustomFusionPatternRegistry::Match( + HloInstruction* instr) const { + std::vector matches; + for (auto& pattern : patterns_) { + if (auto matched = pattern->TryMatch(instr); matched.has_value()) + matches.push_back(std::move(*matched)); + } + return matches; +} + +void CustomFusionPatternRegistry::Add( + std::unique_ptr pattern) { + patterns_.push_back(std::move(pattern)); +} + +} // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h new file mode 100644 index 00000000000000..43b1428b0ce4cd --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h @@ -0,0 +1,74 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ +#define XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ + +#include +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" + +namespace xla::gpu::kernel { + +//===----------------------------------------------------------------------===// +// CustomFusionPattern +//===----------------------------------------------------------------------===// + +// Custom fusion pattern matches HLO instruction to custom kernels. +class CustomFusionPattern { + public: + virtual ~CustomFusionPattern() = default; + + struct Match { + CustomFusionConfig config; + std::vector instructions; + }; + + // Returns custom fusion config and a list of instructions that matched to a + // custom fusion (one or more custom kernels). Custom fusion pass will outline + // matched instructions into a custom fusion operation if possible. + // + // TODO(ezhulenev): Today the last instruction defines custom fusion root + // (results), however we need to add support for custom fusion that can return + // intermediate result, and custom fusions that require an extra workspace. + virtual std::optional TryMatch(HloInstruction *instr) const = 0; +}; + +//===----------------------------------------------------------------------===// +// CustomFusionPatternRegistry +//===----------------------------------------------------------------------===// + +class CustomFusionPatternRegistry { + public: + std::vector Match(HloInstruction *instr) const; + + void Add(std::unique_ptr pattern); + + template > + void Emplace() { + (Add(std::make_unique()), ...); + } + + private: + std::vector> patterns_; +}; + +} // namespace xla::gpu::kernel + +#endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ From 6afe016cb5b358a9040c8c786d33de2c10519d7c Mon Sep 17 00:00:00 2001 From: Robert David Date: Mon, 13 Nov 2023 10:57:47 -0800 Subject: [PATCH 027/391] Refactor `SimpleMemoryArena`, extract code that deals with the resizable aligned buffer to a separate class. PiperOrigin-RevId: 582019459 --- tensorflow/lite/simple_memory_arena.cc | 119 ++++++++++++++----------- tensorflow/lite/simple_memory_arena.h | 62 +++++++++---- 2 files changed, 110 insertions(+), 71 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 0f107be50691dc..f0b5f281985539 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -19,14 +19,17 @@ limitations under the License. #include #include +#include #include #include #include #include +#include #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" + #ifdef TF_LITE_TENSORFLOW_PROFILER #include "tensorflow/lite/tensorflow_profiler_logger.h" #endif // TF_LITE_TENSORFLOW_PROFILER @@ -43,6 +46,56 @@ T AlignTo(size_t alignment, T offset) { namespace tflite { +bool ResizableAlignedBuffer::Resize(size_t new_size) { + const size_t new_allocation_size = RequiredAllocationSize(new_size); + if (new_allocation_size <= allocation_size_) { + // Skip reallocation when resizing down. + return false; + } +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/true); + OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), + new_allocation_size); +#endif + auto new_buffer = std::make_unique(new_allocation_size); + char* new_aligned_ptr = reinterpret_cast( + AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); + if (new_size > 0 && allocation_size_ > 0) { + // Copy data when both old and new buffers are bigger than 0 bytes. + const size_t new_alloc_alignment_adjustment = + new_aligned_ptr - new_buffer.get(); + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); + const size_t copy_amount = + std::min(allocation_size_ - old_alloc_alignment_adjustment, + new_allocation_size - new_alloc_alignment_adjustment); + memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); + } +#ifdef TF_LITE_TENSORFLOW_PROFILER + if (allocation_size_ > 0) { + OnTfLiteArenaDealloc(subgraph_index_, + reinterpret_cast(this), + allocation_size_); + } +#endif + buffer_ = std::move(new_buffer); + allocation_size_ = new_allocation_size; + aligned_ptr_ = new_aligned_ptr; +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/false); +#endif + return true; +} + +void ResizableAlignedBuffer::Release() { +#ifdef TF_LITE_TENSORFLOW_PROFILER + OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), + allocation_size_); +#endif + buffer_.reset(); + allocation_size_ = 0; + aligned_ptr_ = nullptr; +} + void SimpleMemoryArena::PurgeAfter(int32_t node) { for (int i = 0; i < active_allocs_.size(); ++i) { if (active_allocs_[i].first_node > node) { @@ -90,7 +143,7 @@ TfLiteStatus SimpleMemoryArena::Allocate( TfLiteContext* context, size_t alignment, size_t size, int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc) { - TF_LITE_ENSURE(context, alignment <= arena_alignment_); + TF_LITE_ENSURE(context, alignment <= underlying_buffer_.GetAlignment()); new_alloc->tensor = tensor; new_alloc->first_node = first_node; new_alloc->last_node = last_node; @@ -141,48 +194,12 @@ TfLiteStatus SimpleMemoryArena::Allocate( } TfLiteStatus SimpleMemoryArena::Commit(bool* arena_reallocated) { - size_t required_size = RequiredBufferSize(); - if (required_size > underlying_buffer_size_) { - *arena_reallocated = true; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/true); - OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), - required_size); -#endif - char* new_alloc = new char[required_size]; - char* new_underlying_buffer_aligned_ptr = reinterpret_cast( - AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); - - // If the arena had been previously allocated, copy over the old memory. - // Since Alloc pointers are offset based, they will remain valid in the new - // memory block. - if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { - size_t copy_amount = std::min( - underlying_buffer_.get() + underlying_buffer_size_ - - underlying_buffer_aligned_ptr_, - new_alloc + required_size - new_underlying_buffer_aligned_ptr); - memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, - copy_amount); - } - -#ifdef TF_LITE_TENSORFLOW_PROFILER - if (underlying_buffer_size_ > 0) { - OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - underlying_buffer_size_); - } -#endif - underlying_buffer_.reset(new_alloc); - underlying_buffer_size_ = required_size; - underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/false); -#endif - } else { - *arena_reallocated = false; - } + // Resize the arena to the high water mark (calculated by Allocate), retaining + // old contents and alignment in the process. Since Alloc pointers are offset + // based, they will remain valid in the new memory block. + *arena_reallocated = underlying_buffer_.Resize(high_water_mark_); committed_ = true; - return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; + return kTfLiteOk; } TfLiteStatus SimpleMemoryArena::ResolveAlloc( @@ -190,12 +207,12 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - TF_LITE_ENSURE(context, - underlying_buffer_size_ >= (alloc.offset + alloc.size)); + TF_LITE_ENSURE(context, underlying_buffer_.GetAllocationSize() >= + (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + *output_ptr = underlying_buffer_.GetPtr() + alloc.offset; } return kTfLiteOk; } @@ -209,13 +226,7 @@ TfLiteStatus SimpleMemoryArena::ClearPlan() { TfLiteStatus SimpleMemoryArena::ReleaseBuffer() { committed_ = false; -#ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - underlying_buffer_size_); -#endif - underlying_buffer_size_ = 0; - underlying_buffer_aligned_ptr_ = nullptr; - underlying_buffer_.reset(); + underlying_buffer_.Release(); return kTfLiteOk; } @@ -227,8 +238,8 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_size_, - active_allocs_); + tflite::DumpArenaInfo(name, execution_plan, + underlying_buffer_.GetAllocationSize(), active_allocs_); } } // namespace tflite diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 0e527df9ac98b1..05bb52e6a225e4 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -55,6 +55,44 @@ struct ArenaAllocWithUsageInterval { } }; +class ResizableAlignedBuffer { + public: + explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) + : allocation_size_(0), + alignment_(alignment), + subgraph_index_(subgraph_index) { + // To silence unused private member warning, only used with + // TF_LITE_TENSORFLOW_PROFILER + (void)subgraph_index_; + } + + // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps + // alignment and any existing the data. Returns true when any external + // pointers into the data array need to be adjusted (the buffer was moved). + bool Resize(size_t new_size); + // Releases any allocated memory. + void Release(); + + // Pointer to the data array. + char* GetPtr() const { return aligned_ptr_; } + // Size of the allocation (NOT of the data array). + size_t GetAllocationSize() const { return allocation_size_; } + // Alignment of the data array. + size_t GetAlignment() const { return alignment_; } + + private: + size_t RequiredAllocationSize(size_t data_array_size) const { + return data_array_size + alignment_ - 1; + } + + std::unique_ptr buffer_; + size_t allocation_size_; + size_t alignment_; + char* aligned_ptr_; + + int subgraph_index_; +}; + // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is @@ -63,11 +101,9 @@ struct ArenaAllocWithUsageInterval { class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment, int subgraph_index = 0) - : subgraph_index_(subgraph_index), - committed_(false), - arena_alignment_(arena_alignment), + : committed_(false), high_water_mark_(0), - underlying_buffer_size_(0), + underlying_buffer_(arena_alignment, subgraph_index), active_allocs_() {} // Delete all allocs. This should be called when allocating the first node of @@ -99,10 +135,6 @@ class SimpleMemoryArena { int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc); - inline size_t RequiredBufferSize() { - return high_water_mark_ + arena_alignment_ - 1; - } - TfLiteStatus Commit(bool* arena_reallocated); TfLiteStatus ResolveAlloc(TfLiteContext* context, @@ -119,10 +151,12 @@ class SimpleMemoryArena { // again until Commit() is called & tensor allocations are resolved. TfLiteStatus ReleaseBuffer(); - size_t GetBufferSize() const { return underlying_buffer_size_; } + size_t GetBufferSize() const { + return underlying_buffer_.GetAllocationSize(); + } std::intptr_t BasePointer() const { - return reinterpret_cast(underlying_buffer_aligned_ptr_); + return reinterpret_cast(underlying_buffer_.GetPtr()); } // Dumps the memory allocation information of this memory arena (which could @@ -142,16 +176,10 @@ class SimpleMemoryArena { void DumpDebugInfo(const std::string& name, const std::vector& execution_plan) const; - protected: - int subgraph_index_; - private: bool committed_; - size_t arena_alignment_; size_t high_water_mark_; - std::unique_ptr underlying_buffer_; - size_t underlying_buffer_size_; - char* underlying_buffer_aligned_ptr_; + ResizableAlignedBuffer underlying_buffer_; std::vector active_allocs_; }; From 2aa6bd5b04751a7100afff411ab2abbd58398c7b Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Mon, 13 Nov 2023 11:33:20 -0800 Subject: [PATCH 028/391] Fix poorly worded comment. PiperOrigin-RevId: 582031434 --- third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h index 93cf5a6dba06d9..09b3689274438a 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -284,7 +284,8 @@ class DfsHloVisitorWithDefaultBase delete; }; -// Users should use these type aliases which are only two valid instantiations. +// Users should use one of these two type aliases, which are the only two valid +// instantiations of DfsHloVisitorWithDefaultBase. using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; using ConstDfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; From c2ea26cdf2afaa5c9b6c334fb251b442885ceafd Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 13 Nov 2023 11:54:33 -0800 Subject: [PATCH 029/391] [xla:gpu] NFC: Decouple construction of FusionEmissionResult from EmitFusion #6528 EmitFusion should call a subroutine that construct a FusionEmissionResult, then add the thunks from the result to the executable. This is needed by command buffer emission because it only needs FusionEmissionResult as a intermediate result. PiperOrigin-RevId: 582038375 --- .../xla/service/gpu/ir_emitter_unnested.cc | 148 +++++++++++------- .../xla/xla/service/gpu/ir_emitter_unnested.h | 17 +- 2 files changed, 101 insertions(+), 64 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index d8e2909f193737..d1cfa6012679f5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1938,7 +1938,7 @@ static Status ProcessFusionForConversion(mlir::Region* region, } #if GOOGLE_CUDA -Status IrEmitterUnnested::EmitTritonFusion( +StatusOr IrEmitterUnnested::EmitTritonFusion( const HloFusionAnalysis& hlo_fusion_analysis, const HloFusionInstruction* fusion, mlir::Operation* op) { // Note: In this method we can't use `BuildKernelThunk` as usual, @@ -2052,10 +2052,13 @@ Status IrEmitterUnnested::EmitTritonFusion( } else { fusion_op = op; } - AddThunkToThunkSequence(std::make_unique( + + FusionEmissionResult result; + result.thunks.emplace_back(std::make_unique( fusion_op, kernel->kernel_name, kernel_arguments.args(), kernel->launch_dimensions, kernel->shmem_bytes)); - return OkStatus(); + + return result; } #endif // GOOGLE_CUDA @@ -2089,35 +2092,63 @@ bool IsSpecializedLoopFusion( Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, HloFusionAnalysis& fusion_analysis) { - std::unique_ptr emitter; + FusionEmissionResult emission_result; switch (fusion_analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - emitter = std::make_unique(fusion_analysis); + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: { + auto emitter = std::make_unique(fusion_analysis); + TF_ASSIGN_OR_RETURN( + emission_result, + emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, + *instr, kernel_reuse_cache_, &b_)); break; - case HloFusionAnalysis::EmitterFusionKind::kLoop: + } + case HloFusionAnalysis::EmitterFusionKind::kLoop: { // TODO(anlunx): Support MemcpyFusion and InPlaceDymaicUpdateSlice. - emitter = std::make_unique(fusion_analysis); + auto emitter = std::make_unique(fusion_analysis); + TF_ASSIGN_OR_RETURN( + emission_result, + emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, + *instr, kernel_reuse_cache_, &b_)); break; - case HloFusionAnalysis::EmitterFusionKind::kTranspose: - emitter = std::make_unique(fusion_analysis); + } + case HloFusionAnalysis::EmitterFusionKind::kTranspose: { + auto emitter = std::make_unique(fusion_analysis); + TF_ASSIGN_OR_RETURN( + emission_result, + emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, + *instr, kernel_reuse_cache_, &b_)); break; - case HloFusionAnalysis::EmitterFusionKind::kReduction: - emitter = std::make_unique(fusion_analysis); + } + case HloFusionAnalysis::EmitterFusionKind::kReduction: { + auto emitter = std::make_unique(fusion_analysis); + TF_ASSIGN_OR_RETURN( + emission_result, + emitter->Emit(*ir_emitter_context_, elemental_emitter_, nullptr, + *instr, kernel_reuse_cache_, &b_)); break; + } case HloFusionAnalysis::EmitterFusionKind::kTriton: { TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); #if GOOGLE_CUDA - return EmitTritonFusion(fusion_analysis, instr, nullptr); + TF_ASSIGN_OR_RETURN(emission_result, + EmitTritonFusion(fusion_analysis, instr, nullptr)); + break; #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } - case HloFusionAnalysis::EmitterFusionKind::kScatter: - return EmitScatter(instr, nullptr, fusion_analysis); + case HloFusionAnalysis::EmitterFusionKind::kScatter: { + TF_ASSIGN_OR_RETURN(emission_result, + EmitScatter(instr, nullptr, fusion_analysis)); + break; + } case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); - return EmitCustomFusion(instr, backend_config.custom_fusion_config()); + TF_ASSIGN_OR_RETURN( + emission_result, + EmitCustomFusion(instr, backend_config.custom_fusion_config())); + break; } default: return FailedPrecondition( @@ -2125,9 +2156,6 @@ Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr, break; } - TF_ASSIGN_OR_RETURN(FusionEmissionResult emission_result, - emitter->Emit(*ir_emitter_context_, elemental_emitter_, - nullptr, *instr, kernel_reuse_cache_, &b_)); for (auto& thunk : emission_result.thunks) { AddThunkToThunkSequence(std::move(thunk)); } @@ -2161,42 +2189,47 @@ Status IrEmitterUnnested::EmitFusion( TF_ASSIGN_OR_RETURN(auto fusion_analysis, HloFusionAnalysis::Create(fusion, &device_info)); - auto emitter = GetFusionEmitter( - fusion_analysis, ir_emitter_context_->allocations(), fusion_op); - if (emitter != std::nullopt) { - TF_ASSIGN_OR_RETURN( - auto emission_result, - (*emitter)->Emit(*ir_emitter_context_, elemental_emitter_, fusion_op, - *fusion, kernel_reuse_cache_, &b_)); - for (auto& thunk : emission_result.thunks) { - AddThunkToThunkSequence(std::move(thunk)); - } - return OkStatus(); - } - - // Dispatch to the fusion specific emitter. + FusionEmissionResult emission_result; auto emitter_fusion_kind = fusion_analysis.GetEmitterFusionKind(); switch (emitter_fusion_kind) { + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: + case HloFusionAnalysis::EmitterFusionKind::kLoop: + case HloFusionAnalysis::EmitterFusionKind::kReduction: + case HloFusionAnalysis::EmitterFusionKind::kTranspose: { + std::optional> emitter = + GetFusionEmitter(fusion_analysis, ir_emitter_context_->allocations(), + fusion_op); + if (emitter == std::nullopt) { + return FailedPrecondition( + "Fusion should have been handled by GetFusionEmitter."); + } + TF_ASSIGN_OR_RETURN( + emission_result, + (*emitter)->Emit(*ir_emitter_context_, elemental_emitter_, fusion_op, + *fusion, kernel_reuse_cache_, &b_)); + break; + } case HloFusionAnalysis::EmitterFusionKind::kTriton: { #if GOOGLE_CUDA - return EmitTritonFusion(fusion_analysis, fusion, fusion_op); + TF_ASSIGN_OR_RETURN(emission_result, + EmitTritonFusion(fusion_analysis, fusion, fusion_op)); + break; #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kScatter: { + TF_ASSIGN_OR_RETURN(emission_result, + EmitScatter(fusion, fusion_op, fusion_analysis)); + break; + } case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: - if (!backend_config.has_custom_fusion_config()) - return absl::InternalError( - "custom fusion is missing custom fusion config"); - return EmitCustomFusion(fusion, backend_config.custom_fusion_config()); - case HloFusionAnalysis::EmitterFusionKind::kScatter: - return EmitScatter(fusion, fusion_op, fusion_analysis); - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - case HloFusionAnalysis::EmitterFusionKind::kLoop: - case HloFusionAnalysis::EmitterFusionKind::kReduction: - case HloFusionAnalysis::EmitterFusionKind::kTranspose: - return FailedPrecondition( - "Fusion should have been handled by GetFusionEmitter."); + LOG(FATAL) << "kCustomFusion is not supported by JitRt runtime"; + } + + for (auto& thunk : emission_result.thunks) { + AddThunkToThunkSequence(std::move(thunk)); } + return OkStatus(); } Status IrEmitterUnnested::AssertNonDeterminismIsOkay( @@ -3151,9 +3184,9 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } -Status IrEmitterUnnested::EmitScatter(const HloFusionInstruction* fusion, - mlir::lmhlo::FusionOp fusion_op, - HloFusionAnalysis& fusion_analysis) { +StatusOr IrEmitterUnnested::EmitScatter( + const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, + HloFusionAnalysis& fusion_analysis) { auto* fused_computation = fusion->fused_instructions_computation(); auto* root = fused_computation->root_instruction(); @@ -3204,17 +3237,19 @@ Status IrEmitterUnnested::EmitScatter(const HloFusionInstruction* fusion, return EmitScatter(desc, launch_dimensions); }; - TF_ASSIGN_OR_RETURN(auto thunk, + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel_thunk, BuildKernelThunkForFusion( *ir_emitter_context_, kernel_reuse_cache_, fusion, fusion_op, fused_computation, launch_dimensions, /*discriminator=*/"scatter", builder_fn, &b_)); - AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + + FusionEmissionResult result; + result.thunks.push_back(std::move(kernel_thunk)); + return result; } -Status IrEmitterUnnested::EmitCustomFusion(const HloFusionInstruction* fusion, - const CustomFusionConfig& config) { +StatusOr IrEmitterUnnested::EmitCustomFusion( + const HloFusionInstruction* fusion, const CustomFusionConfig& config) { VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); auto* registry = kernel::CustomFusionRegistry::Default(); @@ -3249,8 +3284,9 @@ Status IrEmitterUnnested::EmitCustomFusion(const HloFusionInstruction* fusion, auto thunk, BuildCustomKernelThunkForFusion(*ir_emitter_context_, fusion, std::move(kernels[0]))); - AddThunkToThunkSequence(std::move(thunk)); - return OkStatus(); + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; } Status IrEmitterUnnested::EmitOp( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 3fc2255488b55a..484910eef12de2 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter.h" @@ -141,9 +142,9 @@ class IrEmitterUnnested : public IrEmitter { Status EmitCublasLtMatmulThunkF8(mlir::Operation* op); Status EmitConvolutionReorderThunk(mlir::Operation* op); Status EmitNormThunk(mlir::Operation* op); - Status EmitTritonFusion(const HloFusionAnalysis& hlo_fusion_analysis, - const HloFusionInstruction* fusion, - mlir::Operation* op); + StatusOr EmitTritonFusion( + const HloFusionAnalysis& hlo_fusion_analysis, + const HloFusionInstruction* fusion, mlir::Operation* op); Status EmitFusedMHAThunk(mlir::Operation* op); Status EmitFusedMHABackwardThunk(mlir::Operation* op); #endif // GOOGLE_CUDA @@ -353,14 +354,14 @@ class IrEmitterUnnested : public IrEmitter { Status EmitScatter(const ScatterDescriptor& desc, const LaunchDimensions& launch_dimensions); - Status EmitScatter(const HloFusionInstruction* fusion, - mlir::lmhlo::FusionOp fusion_op, - HloFusionAnalysis& fusion_analysis); + StatusOr EmitScatter( + const HloFusionInstruction* fusion, mlir::lmhlo::FusionOp fusion_op, + HloFusionAnalysis& fusion_analysis); // Emits kernel thunk for a custom fusion implemented with hand written custom // device kernels. - Status EmitCustomFusion(const HloFusionInstruction* fusion, - const CustomFusionConfig& config); + StatusOr EmitCustomFusion( + const HloFusionInstruction* fusion, const CustomFusionConfig& config); // Builds a kernel thunk for a non-fusion operation, without reuse. // From b71841568ef65ad78c6251fc893c1050a16d1da6 Mon Sep 17 00:00:00 2001 From: Samuel Agyakwa Date: Mon, 13 Nov 2023 11:56:49 -0800 Subject: [PATCH 030/391] [PJRT C API] Support passing platform_name as an option in PJRT GPU plugin. PiperOrigin-RevId: 582038998 --- .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 12 ++- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 77 +++++++++++++++++++ third_party/xla/xla/python/xla_client.py | 2 +- 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 3dca0efe9547d2..5a62033b67c28d 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -54,7 +54,8 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { args->num_options); const auto kExpectedOptionNameAndTypes = absl::flat_hash_map( - {{"allocator", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, + {{"platform_name", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, + {"allocator", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, {"memory_fraction", PJRT_NamedValue_Type::PJRT_NamedValue_kFloat}, {"preallocate", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, {"visible_devices", @@ -64,6 +65,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); + std::optional platform_name; + if (auto it = create_options.find("platform_name"); + it != create_options.end()) { + platform_name.emplace(std::get(it->second)); + } xla::GpuAllocatorConfig allocator_config; if (auto it = create_options.find("allocator"); it != create_options.end()) { auto allocator_name = std::get(it->second); @@ -108,8 +114,8 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetStreamExecutorGpuClient( /*asynchronous=*/true, allocator_config, node_id, - num_nodes, visible_devices, - /*platform_name=*/std::nullopt, true, + num_nodes, visible_devices, platform_name, + /*should_stage_host_to_device_transfers=*/true, pjrt::ToCppKeyValueGetCallback( args->kv_get_callback, args->kv_get_user_arg), pjrt::ToCppKeyValuePutCallback( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index ded7bf4e5c81f0..561c9d99c4e429 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -333,6 +333,83 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) { api->PJRT_Error_Destroy(&error_destroy_args); } +TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { + auto api = GetPjrtApi(); + std::string expected_platform_name_for_cuda = "cuda"; + std::string expected_platform_name_for_rocm = "rocm"; + absl::flat_hash_map options = { + {"platform_name", static_cast("gpu")}, + {"allocator", static_cast("default")}, + {"visible_devices", xla::PjRtValueType(std::vector{0, 1})}, + }; + TF_ASSERT_OK_AND_ASSIGN( + std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + PJRT_Client_Create_Args create_arg; + create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; + create_arg.priv = nullptr; + create_arg.client = nullptr; + create_arg.create_options = c_options.data(); + create_arg.num_options = c_options.size(); + PJRT_Error* error = api->PJRT_Client_Create(&create_arg); + EXPECT_EQ(error, nullptr) << error->status.message(); + + PJRT_Client_PlatformName_Args platform_name_args; + platform_name_args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE; + platform_name_args.priv = nullptr; + platform_name_args.client = create_arg.client; + + PJRT_Error* platform_name_error = + api->PJRT_Client_PlatformName(&platform_name_args); + EXPECT_EQ(platform_name_error, nullptr); +#if TENSORFLOW_USE_ROCM + EXPECT_EQ(platform_name_args.platform_name, expected_platform_name_for_rocm); +#else + EXPECT_EQ(platform_name_args.platform_name, expected_platform_name_for_cuda); +#endif + + PJRT_Client_Destroy_Args destroy_args; + destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; + destroy_args.priv = nullptr; + destroy_args.client = create_arg.client; + + PJRT_Error* destroy_error = api->PJRT_Client_Destroy(&destroy_args); + CHECK_EQ(destroy_error, nullptr); +} + +TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) { + auto api = GetPjrtApi(); + absl::flat_hash_map options = { + {"platform_name", static_cast("invalid_platform_name")}, + {"allocator", static_cast("default")}, + {"visible_devices", xla::PjRtValueType(std::vector{0, 1})}, + }; + TF_ASSERT_OK_AND_ASSIGN( + std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30)); + PJRT_Client_Create_Args create_arg; + create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE; + create_arg.priv = nullptr; + create_arg.client = nullptr; + create_arg.create_options = c_options.data(); + create_arg.num_options = c_options.size(); + PJRT_Error* error = api->PJRT_Client_Create(&create_arg); + EXPECT_NE(error, nullptr); + EXPECT_THAT(error->status, + ::tsl::testing::StatusIs( + absl::StatusCode::kNotFound, + testing::StartsWith("Could not find registered platform with " + "name: \"invalid_platform_name\". " + "Available platform names are:"))); + + PJRT_Error_Destroy_Args error_destroy_args; + error_destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE; + error_destroy_args.priv = nullptr; + error_destroy_args.error = error; + + api->PJRT_Error_Destroy(&error_destroy_args); +} + void TestCustomCall() {} TEST(PjrtCApiGpuPrivTest, CustomCall) { diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index e2b201d2f7234f..6239cce3b38351 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -205,7 +205,7 @@ def generate_pjrt_gpu_plugin_options( options = {} if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - + options['platform_name'] = 'cuda' allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') From bd14fa4c71495e6934ab5f4002fce9baeeeb6fd8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 12:14:52 -0800 Subject: [PATCH 031/391] [PJRT] Refactor topology exchange helpers out of the GPU client and into a common utility library. Change in preparation for adding topology sharing on CPU as well. Also simplify the GPU initialization code: always use the distributed code path, even in the non-distributed case. We can't actually delete BuildLocalDevices because TensorFlow uses it. PiperOrigin-RevId: 582045318 --- third_party/xla/xla/pjrt/distributed/BUILD | 19 ++ .../xla/xla/pjrt/distributed/topology_util.cc | 145 +++++++++- .../xla/xla/pjrt/distributed/topology_util.h | 27 +- .../pjrt/distributed/topology_util_test.cc | 62 +++++ .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 251 ++++++------------ 5 files changed, 329 insertions(+), 175 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 875c8f08f58350..03be34e444d48b 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -54,7 +54,14 @@ xla_cc_test( deps = [ ":protocol_proto_cc", ":topology_util", + "//xla:status", + "//xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -118,9 +125,21 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protocol_proto_cc", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:utils", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.cc b/third_party/xla/xla/pjrt/distributed/topology_util.cc index cd6f17e2d81673..fa3e26b6595e30 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util.cc @@ -15,13 +15,113 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/utils.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" namespace xla { +// Exists on Linux systems. Unique per OS kernel restart. +static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id"; + +// Retrieve content of /proc/sys/kernel/random/boot_id as a string. +// Note that procfs file may have file size 0 which throws off generic file +// readers such as tsl::ReadFileToString. +StatusOr GetBootIdString() { + std::string boot_id_str; +#ifdef __linux__ + std::ifstream file(kBootIdPath); + if (!file) { + return NotFound("%s not found.", kBootIdPath); + } + std::string line; + while (std::getline(file, line)) { + absl::StripAsciiWhitespace(&line); + absl::StrAppend(&boot_id_str, line); + } +#endif + return boot_id_str; +} + +static std::string GetLocalTopologyKey(std::string_view platform, int node_id) { + return absl::StrCat("local_topology/", platform, "/", node_id); +} + +static std::string GetGlobalTopologyKey(std::string_view platform) { + return absl::StrCat("global_topology/", platform); +} + +static StatusOr> GetAllLocalTopologies( + std::string_view platform, int num_nodes, + const PjRtClient::KeyValueGetCallback& kv_get, absl::Duration timeout) { + std::vector> local_topology_strs(num_nodes); + + // TODO(ezhulenev): Should a thread pool become a function argument? + tsl::thread::ThreadPool thread_pool( + tsl::Env::Default(), "GetAllLocalTopologies", DefaultThreadPoolSize()); + + absl::BlockingCounter blocking_counter(num_nodes); + absl::Mutex mu; + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + StatusOr local_topology_str = + kv_get(GetLocalTopologyKey(platform, i), timeout); + { + absl::MutexLock lock(&mu); + local_topology_strs[i] = local_topology_str; + } + blocking_counter.DecrementCount(); + }); + } + blocking_counter.Wait(); + + std::vector error_messages; + std::vector local_topologies; + int max_num_failed_message = 10; + int failed_count = 0; + for (const StatusOr& str : local_topology_strs) { + if (str.ok()) { + LocalTopologyProto local; + local.ParseFromString(*str); + local_topologies.push_back(local); + } else { + error_messages.push_back( + absl::StrCat("Error ", ++failed_count, ": ", str.status().message())); + if (failed_count > max_num_failed_message) { + break; + } + } + } + if (error_messages.empty()) { + return local_topologies; + } + return absl::InternalError( + absl::StrCat("Getting local topologies failed: ", + absl::StrJoin(error_messages, "\n\n"))); +} + // Steals the contents of `local_topologies`. GlobalTopologyProto BuildGlobalTopology( absl::Span local_topologies) { @@ -32,7 +132,7 @@ GlobalTopologyProto BuildGlobalTopology( absl::flat_hash_map boot_id_to_slice_index; for (LocalTopologyProto& local : local_topologies) { // Every new boot_id seen is treated as a new host/slice. - absl::string_view boot_id = local.boot_id(); + std::string_view boot_id = local.boot_id(); auto [it, inserted] = boot_id_to_slice_index.try_emplace(boot_id, next_slice_index); if (inserted) { @@ -54,4 +154,47 @@ GlobalTopologyProto BuildGlobalTopology( return global_topology; } +Status ExchangeTopologies(std::string_view platform, int node_id, int num_nodes, + absl::Duration get_local_topology_timeout, + absl::Duration get_global_topology_timeout, + const PjRtClient::KeyValueGetCallback& kv_get, + const PjRtClient::KeyValuePutCallback& kv_put, + const LocalTopologyProto& local_topology, + GlobalTopologyProto* global_topology) { + VLOG(3) << "Local Topology for platform" << platform << ":\n" + << local_topology.DebugString(); + if (num_nodes == 1) { + LocalTopologyProto* topology = global_topology->add_nodes(); + *topology = local_topology; + for (DeviceProto& device : *topology->mutable_devices()) { + device.set_global_device_id(device.local_device_ordinal()); + } + return absl::OkStatus(); + } + + TF_RETURN_IF_ERROR(kv_put(GetLocalTopologyKey(platform, node_id), + local_topology.SerializeAsString())); + + // The lead node gets all local topologies, builds the global topology and + // puts it to the key-value store. + std::string global_topology_key = GetGlobalTopologyKey(platform); + if (node_id == 0) { + TF_ASSIGN_OR_RETURN(std::vector local_topologies, + GetAllLocalTopologies(platform, num_nodes, kv_get, + get_local_topology_timeout)); + *global_topology = + BuildGlobalTopology(absl::Span(local_topologies)); + TF_RETURN_IF_ERROR( + kv_put(global_topology_key, global_topology->SerializeAsString())); + } else { + TF_ASSIGN_OR_RETURN( + std::string global_topology_str, + kv_get(global_topology_key, get_global_topology_timeout)); + global_topology->ParseFromString(global_topology_str); + } + VLOG(3) << "Global topology for platform " << platform << ":\n" + << global_topology->DebugString(); + return absl::OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.h b/third_party/xla/xla/pjrt/distributed/topology_util.h index ad2d08274546d5..10e3732c71f670 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.h +++ b/third_party/xla/xla/pjrt/distributed/topology_util.h @@ -16,13 +16,38 @@ limitations under the License. #ifndef XLA_PJRT_DISTRIBUTED_TOPOLOGY_UTIL_H_ #define XLA_PJRT_DISTRIBUTED_TOPOLOGY_UTIL_H_ +#include +#include + +#include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/status.h" +#include "xla/statusor.h" namespace xla { +// Retrieve content of /proc/sys/kernel/random/boot_id as a string. +// Empty on non-Linux platforms. +StatusOr GetBootIdString(); + +// Performs a distributed exchange of topologies using a KV store. Each process +// provides its local topology, and the local topologies are exchanged to +// form a global topology. +Status ExchangeTopologies(std::string_view platform, int node_id, int num_nodes, + absl::Duration get_local_topology_timeout, + absl::Duration get_global_topology_timeout, + const PjRtClient::KeyValueGetCallback& kv_get, + const PjRtClient::KeyValuePutCallback& kv_put, + const LocalTopologyProto& local_topology, + GlobalTopologyProto* global_topology); + +// Functions below this point are public only for testing. + // Given a LocalTopologyProto object from each node, builds a -// GlobalTopologyProto that describes all nodes. +// GlobalTopologyProto that describes all nodes. Steals the contents of the +// LocalTopologyProtos. GlobalTopologyProto BuildGlobalTopology( absl::Span local_topologies); diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index 97192906ad610a..fd2b1a87709f50 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -15,11 +15,20 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" #include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace { @@ -42,5 +51,58 @@ TEST(TopologyTest, BuildGlobalTopology) { EXPECT_EQ(global.nodes()[1].devices_size(), 2); } +TEST(TopologyTest, ExchangeTopology) { + int num_nodes = 2; + std::vector locals(num_nodes); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + absl::Mutex mu; + absl::flat_hash_map kv; + + auto kv_get = [&](const std::string& key, + absl::Duration timeout) -> xla::StatusOr { + absl::MutexLock lock(&mu); + auto ready = [&]() { return kv.contains(key); }; + if (mu.AwaitWithTimeout(absl::Condition(&ready), timeout)) { + return kv[key]; + } + return absl::NotFoundError("key not found"); + }; + + auto kv_put = [&](const std::string& key, + const std::string& value) -> xla::Status { + absl::MutexLock lock(&mu); + kv[key] = value; + return absl::OkStatus(); + }; + + std::vector globals(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", + num_nodes); + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), kv_get, kv_put, locals[i], &globals[i])); + }); + } + } + for (const GlobalTopologyProto& global : globals) { + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + } +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 132671a59ad96d..87ec53368e8f82 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -15,12 +15,12 @@ limitations under the License. #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include #include #include #include #include #include +#include #include #include @@ -585,21 +585,6 @@ StreamExecutorGpuClient::LoadSerialized(absl::string_view serialized, return absl::InternalError("LoadSerialized only works with cuda or rocm."); } -std::vector> BuildLocalDevices( - std::map> local_device_states, - int node_id) { - std::vector> devices; - for (auto& ordinal_and_device : local_device_states) { - const se::DeviceDescription& description = - ordinal_and_device.second->executor()->GetDeviceDescription(); - auto device = std::make_unique( - ordinal_and_device.first, std::move(ordinal_and_device.second), - description.name(), description.device_vendor(), node_id); - devices.push_back(std::move(device)); - } - return devices; -} - StatusOr> StreamExecutorGpuClient::Load( std::unique_ptr executable) { auto se_executable = absl::WrapUnique( @@ -762,90 +747,14 @@ GetStreamExecutorGpuDeviceAllocator( return std::move(allocator); } -// Exists on Linux systems. Unique per OS kernel restart. -static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id"; - -// Retrieve content of /proc/sys/kernel/random/boot_id as a string. -// Note that procfs file may have file size 0 which throws off generic file -// readers such as tsl::ReadFileToString. -StatusOr GetBootIdString() { - std::string boot_id_str; -#ifdef __linux__ - std::ifstream file(kBootIdPath); - if (!file) { - return NotFound("%s not found.", kBootIdPath); - } - std::string line; - while (std::getline(file, line)) { - absl::StripAsciiWhitespace(&line); - absl::StrAppend(&boot_id_str, line); - } -#endif - return boot_id_str; -} - -static std::string GetLocalTopologyKey(int node_id) { - return absl::StrCat("local_topology:", node_id); -} - -static std::string GetGlobalTopologyKey() { return "global_topology"; } - -static StatusOr> GetAllLocalTopologies( - int num_nodes, const PjRtClient::KeyValueGetCallback& kv_get, - absl::Duration timeout) { - std::vector> local_topology_strs(num_nodes); - - // TODO(ezhulenev): Should a thread pool become a function argument? - tsl::thread::ThreadPool thread_pool( - tsl::Env::Default(), "GetAllLocalTopologies", DefaultThreadPoolSize()); - - absl::BlockingCounter blocking_counter(num_nodes); - absl::Mutex mu; - for (int i = 0; i < num_nodes; i++) { - thread_pool.Schedule([&, i] { - StatusOr local_topology_str = - kv_get(GetLocalTopologyKey(i), timeout); - { - absl::MutexLock lock(&mu); - local_topology_strs[i] = local_topology_str; - } - blocking_counter.DecrementCount(); - }); - } - blocking_counter.Wait(); - - std::vector error_messages; - std::vector local_topologies; - int max_num_failed_message = 10; - int failed_count = 0; - for (const StatusOr& str : local_topology_strs) { - if (str.ok()) { - LocalTopologyProto local; - local.ParseFromString(*str); - local_topologies.push_back(local); - } else { - error_messages.push_back( - absl::StrCat("Error ", ++failed_count, ": ", str.status().message())); - if (failed_count > max_num_failed_message) { - break; - } - } - } - if (error_messages.empty()) { - return local_topologies; - } - return absl::InternalError( - absl::StrCat("Getting local topologies failed: ", - absl::StrJoin(error_messages, "\n\n"))); -} - Status BuildDistributedDevices( + std::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, std::vector>* devices, gpu::GpuExecutableRunOptions* gpu_executable_run_options, - PjRtClient::KeyValueGetCallback kv_get, - PjRtClient::KeyValuePutCallback kv_put, + const PjRtClient::KeyValueGetCallback& kv_get, + const PjRtClient::KeyValuePutCallback& kv_put, absl::Duration get_local_topology_timeout = absl::Minutes(2), absl::Duration get_global_topology_timeout = absl::Minutes(5)) { LocalTopologyProto local_topology; @@ -869,28 +778,12 @@ Status BuildDistributedDevices( device_proto->set_name(desc->name()); device_proto->set_vendor(desc->device_vendor()); } - VLOG(3) << "GPU Local Topology:\n" << local_topology.DebugString(); - TF_RETURN_IF_ERROR( - kv_put(GetLocalTopologyKey(node_id), local_topology.SerializeAsString())); GlobalTopologyProto global_topology; - // The lead node gets all local topologies, builds the global topology and - // puts it to the key-value store. - if (node_id == 0) { - TF_ASSIGN_OR_RETURN( - std::vector local_topologies, - GetAllLocalTopologies(num_nodes, kv_get, get_local_topology_timeout)); - global_topology = - BuildGlobalTopology(absl::Span(local_topologies)); - TF_RETURN_IF_ERROR( - kv_put(GetGlobalTopologyKey(), global_topology.SerializeAsString())); - } else { - TF_ASSIGN_OR_RETURN( - std::string global_topology_str, - kv_get(GetGlobalTopologyKey(), get_global_topology_timeout)); - global_topology.ParseFromString(global_topology_str); - } - VLOG(3) << "GPU Global Topology:\n" << global_topology.DebugString(); + TF_RETURN_IF_ERROR(ExchangeTopologies( + platform_name, node_id, num_nodes, get_local_topology_timeout, + get_global_topology_timeout, kv_get, kv_put, local_topology, + &global_topology)); std::map gpu_device_ids; absl::flat_hash_map device_to_node; @@ -920,12 +813,14 @@ Status BuildDistributedDevices( gpu_executable_run_options->set_gpu_global_device_ids( std::move(gpu_device_ids)); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - auto nccl_id_store = - std::make_shared(node_id, device_to_node, kv_get, kv_put); - gpu_executable_run_options->set_nccl_unique_id_callback( - [nccl_id_store](const gpu::NcclCliqueKey& key) { - return nccl_id_store->GetNcclUniqueId(key); - }); + if (num_nodes > 1) { + auto nccl_id_store = + std::make_shared(node_id, device_to_node, kv_get, kv_put); + gpu_executable_run_options->set_nccl_unique_id_callback( + [nccl_id_store](const gpu::NcclCliqueKey& key) { + return nccl_id_store->GetNcclUniqueId(key); + }); + } #endif // GOOGLE_CUDA return OkStatus(); } @@ -981,6 +876,12 @@ StatusOr> GetStreamExecutorGpuClient( bool should_stage_host_to_device_transfers, PjRtClient::KeyValueGetCallback kv_get, PjRtClient::KeyValuePutCallback kv_put, bool enable_mock_nccl) { +#if TENSORFLOW_USE_ROCM + auto pjrt_platform_name = xla::RocmName(); +#else // TENSORFLOW_USE_ROCM + auto pjrt_platform_name = xla::CudaName(); +#endif // TENSORFLOW_USE_ROCM + TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient(platform_name, allowed_devices)); std::map> local_device_states; @@ -998,61 +899,50 @@ StatusOr> GetStreamExecutorGpuClient( if (enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); } - if (num_nodes > 1) { - absl::flat_hash_map device_maps; - absl::Mutex mu; - if (enable_mock_nccl) { - kv_get = [&device_maps, &mu, &num_nodes]( - const std::string& k, - absl::Duration timeout) -> xla::StatusOr { - std::string result; - { - absl::MutexLock lock(&mu); - if (device_maps.contains(k)) { - result = device_maps[k]; - } else { - int device_id; - std::vector tokens = absl::StrSplit(k, ':'); - if (tokens.size() != 2 || - !absl::SimpleAtoi(tokens[1], &device_id)) { - device_id = num_nodes - 1; - } - // Return fake local topology with device_id info back. - xla::LocalTopologyProto local; - local.set_boot_id("fake_boot_id"); - local.set_node_id(device_id); - xla::DeviceProto* device = local.add_devices(); - device->set_global_device_id(device_id); - device->set_name("fake_device"); - device->set_vendor("fake_vendor"); - result = local.SerializeAsString(); + absl::flat_hash_map device_maps; + absl::Mutex mu; + if (enable_mock_nccl) { + kv_get = [&device_maps, &mu, &num_nodes]( + const std::string& k, + absl::Duration timeout) -> xla::StatusOr { + std::string result; + { + absl::MutexLock lock(&mu); + if (device_maps.contains(k)) { + result = device_maps[k]; + } else { + int device_id; + std::vector tokens = absl::StrSplit(k, ':'); + if (tokens.size() != 2 || !absl::SimpleAtoi(tokens[1], &device_id)) { + device_id = num_nodes - 1; } + // Return fake local topology with device_id info back. + xla::LocalTopologyProto local; + local.set_boot_id("fake_boot_id"); + local.set_node_id(device_id); + xla::DeviceProto* device = local.add_devices(); + device->set_global_device_id(device_id); + device->set_name("fake_device"); + device->set_vendor("fake_vendor"); + result = local.SerializeAsString(); } - return result; - }; - kv_put = [&device_maps, &mu](const std::string& k, - const std::string& v) -> xla::Status { - { - absl::MutexLock lock(&mu); - device_maps[k] = v; - } - return xla::OkStatus(); - }; - } - TF_RET_CHECK(kv_get != nullptr); - TF_RET_CHECK(kv_put != nullptr); - TF_RETURN_IF_ERROR(BuildDistributedDevices( - std::move(local_device_states), node_id, num_nodes, &devices, - gpu_run_options.get(), kv_get, kv_put)); - } else { - devices = BuildLocalDevices(std::move(local_device_states), node_id); + } + return result; + }; + kv_put = [&device_maps, &mu](const std::string& k, + const std::string& v) -> xla::Status { + { + absl::MutexLock lock(&mu); + device_maps[k] = v; + } + return xla::OkStatus(); + }; } - -#if TENSORFLOW_USE_ROCM - auto pjrt_platform_name = xla::RocmName(); -#else // TENSORFLOW_USE_ROCM - auto pjrt_platform_name = xla::CudaName(); -#endif // TENSORFLOW_USE_ROCM + TF_RET_CHECK(num_nodes == 1 || kv_get != nullptr); + TF_RET_CHECK(num_nodes == 1 || kv_put != nullptr); + TF_RETURN_IF_ERROR(BuildDistributedDevices( + pjrt_platform_name, std::move(local_device_states), node_id, num_nodes, + &devices, gpu_run_options.get(), kv_get, kv_put)); return std::unique_ptr(std::make_unique( pjrt_platform_name, xla_client, std::move(devices), @@ -1070,4 +960,19 @@ absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() return result; } +std::vector> BuildLocalDevices( + std::map> local_device_states, + int node_id) { + std::vector> devices; + for (auto& ordinal_and_device : local_device_states) { + const se::DeviceDescription& description = + ordinal_and_device.second->executor()->GetDeviceDescription(); + auto device = std::make_unique( + ordinal_and_device.first, std::move(ordinal_and_device.second), + description.name(), description.device_vendor(), node_id); + devices.push_back(std::move(device)); + } + return devices; +} + } // namespace xla From 1c5e88a84c70f8c3b29879634a5ea9247cfa8ac0 Mon Sep 17 00:00:00 2001 From: Ziyin Huang Date: Mon, 13 Nov 2023 12:25:13 -0800 Subject: [PATCH 032/391] Make the logging the max_ids/unique_ids default in the sparse core xla ops. This will make debugging easier. PiperOrigin-RevId: 582048387 --- .../core/tpu/kernels/sparse_core_xla_ops.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index 39b459709849b1..3d5e8642b2e58e 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -264,10 +264,10 @@ class XlaSparseDenseMatmulWithCsrInputOp : public XlaOpKernel { ->Set(max_ids_per_partition); max_unique_ids_per_partition_gauge_->GetCell(device_name_, table_name_) ->Set(max_unique_ids_per_partition); - VLOG(3) << "XlaSparseDenseMatmulWithCsrInputOp: " - << "table_name = '" << table_name_ - << "', max_ids = " << max_ids_per_partition - << ", max_uniques = " << max_unique_ids_per_partition; + LOG(INFO) << "Lowering XlaSparseDenseMatmulWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape( "num_minibatches_per_physical_sparse_core")), @@ -427,10 +427,10 @@ class XlaSparseDenseMatmulGradWithCsrInputBase : public XlaOpKernel { ctx, GetMaxIdsAndUniquesExternal( "", table_name_, per_sparse_core_batch_size, feature_width, &max_ids_per_partition, &max_unique_ids_per_partition)); - VLOG(3) << "XlaSparseDenseMatmulWithCsrInputOp: " - << "table_name = '" << table_name_ - << "', max_ids = " << max_ids_per_partition - << ", max_uniques = " << max_unique_ids_per_partition; + LOG(INFO) << "Lowering XlaSparseDenseMatmulGradWithCsrInputOp to HLO: " + << "table_name = '" << table_name_ + << "', max_ids = " << max_ids_per_partition + << ", max_uniques = " << max_unique_ids_per_partition; xla::XlaComputation optimizer = build_optimizer_computation(feature_width); From 3ee57d097813c5f92397da14f9401ceed623d25d Mon Sep 17 00:00:00 2001 From: Chase Riley Roberts Date: Mon, 13 Nov 2023 12:34:59 -0800 Subject: [PATCH 033/391] PR #6719: Added py::bytes wrappings to backend config strings in custom partitioner Imported from GitHub PR https://github.com/openxla/xla/pull/6719 Fixes #6718 There was an issue where the registered custom partitioning would pass a normal python string for the backend configuration string. However, python will try to convert this string to utf-8 first. If your configuration string contains non valid utf-8 bytes, this conversion fails with a somewhat cryptic error. This PR fixes that issue by wrapping all of the strings in `py::bytes` instead. Copybara import of the project: -- 2709b98beb1eb61eb8030973098ae56ae6c19992 by Thenerdstation : Added py::bytes wrappings Merging this change closes #6719 PiperOrigin-RevId: 582051076 --- third_party/xla/xla/python/custom_call_sharding.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/custom_call_sharding.cc b/third_party/xla/xla/python/custom_call_sharding.cc index 6b0cdd2ba1f964..318164a87375a6 100644 --- a/third_party/xla/xla/python/custom_call_sharding.cc +++ b/third_party/xla/xla/python/custom_call_sharding.cc @@ -122,7 +122,7 @@ class PyCustomCallPartitioner : public CustomCallPartitioner { auto py_result = partition_(GetArgShapes(instruction), GetArgShardings(instruction), instruction->shape(), instruction->sharding(), - instruction->raw_backend_config_string()); + py::bytes(instruction->raw_backend_config_string())); const XlaComputation* computation = nullptr; // Kept alive by py_result. std::vector arg_shardings; @@ -190,9 +190,9 @@ class PyCustomCallPartitioner : public CustomCallPartitioner { // The user is used when the custom call returns a Tuple and // the user is a get-tuple-element. In this case we must update only // part of the sharding spec. - auto result = py::cast( - prop_user_sharding_(sharding, instruction->shape(), - instruction->raw_backend_config_string())); + auto result = py::cast(prop_user_sharding_( + sharding, instruction->shape(), + py::bytes(instruction->raw_backend_config_string()))); return result; } std::optional InferShardingFromOperands( @@ -202,7 +202,7 @@ class PyCustomCallPartitioner : public CustomCallPartitioner { py::gil_scoped_acquire gil; auto py_result = infer_sharding_from_operands_( arg_shapes, arg_shardings, instruction->shape(), - instruction->raw_backend_config_string()); + py::bytes(instruction->raw_backend_config_string())); if (py_result.is_none()) { return std::nullopt; } From d1a8716e7e4f9159e6e17df24fd64510f1c3fa09 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 13:00:46 -0800 Subject: [PATCH 034/391] [PJRT] Remove unused EnumerateDevices API. The users of EnumerateDevice() were refactored to use the KV-store instead as part of the PJRT plugin work. There are no remaining users. Also rename "enumerate_devices_timeout" to "cluster_register_timeout" which matches the name given in the coordination service code. Bump the default timeout to one hour, which also matches the coordination service default. PiperOrigin-RevId: 582058387 --- third_party/xla/xla/pjrt/distributed/BUILD | 6 +-- .../xla/xla/pjrt/distributed/client.cc | 22 --------- third_party/xla/xla/pjrt/distributed/client.h | 11 +---- .../pjrt/distributed/client_server_test.cc | 46 ++++++++++++++++-- .../xla/xla/pjrt/distributed/service.cc | 47 ++----------------- .../xla/xla/pjrt/distributed/service.h | 6 +-- third_party/xla/xla/python/xla.cc | 10 ++-- .../xla/xla/python/xla_extension/__init__.pyi | 2 +- 8 files changed, 57 insertions(+), 93 deletions(-) diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 03be34e444d48b..5fba7fba3099ee 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -12,10 +12,7 @@ package( tf_proto_library( name = "protocol_proto", srcs = ["protocol.proto"], - has_services = 1, cc_api_version = 2, - create_grpc_library = True, - use_grpc_namespace = True, visibility = ["//visibility:public"], ) @@ -25,7 +22,6 @@ cc_library( hdrs = ["service.h"], visibility = ["//visibility:public"], deps = [ - ":protocol_cc_grpc_proto", ":topology_util", ":util", "//xla:status", @@ -77,7 +73,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":protocol_cc_grpc_proto", ":util", "//xla:statusor", "//xla:types", @@ -152,6 +147,7 @@ xla_cc_test( ":distributed", ":protocol_proto_cc", ":service", + ":topology_util", "//xla:protobuf_util", "//xla:status_macros", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index ce63141e05577a..b3027eaf8b6169 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -45,8 +45,6 @@ class DistributedRuntimeCoordinationServiceClient xla::Status Connect() override; xla::Status Shutdown() override; - xla::Status EnumerateDevices(const LocalTopologyProto& local_topology, - GlobalTopologyProto* global_topology) override; xla::StatusOr BlockingKeyValueGet( std::string key, absl::Duration timeout) override; xla::StatusOr>> @@ -133,26 +131,6 @@ xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { return s; } -xla::Status DistributedRuntimeCoordinationServiceClient::EnumerateDevices( - const LocalTopologyProto& local_topology, - GlobalTopologyProto* global_topology) { - LocalTopologyProto local_device = local_topology; - local_device.set_node_id(task_id_); - tensorflow::DeviceInfo devices; - devices.mutable_device()->Add()->PackFrom(local_device); - // Client sends LocalTopologyProto. - Status s = coord_agent_->WaitForAllTasks(devices); - if (!s.ok()) return s; - // Server responds with GlobalTopologyProto (refer to service.cc for details). - tensorflow::DeviceInfo global_devices = coord_agent_->GetClusterDeviceInfo(); - if (global_devices.device_size() != 1) { - return tsl::errors::Internal( - "Unexpected cluster device response from EnumerateDevices()."); - } - global_devices.device().Get(0).UnpackTo(global_topology); - return OkStatus(); -} - xla::StatusOr DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( std::string key, absl::Duration timeout) { diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 18c15f107d2694..6ee9ce976df557 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/time/time.h" #include "grpcpp/channel.h" -#include "xla/pjrt/distributed/protocol.grpc.pb.h" #include "xla/statusor.h" #include "xla/types.h" #include "tsl/platform/env.h" @@ -101,8 +100,8 @@ class DistributedRuntimeClient { // Connects to the master, and blocks until all clients have successfully // connected. - // Not thread-safe, i.e., calls to Connect()/Shutdown()/EnumerateDevices() - // must be serialized by some other means. + // Not thread-safe, i.e., calls to Connect()/Shutdown() must be serialized by + // some other means. virtual xla::Status Connect() = 0; // Reports to the master that the client is ready to shutdown, and blocks @@ -110,12 +109,6 @@ class DistributedRuntimeClient { // Not thread-safe. virtual xla::Status Shutdown() = 0; - // Blocking enumeration of global devices. Used by the GPU platform. - // Not thread-safe. - virtual xla::Status EnumerateDevices( - const LocalTopologyProto& local_topology, - GlobalTopologyProto* global_topology) = 0; - // The following APIs are thread-safe. // Key-value store API. diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index 12df3a01475a42..efd3f7cda98e59 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/distributed/topology_util.h" #include "xla/protobuf_util.h" #include "xla/status_macros.h" #include "tsl/lib/core/status_test_util.h" @@ -215,7 +216,20 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { n.WaitForNotification(); // Sleep a short while for the other thread to send their device info first. absl::SleepFor(absl::Seconds(1)); - TF_RETURN_IF_ERROR(client->EnumerateDevices(locals[0], &topology)); + + auto kv_get = [&](const std::string& k, + absl::Duration timeout) -> xla::StatusOr { + return client->BlockingKeyValueGet(k, timeout); + }; + auto kv_put = [&](const std::string& k, + const std::string& v) -> xla::Status { + return client->KeyValueSet(k, v); + }; + TF_RETURN_IF_ERROR( + ExchangeTopologies("cuda", /*node_id=*/0, /*num_nodes=*/2, + /*get_local_topology_timeout=*/absl::Minutes(1), + /*get_global_topology_timeout=*/absl::Minutes(1), + kv_get, kv_put, locals[0], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); @@ -236,7 +250,19 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // We cannot send the notification after the call since there is a barrier // within the call that would cause a deadlock. n.Notify(); - TF_RETURN_IF_ERROR(client->EnumerateDevices(locals[1], &topology)); + auto kv_get = [&](const std::string& k, + absl::Duration timeout) -> xla::StatusOr { + return client->BlockingKeyValueGet(k, timeout); + }; + auto kv_put = [&](const std::string& k, + const std::string& v) -> xla::Status { + return client->KeyValueSet(k, v); + }; + TF_RETURN_IF_ERROR( + ExchangeTopologies("cuda", /*node_id=*/1, /*num_nodes=*/2, + /*get_local_topology_timeout=*/absl::Minutes(1), + /*get_global_topology_timeout=*/absl::Minutes(1), + kv_get, kv_put, locals[1], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); @@ -290,7 +316,19 @@ TEST_F(ClientServerTest, EnumerateElevenDevices) { auto client = GetClient(node_id); GlobalTopologyProto topology; TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->EnumerateDevices(locals[node_id], &topology)); + auto kv_get = [&](const std::string& k, + absl::Duration timeout) -> xla::StatusOr { + return client->BlockingKeyValueGet(k, timeout); + }; + auto kv_put = [&](const std::string& k, + const std::string& v) -> xla::Status { + return client->KeyValueSet(k, v); + }; + TF_RETURN_IF_ERROR( + ExchangeTopologies("cuda", /*node_id=*/node_id, num_nodes, + /*get_local_topology_timeout=*/absl::Minutes(1), + /*get_global_topology_timeout=*/absl::Minutes(1), + kv_get, kv_put, locals[node_id], &topology)); TF_RET_CHECK( xla::protobuf_util::ProtobufEquals(topology, expected_topology)) << topology.DebugString(); @@ -515,7 +553,7 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) { int num_nodes = 3; absl::Duration timeout = absl::Milliseconds(100); CoordinationServiceImpl::Options service_options; - service_options.enumerate_devices_timeout = timeout; + service_options.cluster_register_timeout = timeout; service_options.shutdown_timeout = timeout; StartService(num_nodes, service_options); diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 947c7267bd3f30..37d4c651544b64 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -17,9 +17,8 @@ limitations under the License. #include #include +#include -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "grpcpp/server_builder.h" #include "xla/util.h" @@ -40,7 +39,7 @@ std::unique_ptr EnableCoordinationService( config.set_service_type("standalone"); config.set_service_leader(absl::StrCat("/job:", job_name, "/task:0")); config.set_cluster_register_timeout_in_ms( - absl::ToInt64Milliseconds(options.enumerate_devices_timeout)); + absl::ToInt64Milliseconds(options.cluster_register_timeout)); config.set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds( options.heartbeat_interval * options.max_missing_heartbeats)); config.set_shutdown_barrier_timeout_in_ms( @@ -51,46 +50,6 @@ std::unique_ptr EnableCoordinationService( job->set_num_tasks(options.num_nodes); auto service = tsl::CoordinationServiceInterface::EnableCoordinationService( options.env, config, /*cache=*/nullptr); - // Convert list of local devices to global device message as EnumerateDevies() - // response. - service->SetDeviceAggregationFunction( - [](const tensorflow::DeviceInfo& raw_global_devices) { - xla::GlobalTopologyProto global_topology; - int global_device_id = 0; - // Assign local devices of the same host to the same slice_index. - int next_slice_index = 0; - absl::flat_hash_map boot_id_to_slice_index; - // Unwrap result to local device proto. - for (const auto& device : raw_global_devices.device()) { - xla::LocalTopologyProto local_topology; - // Note that tensorflow::DeviceInfo.device is xla.LocalTopologyProto! - device.UnpackTo(&local_topology); - // Every new boot_id seen is treated as a new host/slice. - absl::string_view boot_id = local_topology.boot_id(); - auto [it, inserted] = - boot_id_to_slice_index.try_emplace(boot_id, next_slice_index); - if (inserted) { - ++next_slice_index; - } - // Set deterministic global ids. - for (xla::DeviceProto& device : *local_topology.mutable_devices()) { - device.set_global_device_id(global_device_id++); - device.set_slice_index(it->second); - } - *global_topology.mutable_nodes()->Add() = local_topology; - } - if (VLOG_IS_ON(10)) { - for (auto it = boot_id_to_slice_index.begin(); - it != boot_id_to_slice_index.end(); ++it) { - LOG(INFO) << "BuildGlobalTopology boot_id_to_slice_index " - << it->first << "->" << it->second; - } - } - // Wrap result back in DeviceInfo proto. - tensorflow::DeviceInfo global_devices; - global_devices.mutable_device()->Add()->PackFrom(global_topology); - return global_devices; - }); return service; } } // namespace @@ -110,7 +69,7 @@ CoordinationServiceImpl::CoordinationServiceImpl( auto* grpc_coord_service = static_cast(coord_rpc_service_.get()); grpc_coord_service->SetCoordinationServiceInstance(coord_service_.get()); - LOG(INFO) << "Experimental coordination service is enabled."; + LOG(INFO) << "Coordination service is enabled."; } CoordinationServiceImpl::~CoordinationServiceImpl() { diff --git a/third_party/xla/xla/pjrt/distributed/service.h b/third_party/xla/xla/pjrt/distributed/service.h index ef79e12e363759..d2e308d0f8a71c 100644 --- a/third_party/xla/xla/pjrt/distributed/service.h +++ b/third_party/xla/xla/pjrt/distributed/service.h @@ -25,7 +25,7 @@ limitations under the License. #include "absl/time/time.h" #include "grpcpp/grpcpp.h" #include "grpcpp/security/server_credentials.h" -#include "xla/pjrt/distributed/protocol.grpc.pb.h" +#include "grpcpp/server_builder.h" #include "xla/statusor.h" #include "xla/types.h" #include "tsl/distributed_runtime/coordination/coordination_service.h" @@ -53,9 +53,9 @@ class CoordinationServiceImpl { // coordinator concludes that a client has vanished. int max_missing_heartbeats = 10; - // How long should we wait for all clients to call EnumerateDevices() before + // How long should we wait for all clients to call Connect() before // giving up? - absl::Duration enumerate_devices_timeout = absl::Seconds(60); + absl::Duration cluster_register_timeout = absl::Minutes(60); // How long should we wait for all clients to call Shutdown() before giving // up and returning a failure? diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 06de87623a078c..a4db3deef34730 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -899,7 +899,7 @@ static void Init(py::module_& m) { [](std::string address, int num_nodes, std::optional heartbeat_interval, std::optional max_missing_heartbeats, - std::optional enumerate_devices_timeout, + std::optional cluster_register_timeout, std::optional shutdown_timeout) -> std::unique_ptr { CoordinationServiceImpl::Options options; @@ -910,9 +910,9 @@ static void Init(py::module_& m) { if (max_missing_heartbeats.has_value()) { options.max_missing_heartbeats = *max_missing_heartbeats; } - if (enumerate_devices_timeout.has_value()) { - options.enumerate_devices_timeout = - absl::Seconds(*enumerate_devices_timeout); + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); } if (shutdown_timeout.has_value()) { options.shutdown_timeout = absl::Seconds(*shutdown_timeout); @@ -924,7 +924,7 @@ static void Init(py::module_& m) { py::arg("address"), py::arg("num_nodes"), py::kw_only(), py::arg("heartbeat_interval") = std::nullopt, py::arg("max_missing_heartbeats") = std::nullopt, - py::arg("enumerate_devices_timeout") = std::nullopt, + py::arg("cluster_register_timeout") = std::nullopt, py::arg("shutdown_timeout") = std::nullopt); m.def( diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 63df53d5ba3710..19752e1c593903 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -679,7 +679,7 @@ def get_distributed_runtime_service( num_nodes: int, heartbeat_interval: Optional[int] = ..., max_missing_heartbeats: Optional[int] = ..., - enumerate_devices_timeout: Optional[int] = ..., + cluster_register_timeout: Optional[int] = ..., shutdown_timeout: Optional[int] = ...) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, From f25c22bc494573699c05ee7c38b6154da8c8fc62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 13:21:36 -0800 Subject: [PATCH 035/391] Support per-channel quantization for data movement ops This is needed for per-channel quantized weight. Skip op/result type check in the pattern because some op may change for per-channel quantized type. E.g. Quantization axis may change after broadcast_in_dims. PiperOrigin-RevId: 582064121 --- .../bridge/convert_mhlo_quant_to_int.cc | 23 +----- .../bridge/convert-mhlo-quant-to-int.mlir | 72 ++++++++++++------- 2 files changed, 46 insertions(+), 49 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 2ff1ba9200261d..a5a425d540dd30 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -1160,6 +1160,7 @@ class ConvertUniformQuantizedConvolutionOp // This pattern lowers a generic MHLO op for uq->int. // This pattern essentially just performs type change, with no algorithm change. +// TODO: b/310685906 - Add operand/result type validations. class ConvertGenericOp : public ConversionPattern { public: explicit ConvertGenericOp(MLIRContext *ctx) @@ -1174,28 +1175,6 @@ class ConvertGenericOp : public ConversionPattern { return failure(); } - // Check that all operands and result uq types are the same. - llvm::SmallVector uq_types; - for (auto result_type : op->getResultTypes()) { - auto type = - getElementTypeOrSelf(result_type).dyn_cast(); - if (type) { - uq_types.push_back(type); - } - } - for (auto operand : op->getOperands()) { - auto type = getElementTypeOrSelf(operand.getType()) - .dyn_cast(); - if (type) { - uq_types.push_back(type); - } - } - for (auto type : uq_types) { - if (type != uq_types.front()) { - return failure(); - } - } - // Determine new result type: use storage type for uq types; use original // type otherwise. llvm::SmallVector new_result_types; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 65c8497aa9a41a..e022bcd81c447a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -1660,6 +1660,21 @@ func.func @broadcast( // ----- +// CHECK-LABEL: func @broadcast_per_channel +func.func @broadcast_per_channel( + %arg0: tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> { + // CHECK: "mhlo.broadcast_in_dim" + // CHECK-SAME: broadcast_dimensions = dense<3> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<128x26x26x2xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>}: ( + tensor<2x!quant.uniform> + ) -> tensor<128x26x26x2x!quant.uniform> + return %0 : tensor<128x26x26x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @max func.func @max( %arg0: tensor<1x2x!quant.uniform> @@ -1675,6 +1690,21 @@ func.func @max( // ----- +// CHECK-LABEL: func @max_per_channel +func.func @max_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.maximum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.maximum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @min func.func @min( %arg0: tensor<1x2x!quant.uniform> @@ -1690,6 +1720,21 @@ func.func @min( // ----- +// CHECK-LABEL: func @min_per_channel +func.func @min_per_channel( + %arg0: tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> { + // CHECK: mhlo.minimum + // CHECK-SAME: tensor<1x2xi8> + %0 = "mhlo.minimum"(%arg0, %arg0) : ( + tensor<1x2x!quant.uniform>, + tensor<1x2x!quant.uniform> + ) -> tensor<1x2x!quant.uniform> + return %0 : tensor<1x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @function(%arg0: tensor<1x2xi8>) -> tensor<1x2xi8> func.func @function( %arg0: tensor<1x2x!quant.uniform> @@ -1697,30 +1742,3 @@ func.func @function( // CHECK: return %arg0 : tensor<1x2xi8> return %arg0 : tensor<1x2x!quant.uniform> } - -// ----- - -func.func @min_mix_uq_type1( - %arg0: tensor<1x2x!quant.uniform>, - %arg1: tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{failed to legalize operation 'mhlo.minimum' that was explicitly marked illegal}} - %0 = "mhlo.minimum"(%arg0, %arg1) : ( - tensor<1x2x!quant.uniform>, - tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> - return %0 : tensor<1x2x!quant.uniform> -} - -// ----- - -func.func @min_mix_uq_type2( - %arg0: tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{failed to legalize operation 'mhlo.minimum' that was explicitly marked illegal}} - %0 = "mhlo.minimum"(%arg0, %arg0) : ( - tensor<1x2x!quant.uniform>, - tensor<1x2x!quant.uniform> - ) -> tensor<1x2x!quant.uniform> - return %0 : tensor<1x2x!quant.uniform> -} From 8643a51affbb474216b27b3e4832eca55f692c95 Mon Sep 17 00:00:00 2001 From: looi Date: Mon, 13 Nov 2023 21:30:59 +0000 Subject: [PATCH 036/391] ConvGeneric: fix local mem reads when pointers not supported Before this change, the weights_cache is populated but then global mem read is performed, wasting the weights_cache. --- tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc index 72e54dd21c94f5..664e024c77a6d0 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc @@ -886,7 +886,8 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, std::to_string(s * 4 + ch + shared_offset); std::string w_val; if (conv_params.AreWeightsBuffer()) { - if (gpu_info.SupportsPointersInKernels()) { + if (need_local_mem || + gpu_info.SupportsPointersInKernels()) { w_val = "weights_cache[" + weight_id + "]"; } else { w_val = "args.weights.Read(filters_offset + " + @@ -926,7 +927,7 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, std::string weight_id = std::to_string(s * 4 + i + shared_offset); if (conv_params.AreWeightsBuffer()) { - if (gpu_info.SupportsPointersInKernels()) { + if (need_local_mem || gpu_info.SupportsPointersInKernels()) { F[i] = "weights_cache[" + weight_id + "]"; } else { F[i] = @@ -1113,7 +1114,7 @@ std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info, c += " if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n"; c += " {\n"; if (conv_params.AreWeightsBuffer() && - gpu_info.SupportsPointersInKernels()) { + (need_local_mem || gpu_info.SupportsPointersInKernels())) { c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n"; } else { c += " FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n"; From 15055e80dad9b245e283df61a0b49907cc5dc91c Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 13 Nov 2023 14:21:33 -0800 Subject: [PATCH 037/391] [PJRT] Allow the compiler to choose executable input and output layouts. This is done by adding a new "mhlo.layout_mode" attribute to arguments and results in the input MLIR module to the compiler. The layout can either be the default (i.e. compact shape), compiler chooses, or caller chooses by specifying a layout. Prior to this change, the caller could specify layouts via CompileOptions.argument_layouts and CompileOptions.executable_build_options.result_layout, and otherwise the default compact layout would be used. This change both introduces the third compiler-chooses option, and embeds this info in the IR instead of as a separate compile option. The old options still work while we update callers and take precedence if specified. PiperOrigin-RevId: 582082768 --- third_party/xla/xla/pjrt/BUILD | 18 ++ third_party/xla/xla/pjrt/layout_mode.cc | 71 +++++ third_party/xla/xla/pjrt/layout_mode.h | 67 +++++ third_party/xla/xla/pjrt/utils.cc | 249 ++++++++++++++++++ third_party/xla/xla/pjrt/utils.h | 31 +++ third_party/xla/xla/python/xla_client_test.py | 234 ++++++++++++++++ 6 files changed, 670 insertions(+) create mode 100644 third_party/xla/xla/pjrt/layout_mode.cc create mode 100644 third_party/xla/xla/pjrt/layout_mode.h diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 448d1ee463a716..29bd730a136c16 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -302,6 +302,7 @@ cc_library( ":pjrt_device_description", ":pjrt_executable", "//xla/client:xla_computation", + "//xla/service:hlo_parser", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -344,6 +345,7 @@ cc_library( hdrs = ["utils.h"], visibility = ["//visibility:public"], deps = [ + ":layout_mode", "//xla:shape_util", "//xla:status", "//xla:status_macros", @@ -359,6 +361,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -366,6 +370,20 @@ cc_library( ], ) +cc_library( + name = "layout_mode", + srcs = ["layout_mode.cc"], + hdrs = ["layout_mode.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla/service:hlo_parser", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "metrics", srcs = ["metrics.cc"], diff --git a/third_party/xla/xla/pjrt/layout_mode.cc b/third_party/xla/xla/pjrt/layout_mode.cc new file mode 100644 index 00000000000000..458144d50865cf --- /dev/null +++ b/third_party/xla/xla/pjrt/layout_mode.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/layout_mode.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "xla/layout.h" +#include "xla/service/hlo_parser.h" +#include "xla/status.h" +#include "xla/statusor.h" + +namespace xla { + +LayoutMode::LayoutMode(Mode layout_mode, std::optional layout) + : mode(layout_mode), user_layout(std::move(layout)) { + if (mode == Mode::kUserSpecified) { + CHECK(user_layout) << "Must pass layout to LayoutMode constructor when " + "mode == kUserSpecified"; + } else { + CHECK(!user_layout) << "Only pass layout to LayoutMode constructor " + "if mode == kUserSpecified"; + } +} + +std::string LayoutMode::ToString() const { + switch (mode) { + case Mode::kDefault: + return "default"; + case Mode::kUserSpecified: + CHECK(user_layout); + return user_layout->ToString(); + case Mode::kAuto: + return "auto"; + } +} + +StatusOr LayoutMode::FromString(std::string s) { + if (s == "default") { + return LayoutMode(Mode::kDefault); + } + if (s == "auto") { + return LayoutMode(Mode::kAuto); + } + // LayoutMode is user-specified; parse Layout string + StatusOr layout = ParseLayout(s); + if (!layout.ok()) { + Status new_status(layout.status().code(), + absl::StrCat("Error parsing user-specified layout mode '", + s, "': ", layout.status().message())); + return new_status; + } + return LayoutMode(*layout); +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/layout_mode.h b/third_party/xla/xla/pjrt/layout_mode.h new file mode 100644 index 00000000000000..792bc3de9f9a10 --- /dev/null +++ b/third_party/xla/xla/pjrt/layout_mode.h @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_LAYOUT_MODE_H_ +#define XLA_PJRT_LAYOUT_MODE_H_ + +#include + +#include "xla/layout.h" +#include "xla/shape.h" +#include "xla/statusor.h" + +namespace xla { + +// Helper struct for specifying how to choose the layout for a value in a +// program to be compiled (e.g. a computation argument). +// +// The source of truth for this info is the "mhlo.layout_mode" string attribute +// of input MLIR modules. This struct can help manage the attribute. The +// ToString and FromString methods can be used to convert between this struct +// and the "mhlo.layout_mode" string attr. +struct LayoutMode { + enum class Mode { + // Use the default compact layout. + kDefault = 0, + // Use `layout`. + kUserSpecified, + // Let compiler choose layout. + kAuto + }; + Mode mode = Mode::kDefault; + + // Only set iff layout_mode == kUserSpecified. This is the layout of the + // per-device data, i.e. if the computation is sharded, the caller must choose + // both the sharding and layout for this value such that they're compatible. + std::optional user_layout; + + LayoutMode() = default; + explicit LayoutMode(Mode layout_mode, + std::optional layout = std::nullopt); + explicit LayoutMode(const Layout& layout) + : LayoutMode(Mode::kUserSpecified, layout) {} + explicit LayoutMode(const Shape& shape_with_layout) + : LayoutMode(Mode::kUserSpecified, shape_with_layout.layout()) {} + + // Produces a human-readable string representing this LayoutMode. Is also in + // the correct format for the "mhlo.layout_mode" attribute. + std::string ToString() const; + // Parses a string produced by LayoutMode::ToString() or Layout::ToString(). + static StatusOr FromString(std::string s); +}; + +} // namespace xla + +#endif // XLA_PJRT_LAYOUT_MODE_H_ diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 0171f6286da7ab..3967c91ffe002b 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -29,6 +29,9 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,6 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/layout_util.h" +#include "xla/pjrt/layout_mode.h" #include "xla/primitive_util.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" @@ -167,6 +171,251 @@ Status ParseDeviceAssignmentCompileOptions( return OkStatus(); } +// Helper method that takes an ArrayAttr of DictionaryAttrs for each arg or +// result of a function, and looks for "mhlo.layout_mode". `all_attrs` can be +// nullptr. `num_values` is the number of arguments or results. +static StatusOr> MlirAttrsToLayoutModes( + mlir::ArrayAttr all_attrs, size_t num_values) { + if (all_attrs == nullptr) { + return std::vector(num_values); + } + if (all_attrs.size() != num_values) { + return InvalidArgument( + "MlirAttrsToLayoutModes got unexpected number of attributes: %d, " + "expected: %d", + all_attrs.size(), num_values); + } + + std::vector result; + result.reserve(all_attrs.size()); + for (const mlir::Attribute& dict_attr : all_attrs) { + mlir::StringAttr attr = + dict_attr.cast().getAs( + "mhlo.layout_mode"); + if (attr != nullptr) { + TF_ASSIGN_OR_RETURN(LayoutMode mode, + LayoutMode::FromString(attr.getValue().str())); + result.emplace_back(std::move(mode)); + } else { + result.emplace_back(); + } + } + return result; +} + +// Helper function for getting default LayoutModes for tupled arguments or +// outputs. Returns nullopt if the arguments/outputs are not tupled. Raises an +// error if layout modes are requested on tupled values. +static StatusOr>> GetTupleLayoutModes( + mlir::ArrayRef types, mlir::ArrayAttr all_attrs) { + if (types.size() != 1 || !llvm::isa(types[0])) { + return std::nullopt; + } + if (all_attrs != nullptr) { + if (all_attrs.size() != 1) { + return InvalidArgument( + "GetTupleLayoutModes expected single tuple attr, got %d attrs", + all_attrs.size()); + } + mlir::StringAttr attr = + all_attrs.begin()->cast().getAs( + "mhlo.layout_mode"); + if (attr != nullptr) { + return Unimplemented("mhlo.layout_mode not supported with tupled values"); + } + } + // Use default layout for all outputs. + return std::vector(types[0].cast().size()); +} + +StatusOr> GetArgLayoutModes(mlir::ModuleOp module) { + mlir::func::FuncOp main = module.lookupSymbol("main"); + if (main == nullptr) { + return InvalidArgument( + "GetArgLayoutModes passed module without main function"); + } + + // Special case: tupled arguments + TF_ASSIGN_OR_RETURN(std::optional> maybe_result, + GetTupleLayoutModes(main.getFunctionType().getInputs(), + main.getAllArgAttrs())); + if (maybe_result) return *maybe_result; + + return MlirAttrsToLayoutModes(main.getAllArgAttrs(), main.getNumArguments()); +} + +StatusOr> GetOutputLayoutModes(mlir::ModuleOp module) { + mlir::func::FuncOp main = module.lookupSymbol("main"); + if (main == nullptr) { + return InvalidArgument( + "GetOutputLayoutModes passed module without main function"); + } + + // Special case: tupled outputs + TF_ASSIGN_OR_RETURN(std::optional> maybe_tuple_result, + GetTupleLayoutModes(main.getFunctionType().getResults(), + main.getAllResultAttrs())); + if (maybe_tuple_result) return *maybe_tuple_result; + + return MlirAttrsToLayoutModes(main.getAllResultAttrs(), main.getNumResults()); +} + +static StatusOr LayoutModeToXlaShape( + const LayoutMode& layout_mode, const Shape& unsharded_shape, + const Shape& sharded_shape, + std::function(Shape)> + choose_compact_layout_for_shape_function) { + if (unsharded_shape.IsToken() || unsharded_shape.IsOpaque()) { + return unsharded_shape; + } + if (!unsharded_shape.IsArray() || !sharded_shape.IsArray()) { + return InvalidArgument( + "LayoutModeToXlaShape must be passed array shapes, got " + "unsharded_shape: %s, sharded_shape: %s", + unsharded_shape.ToString(), sharded_shape.ToString()); + } + // For sharded computations, XLA expects the layout to specified as the global + // shape with the sharded layout. + Shape result = unsharded_shape; + LayoutUtil::ClearLayout(&result); + switch (layout_mode.mode) { + case LayoutMode::Mode::kDefault: { + TF_ASSIGN_OR_RETURN( + Shape layout, + choose_compact_layout_for_shape_function(sharded_shape)); + *result.mutable_layout() = layout.layout(); + break; + } + case LayoutMode::Mode::kUserSpecified: { + CHECK(layout_mode.user_layout); + *result.mutable_layout() = *layout_mode.user_layout; + break; + } + case LayoutMode::Mode::kAuto: { + // Don't set any layout on `result`. + break; + } + } + return result; +} + +StatusOr, Shape>> LayoutModesToXlaShapes( + const XlaComputation& computation, std::vector arg_layout_modes, + std::vector out_layout_modes, + std::function(Shape)> + choose_compact_layout_for_shape_function) { + // Compute sharded argument and output shapes. + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(auto sharded_shapes, + GetShardedProgramShapes(computation, program_shape)); + + // Untuple if necessary. + bool args_tupled = program_shape.parameters_size() == 1 && + program_shape.parameters(0).IsTuple(); + const std::vector& unsharded_arg_shapes = + args_tupled ? program_shape.parameters(0).tuple_shapes() + : program_shape.parameters(); + const std::vector& sharded_arg_shapes = + args_tupled ? sharded_shapes.first[0].tuple_shapes() + : sharded_shapes.first; + + bool out_tupled = program_shape.result().IsTuple(); + const std::vector& unsharded_out_shapes = + out_tupled ? program_shape.result().tuple_shapes() + : std::vector{program_shape.result()}; + const std::vector& sharded_out_shapes = + out_tupled ? sharded_shapes.second.tuple_shapes() + : std::vector{sharded_shapes.second}; + + if (unsharded_arg_shapes.size() != arg_layout_modes.size()) { + return InvalidArgument( + "LayoutModesToXlaShapes got mismatched number of arguments and layout " + "modes (%d vs %d)", + unsharded_arg_shapes.size(), arg_layout_modes.size()); + } + if (sharded_arg_shapes.size() != arg_layout_modes.size()) { + return InvalidArgument( + "LayoutModesToXlaShapes got mismatched number of sharded arguments and " + "layout modes (%d vs %d)", + sharded_arg_shapes.size(), arg_layout_modes.size()); + } + if (unsharded_out_shapes.size() != out_layout_modes.size()) { + return InvalidArgument( + "LayoutModesToXlaShapes got mismatched number of outputs and layout " + "modes (%d vs %d)", + unsharded_out_shapes.size(), out_layout_modes.size()); + } + if (sharded_out_shapes.size() != out_layout_modes.size()) { + return InvalidArgument( + "LayoutModesToXlaShapes got mismatched number of sharded outputs and " + "layout modes (%d vs %d)", + sharded_out_shapes.size(), out_layout_modes.size()); + } + + // Convert each LayoutMode to an xla::Shape with the appropriate Layout set or + // unset. + std::vector flat_arg_layouts; + flat_arg_layouts.reserve(arg_layout_modes.size()); + for (int i = 0; i < arg_layout_modes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + Shape layout, + LayoutModeToXlaShape(arg_layout_modes[i], unsharded_arg_shapes[i], + sharded_arg_shapes[i], + choose_compact_layout_for_shape_function)); + flat_arg_layouts.emplace_back(std::move(layout)); + } + std::vector flat_out_layouts; + flat_out_layouts.reserve(out_layout_modes.size()); + for (int i = 0; i < out_layout_modes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + Shape layout, + LayoutModeToXlaShape(out_layout_modes[i], unsharded_out_shapes[i], + sharded_out_shapes[i], + choose_compact_layout_for_shape_function)); + flat_out_layouts.emplace_back(std::move(layout)); + } + + // Tuple final shapes if necessary. + std::vector arg_layouts = + args_tupled + ? std::vector{ShapeUtil::MakeTupleShape(flat_arg_layouts)} + : std::move(flat_arg_layouts); + Shape out_layout = out_tupled ? ShapeUtil::MakeTupleShape(flat_out_layouts) + : flat_out_layouts[0]; + + return std::pair, Shape>{std::move(arg_layouts), + std::move(out_layout)}; +} + +StatusOr, std::vector>> +LayoutModesToXla(const XlaComputation& computation, + std::vector arg_layout_modes, + std::vector out_layout_modes, + std::function(Shape)> + choose_compact_layout_for_shape_function, + ExecutableBuildOptions& build_options) { + TF_ASSIGN_OR_RETURN( + auto pair, + LayoutModesToXlaShapes(computation, arg_layout_modes, out_layout_modes, + choose_compact_layout_for_shape_function)); + std::vector& arg_layouts = pair.first; + Shape& out_layout = pair.second; + + // Generate result vector of pointers + std::vector arg_layout_pointers; + arg_layout_pointers.reserve(arg_layouts.size()); + for (int i = 0; i < arg_layouts.size(); ++i) { + arg_layout_pointers.push_back(&arg_layouts[i]); + } + + // Update build_options + build_options.set_result_layout(out_layout); + + return std::pair, std::vector>{ + std::move(arg_layouts), std::move(arg_layout_pointers)}; +} + Status DetermineArgumentLayoutsFromCompileOptions( const XlaComputation& computation, std::function(Shape)> diff --git a/third_party/xla/xla/pjrt/utils.h b/third_party/xla/xla/pjrt/utils.h index 8d47cad3bc5aff..ae6129b1e94adb 100644 --- a/third_party/xla/xla/pjrt/utils.h +++ b/third_party/xla/xla/pjrt/utils.h @@ -24,9 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/layout_mode.h" #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/status.h" @@ -44,6 +46,35 @@ Status ParseDeviceAssignmentCompileOptions( int* num_replicas, int* num_partitions, std::shared_ptr* device_assignment); +// Returns the LayoutMode for each argument of the main function in the +// module. Checks for the "mhlo.layout_mode" attr, and if not present, assumes +// LayoutMode::Mode::kDefault. +StatusOr> GetArgLayoutModes(mlir::ModuleOp module); +// Returns the LayoutMode for each output of the main function in the +// module. Checks for the "mhlo.layout_mode" attr, and if not present, assumes +// LayoutMode::Mode::kDefault. +StatusOr> GetOutputLayoutModes(mlir::ModuleOp module); + +// Returns (arg shapes, output shape) with properly-set Layouts that can +// be passed to XLA to reflect arg_layout_modes and out_layout_modes. +StatusOr, Shape>> LayoutModesToXlaShapes( + const XlaComputation& computation, std::vector arg_layout_modes, + std::vector out_layout_modes, + std::function(Shape)> + choose_compact_layout_for_shape_function); + +// Generates useful data structures for communciating desired layouts to XLA: +// * Returns a vector of argument xla::Shapes with properly-set Layouts +// * Returns vector of pointers to those Shapes to create HloModuleConfig +// * Modifies `build_options` to have the correct result_layout set or unset +StatusOr, std::vector>> +LayoutModesToXla(const XlaComputation& computation, + std::vector arg_layout_modes, + std::vector out_layout_modes, + std::function(Shape)> + choose_compact_layout_for_shape_function, + ExecutableBuildOptions& build_options); + // Returns pointers to the argument layouts given an XlaComputation and // ExecutableBuildOptions. Status DetermineArgumentLayoutsFromCompileOptions( diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 5d7bd289098ee0..511b6bf13b2074 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -571,6 +571,55 @@ def testGetOutputLayouts(self): @unittest.skipIf(pathways or pathways_ifrt, "not implemented") def testSetArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0,1,2}"}, + %arg1: tensor {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 3) + self.assertEqual(input_layouts[0].minor_to_major(), (0, 1, 2)) + self.assertEqual(input_layouts[1].minor_to_major(), ()) + self.assertEqual(input_layouts[2].minor_to_major(), (0,)) + + # Compile a version with default arg0 layout so we can make sure we + # actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + input_layouts[0].minor_to_major(), + default_executable.get_parameter_layouts()[0].minor_to_major()) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testSetArgumentLayoutsLegacy(self): + """Tests setting the arg layouts with compile_options (deprecated). + + New code should use the mhlo.layout_mode string attr on parameters. + """ # Create computation with custom input layouts. c = self._NewComputation() param_count = 0 @@ -609,6 +658,191 @@ def MakeArg(shape, dtype, layout): self.assertEqual(actual.minor_to_major(), expected.layout().minor_to_major()) + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testSetOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]", + mhlo.layout_mode = "{0,1,2}"}, + tensor {jax.result_info = "[1]", + mhlo.layout_mode = "{}"}, + tensor<10xf32> {jax.result_info = "[2]", + mhlo.layout_mode = "{0}"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check output layouts. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 3) + self.assertEqual(output_layouts[0].minor_to_major(), (0, 1, 2)) + self.assertEqual(output_layouts[1].minor_to_major(), ()) + self.assertEqual(output_layouts[2].minor_to_major(), (0,)) + + # Compile a version with default first output layout so we can make sure + # we actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + output_layouts[0].minor_to_major(), + default_executable.get_output_layouts()[0].minor_to_major()) + + @unittest.skipIf(pathways, "not implemented") + def SetLayoutsSharded(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) + # x = jax.device_put(np.ones((1024, 128)), sharding.reshape(4, 2)) + # jax.jit(lambda x, y: x + y, out_shardings=sharding)(x, 1.) + # + # This also lightly tests mixed default + user-specified input layouts. + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x128xf32> {mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x128xf32> {jax.result_info = "", + mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}) { + %0 = stablehlo.convert %arg1 : tensor + %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1024x128xf32> + %2 = stablehlo.add %arg0, %1 : tensor<1024x128xf32> + return %2 : tensor<1024x128xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 2) + self.assertEqual(input_layouts[0].minor_to_major(), (0, 1)) + self.assertEqual(input_layouts[1].minor_to_major(), ()) + + # Check output layout. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 1) + self.assertEqual(input_layouts[0].minor_to_major(), (0, 1)) + + # Compile a version with default layouts so we can make sure we actually + # set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1}"', '"default"') + ) + self.assertNotEqual( + input_layouts[0].minor_to_major(), + default_executable.get_parameter_layouts()[0].minor_to_major()) + self.assertNotEqual( + output_layouts[0].minor_to_major(), + default_executable.get_output_layouts()[0].minor_to_major()) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testAutoArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}) + -> (tensor<1024x8x128xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertEqual(input_layouts[0].minor_to_major(), (1, 0)) + self.assertEqual(input_layouts[1].minor_to_major(), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default layout for the second + # (1024,8,128) argument. + self.assertNotEqual( + input_layouts[1].minor_to_major(), + default_executable.get_parameter_layouts()[1].minor_to_major(), + ) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testAutoOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Generated with jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "", + mhlo.layout_mode = "auto"}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check output layout + output_layout, = executable.get_output_layouts() + self.assertEqual(output_layout.minor_to_major(), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default output layout. + self.assertNotEqual( + output_layout.minor_to_major(), + default_executable.get_output_layouts()[0].minor_to_major(), + ) + tests.append(LayoutsTest) class BufferTest(ComputationTest): From 29a39af6b58597a9361649f179639e1354da2147 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 15:23:04 -0800 Subject: [PATCH 038/391] Re-enable layering_check for target. PiperOrigin-RevId: 582101537 --- tensorflow/python/saved_model/BUILD | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 967dfd181fc179..3ae6d97eb18ef8 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -915,7 +915,6 @@ tf_python_pybind_extension( # "//tensorflow:windows": [], # }), # static_deps = tf_python_pybind_static_deps(), - features = ["-layering_check"], pytype_srcs = [ "pywrap_saved_model/__init__.pyi", "pywrap_saved_model/constants.pyi", @@ -929,17 +928,27 @@ tf_python_pybind_extension( "//tensorflow/python/training:__subpackages__", ], deps = [ - ":pywrap_saved_model_headers", # placeholder for index annotation deps "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "//tensorflow/cc/experimental/libexport:save", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:fingerprinting", + "//tensorflow/cc/saved_model:metrics", "//tensorflow/cc/saved_model:reader", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", "//tensorflow/python/lib/core:pybind11_status", "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], + ] + if_google([ + "//tensorflow/tools/proto_splitter:merge", + ]), ) tf_py_strict_test( From a6e31197a7ac83d1c550af6d08475d870487f722 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 13 Nov 2023 15:29:54 -0800 Subject: [PATCH 039/391] Disable failing test on JAX GPU CI runs and clean up build script PiperOrigin-RevId: 582103251 --- third_party/xla/.kokoro/jax/build.sh | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/third_party/xla/.kokoro/jax/build.sh b/third_party/xla/.kokoro/jax/build.sh index 95f5fbe1e623a2..18a4c388ac4b37 100644 --- a/third_party/xla/.kokoro/jax/build.sh +++ b/third_party/xla/.kokoro/jax/build.sh @@ -51,7 +51,6 @@ prelude() { # Install bazel update_bazel_linux - chmod +x "${KOKORO_GFILE_DIR}/bazel_wrapper.py" cd jax } @@ -64,11 +63,11 @@ build_and_test_on_rbe_cpu() { --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ --config=mkl_open_source_only \ - --config="$NOCUDA_RBE_CONFIG_NAME" \ + --config="rbe_cpu_linux_py312" \ --config=tensorflow_testing_rbe_linux \ --test_env=JAX_NUM_GENERATED_CASES=25 \ - //tests:cpu_tests //tests:backend_independent_tests \ - --test_output=errors + --test_output=errors \ + -- //tests:cpu_tests //tests:backend_independent_tests } build_and_test_on_rbe_gpu() { @@ -77,30 +76,29 @@ build_and_test_on_rbe_gpu() { # we need to add `--remote_instance_name` and `--bes_instance_name`. Why this # is only needed for gpu is still a mystery. + + # TODO(ddunleavy): reenable `LaxTest.testBitcastConvertType1` bazel \ test \ --verbose_failures=true \ - //tests:gpu_tests //tests:backend_independent_tests \ --override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \ --config=avx_posix \ --config=mkl_open_source_only \ - --config="$CUDA_RBE_CONFIG_NAME" \ + --config="rbe_linux_cuda12.2_nvcc_py3.9" \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --test_env=JAX_SKIP_SLOW_TESTS=1 \ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ - --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow|LaxTest.testBitcastConvertType1" \ --test_tag_filters=-multiaccelerator \ --remote_instance_name=projects/tensorflow-testing/instances/default_instance \ - --bes_instance_name="tensorflow-testing" + --bes_instance_name="tensorflow-testing" \ + -- //tests:gpu_tests //tests:backend_independent_tests } # Generate a templated results file to make output accessible to everyone "$KOKORO_ARTIFACTS_DIR"/github/xla/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html -NOCUDA_RBE_CONFIG_NAME="rbe_cpu_linux_py312" -CUDA_RBE_CONFIG_NAME="rbe_linux_cuda12.2_nvcc_py3.9" - prelude if is_linux_gpu_job ; then From 91cf4fd5a63d970f580bb95baf0324cf6ccd2081 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 13 Nov 2023 16:34:04 -0800 Subject: [PATCH 040/391] Cleanup: Remove `StaticRangeQuantizationTest.test_matmul_ptq_model_stablehlo`. This is a duplicate test of the one at `.../stablehlo/python/integration_test/quantize_model_test.py`. Test cases related to stablehlo will be maintained here instead. PiperOrigin-RevId: 582122026 --- .../integration_test/quantize_model_test.py | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 4a1fd7148b0c15..35b8814a4605d1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -2518,68 +2518,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: else: self.assertAllClose(new_outputs, expected_outputs, atol=0.13) - @test_util.run_in_graph_and_eager_modes - def test_matmul_ptq_model_stablehlo(self): - activation_fn = None - has_bias = False - batch_sizes = ([], []) - target_opset = quant_opts_pb2.STABLEHLO - - lhs_batch_size, rhs_batch_size = batch_sizes - input_shape = (*lhs_batch_size, 1, 1024) - filter_shape = (*rhs_batch_size, 1024, 3) - static_input_shape = [dim if dim is not None else 2 for dim in input_shape] - model = self._create_matmul_model( - input_shape, - filter_shape, - self._input_saved_model_path, - has_bias, - activation_fn, - ) - rng = np.random.default_rng(seed=1234) - - input_data = ops.convert_to_tensor( - rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( - np.float32 - ) - ) - expected_outputs = model.matmul(input_data) - - def data_gen() -> repr_dataset.RepresentativeDataset: - for _ in range(100): - yield { - 'input_tensor': rng.uniform( - low=0.0, high=1.0, size=static_input_shape - ).astype(np.float32) - } - - quantization_options = quant_opts_pb2.QuantizationOptions( - quantization_method=quant_opts_pb2.QuantizationMethod( - preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 - ), - tags={tag_constants.SERVING}, - signature_keys=['serving_default'], - op_set=target_opset, - ) - converted_model = quantize_model.quantize( - self._input_saved_model_path, - self._output_saved_model_path, - quantization_options, - representative_dataset=data_gen(), - ) - - self.assertIsNotNone(converted_model) - self.assertCountEqual( - converted_model.signatures._signatures.keys(), {'serving_default'} - ) - - new_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data) - ) - # Tests that the quantized graph outputs similar values. The rtol value is - # arbitrary. - self.assertAllClose(new_outputs, expected_outputs, rtol=0.02) - @parameterized.named_parameters( { 'testcase_name': 'with_biasadd', From 5a0cddbc2cce85ac0f4daff1e7761b4b9498c61a Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 13 Nov 2023 16:35:11 -0800 Subject: [PATCH 041/391] Add `py::arg`s for `pywrap_funciton_lib.cc` PiperOrigin-RevId: 582122276 --- .../mlir/quantization/tensorflow/python/BUILD | 1 + .../tensorflow/python/pywrap_function_lib.cc | 13 +++++++++--- .../tensorflow/python/quantize_model.py | 20 +++++++++---------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index bb2c38d4971141..a7bc257909d8b2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -171,6 +171,7 @@ tf_python_pybind_extension( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/python/lib/core:pybind11_lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", "@pybind11", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index 6525ee59bf7014..48464cef4341b5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -15,12 +15,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" @@ -64,6 +66,11 @@ PYBIND11_MODULE(pywrap_function_lib, m) { m, "PyFunctionLibrary") .def(py::init<>()) .def("assign_ids_to_custom_aggregator_ops", - &PyFunctionLibrary::AssignIdsToCustomAggregatorOps) - .def("save_exported_model", &PyFunctionLibrary::SaveExportedModel); + &PyFunctionLibrary::AssignIdsToCustomAggregatorOps, + py::arg("exported_model_serialized")) + .def("save_exported_model", &PyFunctionLibrary::SaveExportedModel, + py::arg("dst_saved_model_path"), + py::arg("exported_model_serialized"), + py::arg("src_saved_model_path"), py::arg("tags"), + py::arg("serialized_signature_def_map")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index d3abb29977d874..959d20f4148f3d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -731,11 +731,11 @@ def _run_static_range_ptq( # TODO: b/296916287 - Create a separate function for saving unquantized # dump model py_function_library.save_exported_model( - quant_opts.debugger_options.unquantized_dump_model_path, - exported_model.SerializeToString(), - src_saved_model_path, - quant_opts.tags, - signature_def_map_serialized, + dst_saved_model_path=quant_opts.debugger_options.unquantized_dump_model_path, + exported_model_serialized=exported_model.SerializeToString(), + src_saved_model_path=src_saved_model_path, + tags=quant_opts.tags, + serialized_signature_def_map=signature_def_map_serialized, ) _change_dump_tensor_file_name(graph_def) @@ -744,11 +744,11 @@ def _run_static_range_ptq( # TODO: b/309601030 - Integrate model functionality to # `quantize_ptq_model_pre_calibration`. py_function_library.save_exported_model( - calibrated_model_path, - exported_model.SerializeToString(), - pre_calib_output_model_path, - quant_opts.tags, - signature_def_map_serialized, + dst_saved_model_path=calibrated_model_path, + exported_model_serialized=exported_model.SerializeToString(), + src_saved_model_path=pre_calib_output_model_path, + tags=quant_opts.tags, + serialized_signature_def_map=signature_def_map_serialized, ) logging.info('Running post-training quantization post-calibration step.') From 45889deaa34cd5982f10410bf60eb88f6e66bb0e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 16:44:15 -0800 Subject: [PATCH 042/391] If TensorBoard fails during reading Debugger V2 files then their names will be logged. PiperOrigin-RevId: 582124510 --- tensorflow/python/debug/lib/debug_events_reader.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py index 706823b799b14a..2b38a4ca4d34ff 100644 --- a/tensorflow/python/debug/lib/debug_events_reader.py +++ b/tensorflow/python/debug/lib/debug_events_reader.py @@ -109,8 +109,15 @@ def _load_metadata_files(self): wall_times.append(debug_event.wall_time) run_ids.append(debug_event.debug_metadata.tfdbg_run_id) tensorflow_versions.append( - debug_event.debug_metadata.tensorflow_version) + debug_event.debug_metadata.tensorflow_version + ) file_versions.append(debug_event.debug_metadata.file_version) + except Exception as e: + raise errors.DataLossError( + None, + None, + "Error reading tfdbg metadata from paths %s" % metadata_paths, + ) from e finally: reader.close() self._starting_wall_time = wall_times[0] From d86ef5049be84d8c95b7edd3bae9441b84d8010f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 16:52:02 -0800 Subject: [PATCH 043/391] Produces a repack summary that reports the # of successful repacks. PiperOrigin-RevId: 582126484 --- .../xla/service/memory_space_assignment/BUILD | 8 +++++++ .../memory_space_assignment.cc | 21 +++++++++++++++++++ .../memory_space_assignment.h | 9 ++++++++ 3 files changed, 38 insertions(+) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index e16247c4eddcec..475f9c67cdfd2a 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -44,13 +44,18 @@ cc_library( "//xla:debug_options_flags", "//xla:shape_util", "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", + "//xla/service:call_graph", "//xla/service:heap_simulator", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:time_utils", @@ -58,9 +63,12 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:errors", diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index d176392d025d9a..a968b304940d4d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -36,17 +37,30 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" +#include "xla/service/call_graph.h" #include "xla/service/heap_simulator.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/memory_space_assignment/tuning_utils.h" @@ -56,6 +70,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" @@ -3835,6 +3850,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { ImportRepackedAllocations(); --retry_number; } + if (*repack_status) { + ++num_repacks_successful_; + } } else { // Check if any of the allocation sites are inefficient. If so, get rid // of the pending allocation, require all of the inefficient sites in @@ -3884,6 +3902,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } } + VLOG(1) << "Repack summary: " << num_repacks_successful_ + << " succeeded out of " << num_repacks_; + VLOG(3) << "Debug buffer info: "; XLA_VLOG_LINES(3, buffer_info_str_); VLOG(3) << "Debug allocation info: "; diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index 59be577d338287..371b2c24708135 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_H_ +#include #include #include #include @@ -35,19 +36,26 @@ limitations under the License. #include "absl/container/btree_map.h" #endif #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" +#include "xla/service/call_graph.h" #include "xla/service/heap_simulator.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/statusor.h" +#include "xla/util.h" namespace xla { namespace memory_space_assignment { @@ -2703,6 +2711,7 @@ class AlternateMemoryBestFitHeap // for aliased allocations. std::list repack_allocation_blocks_; int64_t num_repacks_ = 0; + int64_t num_repacks_successful_ = 0; std::vector> pending_chunks_; std::vector pending_async_copies_; std::vector> From dbd227e9c758b76e3fa2cdfefe9472629c24712d Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 13 Nov 2023 16:52:53 -0800 Subject: [PATCH 044/391] Add a new quantization function at `stablehlo/python/quantization.py`. PiperOrigin-RevId: 582126707 --- .../mlir/quantization/stablehlo/python/BUILD | 12 +++- .../integration_test/quantize_model_test.py | 59 +++++++++++++++---- .../stablehlo/python/quantization.py | 49 +++++++++++++++ 3 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index b34c9574d90568..30c6f2b0359a38 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -22,6 +22,15 @@ package( licenses = ["notice"], ) +pytype_strict_library( + name = "quantization", + srcs = ["quantization.py"], + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model", + ], +) + pytype_strict_library( name = "quantize_model_test_base", testonly = 1, @@ -47,13 +56,14 @@ tf_py_strict_test( name = "quantize_model_test", srcs = ["integration_test/quantize_model_test.py"], deps = [ + ":quantization", ":quantize_model_test_base", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", - "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model", "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:tag_constants", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 91132777488a02..b8236b6b989016 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -18,13 +18,14 @@ from absl.testing import parameterized import numpy as np +from tensorflow.compiler.mlir.quantization.stablehlo.python import quantization from tensorflow.compiler.mlir.quantization.stablehlo.python.integration_test import quantize_model_test_base from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.platform import test +from tensorflow.python.saved_model import load from tensorflow.python.saved_model import tag_constants # Type aliases for quantization method protobuf enums. @@ -73,14 +74,13 @@ def test_matmul_ptq_model( has_bias, activation_fn, ) - rng = np.random.default_rng(seed=1235) + rng = np.random.default_rng(seed=1235) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 ) ) - expected_outputs = model.matmul(input_data) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(100): @@ -90,34 +90,67 @@ def data_gen() -> repr_dataset.RepresentativeDataset: ).astype(np.float32) } - quantization_options = quant_opts_pb2.QuantizationOptions( + dataset_path = self.create_tempfile('tfrecord').full_path + path_map = {'serving_default': dataset_path} + repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save( + {'serving_default': data_gen()} + ) + + config = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 ), tags={tag_constants.SERVING}, signature_keys=['serving_default'], op_set=target_opset, + representative_datasets={ + 'serving_default': quant_opts_pb2.RepresentativeDatasetFile( + tfrecord_file_path=dataset_path + ) + }, ) - converted_model = quantize_model.quantize( + quantization.quantize_saved_model( self._input_saved_model_path, self._output_saved_model_path, - quantization_options, - representative_dataset=data_gen(), + config, ) - self.assertIsNotNone(converted_model) - self.assertCountEqual( - converted_model.signatures._signatures.keys(), {'serving_default'} - ) + expected_outputs = model.matmul(input_data) - new_outputs = converted_model.signatures['serving_default']( + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( input_tensor=ops.convert_to_tensor(input_data) ) # Tests that the quantized graph outputs similar values. The rtol value is # arbitrary. - # TODO(b/309674337): Fix the large numerical errors. + # TODO: b/309674337 - Fix the large numerical errors. self.assertAllClose(new_outputs, expected_outputs, rtol=0.3) + def test_when_preset_not_srq_raise_error(self): + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) + + config = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + preset_method=_PresetMethod.METHOD_NO_QUANTIZE + ), + tags={tag_constants.SERVING}, + signature_keys=['serving_default'], + op_set=quant_opts_pb2.STABLEHLO, + ) + + with self.assertRaisesRegex(ValueError, 'only supports static-range PTQ'): + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py new file mode 100644 index 00000000000000..5eefe2b94bbb7e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/quantization.py @@ -0,0 +1,49 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""StableHLO Quantizer.""" +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model + + +# TODO: b/310594193 - Export API to pip package. +def quantize_saved_model( + src_saved_model_path: str, + dst_saved_model_path: str, + config: quant_opts_pb2.QuantizationOptions, +) -> None: + """Quantizes a saved model. + + Args: + src_saved_model_path: Path to the directory for the source SavedModel. + dst_saved_model_path: Path to the directory for the destination SavedModel. + config: Quantization configuration. + + Raises: + ValueError: When `config` was not configured for static-range PTQ + single representative dataset. + """ + if not ( + config.quantization_method.preset_method + == quant_opts_pb2.QuantizationMethod.PresetMethod.METHOD_STATIC_RANGE_INT8 + and len(config.representative_datasets) == 1 + ): + raise ValueError( + '`quantize_saved_model` currently only supports static-range PTQ with a' + ' single signature.' + ) + + # TODO: b/307624867 - Remove TF Quantizer dependency and replace it with + # StableHLO Quantizer components. + quantize_model.quantize(src_saved_model_path, dst_saved_model_path, config) From 0409e7335be729cff06251b13a1418093ed5ed32 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Nov 2023 16:54:34 -0800 Subject: [PATCH 045/391] Make the CPU backend participate in distributed initialization. The main effect of this change is that CPU devices end up with a unique global ID and the correct process index. PiperOrigin-RevId: 582127068 --- third_party/xla/xla/pjrt/BUILD | 2 + .../xla/xla/pjrt/tfrt_cpu_pjrt_client.cc | 62 +++-- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.h | 27 +- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/xla.cc | 27 +- third_party/xla/xla/python/xla_client.py | 157 +++++++---- third_party/xla/xla/python/xla_client.pyi | 6 +- .../xla/xla/python/xla_extension/__init__.pyi | 261 +++++++++++------- 8 files changed, 360 insertions(+), 183 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 29bd730a136c16..abbb5886b81cdd 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -678,6 +678,7 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:topology_util", "//xla/runtime:cpu_event", "//xla/service:buffer_assignment", "//xla/service:compiler", @@ -706,6 +707,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc index 863f46ce2725af..12e5cce95e42d2 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -57,6 +58,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/abstract_tfrt_cpu_buffer.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -236,7 +238,11 @@ class TfrtCpuAsyncHostToDeviceTransferManager } // namespace -TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id) : id_(id) { +TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id, int process_index, + int local_hardware_id) + : id_(id), + process_index_(process_index), + local_hardware_id_(local_hardware_id) { debug_string_ = absl::StrCat("TFRT_CPU_", id); to_string_ = absl::StrCat("CpuDevice(id=", id, ")"); } @@ -253,8 +259,9 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { return to_string_; } -TfrtCpuDevice::TfrtCpuDevice(int id, int max_inflight_computations) - : description_(id), +TfrtCpuDevice::TfrtCpuDevice(int id, int process_index, int local_hardware_id, + int max_inflight_computations) + : description_(id, process_index, local_hardware_id), max_inflight_computations_semaphore_( /*capacity=*/max_inflight_computations) {} @@ -281,30 +288,47 @@ static int CpuDeviceCount() { return GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } -static StatusOr>> GetTfrtCpuDevices( - int cpu_device_count, int max_inflight_computations_per_device) { - std::vector> devices; - for (int i = 0; i < cpu_device_count; ++i) { - auto device = std::make_unique( - /*id=*/i, max_inflight_computations_per_device); - devices.push_back(std::move(device)); - } - return std::move(devices); -} - StatusOr> GetTfrtCpuClient( const CpuClientOptions& options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); - TF_ASSIGN_OR_RETURN( - std::vector> devices, - GetTfrtCpuDevices(cpu_device_count, - options.max_inflight_computations_per_device)); + LocalTopologyProto local_topology; + local_topology.set_node_id(options.node_id); + std::string boot_id_str; + auto boot_id_str_or_status = GetBootIdString(); + if (!boot_id_str_or_status.ok()) { + LOG(INFO) << boot_id_str_or_status.status(); + } else { + boot_id_str = boot_id_str_or_status.value(); + } + local_topology.set_boot_id(boot_id_str); + for (int i = 0; i < cpu_device_count; ++i) { + DeviceProto* device_proto = local_topology.add_devices(); + device_proto->set_local_device_ordinal(i); + device_proto->set_name("cpu"); + } + + GlobalTopologyProto global_topology; + TF_RETURN_IF_ERROR( + ExchangeTopologies("cpu", options.node_id, options.num_nodes, + absl::Minutes(2), absl::Minutes(5), options.kv_get, + options.kv_put, local_topology, &global_topology)); + + std::vector> devices; + for (const LocalTopologyProto& node : global_topology.nodes()) { + for (const DeviceProto& device_proto : node.devices()) { + auto device = std::make_unique( + /*id=*/device_proto.global_device_id(), node.node_id(), + device_proto.local_device_ordinal(), + options.max_inflight_computations_per_device); + devices.push_back(std::move(device)); + } + } return std::unique_ptr(std::make_unique( - /*process_index=*/0, std::move(devices), num_threads)); + /*process_index=*/options.node_id, std::move(devices), num_threads)); } TfrtCpuClient::TfrtCpuClient( diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h index c744d10ac3a9ea..e9543ab92e93e7 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -63,11 +63,13 @@ namespace xla { class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { public: - explicit TfrtCpuDeviceDescription(int id); + TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); int id() const override { return id_; } - int process_index() const override { return 0; } + int process_index() const override { return process_index_; } + + int local_hardware_id() const { return local_hardware_id_; } absl::string_view device_kind() const override; @@ -82,6 +84,8 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { private: int id_; + int process_index_; + int local_hardware_id_; std::string debug_string_; std::string to_string_; absl::flat_hash_map attributes_ = {}; @@ -89,7 +93,8 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { class TfrtCpuDevice final : public PjRtDevice { public: - explicit TfrtCpuDevice(int id, int max_inflight_computations = 32); + explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, + int max_inflight_computations = 32); const TfrtCpuDeviceDescription& description() const override { return description_; @@ -106,8 +111,9 @@ class TfrtCpuDevice final : public PjRtDevice { return process_index() == client()->process_index(); } - // Used as `device_ordinal`. - int local_hardware_id() const override { return id(); } + int local_hardware_id() const override { + return description_.local_hardware_id(); + } Status TransferToInfeed(const LiteralSlice& literal) override; @@ -518,6 +524,17 @@ struct CpuClientOptions { std::optional cpu_device_count = std::nullopt; int max_inflight_computations_per_device = 32; + + // Number of distributed nodes. node_id, kv_get, and kv_put are ignored if + // this is set to 1. + int num_nodes = 1; + + // My node ID. + int node_id = 0; + + // KV store primitives for sharing topology information. + PjRtClient::KeyValueGetCallback kv_get = nullptr; + PjRtClient::KeyValuePutCallback kv_put = nullptr; }; StatusOr> GetTfrtCpuClient( const CpuClientOptions& options); diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index c0a5938032acd7..c7367a7f35ba59 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1153,6 +1153,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:literal", diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index a4db3deef34730..b2269a423e5393 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "pybind11/attr.h" // from @pybind11 #include "pybind11/cast.h" // from @pybind11 @@ -492,16 +493,38 @@ static void Init(py::module_& m) { m.def( "get_tfrt_cpu_client", - [](bool asynchronous) -> std::shared_ptr { + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes) -> std::shared_ptr { py::gil_scoped_release gil_release; CpuClientOptions options; + if (distributed_client != nullptr) { + std::string key_prefix = "cpu:"; + options.kv_get = + [distributed_client, key_prefix]( + const std::string& k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + options.kv_put = [distributed_client, key_prefix]( + const std::string& k, + const std::string& v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), + v); + }; + options.node_id = node_id; + options.num_nodes = num_nodes; + } + options.asynchronous = asynchronous; std::unique_ptr client = xla::ValueOrThrow(GetTfrtCpuClient(options)); return std::make_shared( ifrt::PjRtClient::Create(std::move(client))); }, - py::arg("asynchronous") = true); + py::arg("asynchronous") = true, py::arg("distributed_client") = nullptr, + py::arg("node_id") = 0, py::arg("num_nodes") = 1); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { xla::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 6239cce3b38351..eba39f47830c89 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 213 +_version = 214 # Version number for MLIR:Python components. mlir_api_version = 54 @@ -63,11 +63,18 @@ _NameValueMapping = Mapping[str, Union[str, int, List[int], float, bool]] -def make_cpu_client() -> ...: - register_custom_call_handler( - 'cpu', _xla.register_custom_call_target +def make_cpu_client( + distributed_client=None, + node_id=0, + num_nodes=1, +) -> ...: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + return _xla.get_tfrt_cpu_client( + asynchronous=True, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, ) - return _xla.get_tfrt_cpu_client(asynchronous=True) def make_gpu_client( @@ -97,12 +104,8 @@ def make_gpu_client( if memory_fraction: config.memory_fraction = float(memory_fraction) config.preallocate = preallocate not in ('0', 'false', 'False') - register_custom_call_handler( - 'CUDA', _xla.register_custom_call_target - ) - register_custom_call_handler( - 'ROCM', _xla.register_custom_call_target - ) + register_custom_call_handler('CUDA', _xla.register_custom_call_target) + register_custom_call_handler('ROCM', _xla.register_custom_call_target) return _xla.get_gpu_client( asynchronous=True, @@ -224,6 +227,7 @@ def generate_pjrt_gpu_plugin_options( class OpMetadata: """Python representation of a xla.OpMetadata protobuf.""" + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') def __init__(self, op_type='', op_name='', source_file='', source_line=0): @@ -238,10 +242,8 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): full_filename, lineno = inspect.stack()[skip_frames][1:3] filename = os.path.basename(full_filename) return OpMetadata( - op_type=op_type, - op_name=op_name, - source_file=filename, - source_line=lineno) + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) PrimitiveType = _xla.PrimitiveType @@ -380,7 +382,8 @@ def convert(pyval): if isinstance(pyval, tuple): if layout is not None: raise NotImplementedError( - 'shape_from_pyval does not support layouts for tuple shapes') + 'shape_from_pyval does not support layouts for tuple shapes' + ) return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) else: return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) @@ -478,8 +481,9 @@ class PaddingType(enum.Enum): SAME = 2 -def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, - window_strides): +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): """Maps PaddingType or string to pad values (list of pairs of ints).""" if not isinstance(padding_type, (str, PaddingType)): msg = 'padding_type must be str or PaddingType, got {}.' @@ -501,7 +505,8 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, pad_sizes = [ max((out_size - 1) * stride + filter_size - in_size, 0) for out_size, stride, filter_size, in_size in zip( - out_shape, window_strides, rhs_dims, lhs_dims) + out_shape, window_strides, rhs_dims, lhs_dims + ) ] return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] else: @@ -605,6 +610,7 @@ def register_custom_call_handler(platform: str, handler: Any) -> None: class PaddingConfigDimension: """Python representation of a xla.PaddingConfigDimension protobuf.""" + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') edge_padding_low: int @@ -619,6 +625,7 @@ def __init__(self): class PaddingConfig: """Python representation of a xla.PaddingConfig protobuf.""" + __slots__ = ('dimensions',) def __init__(self): @@ -652,8 +659,13 @@ def make_padding_config( class DotDimensionNumbers: """Python representation of a xla.DotDimensionNumbers protobuf.""" - __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', - 'lhs_batch_dimensions', 'rhs_batch_dimensions') + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) def __init__(self): self.lhs_contracting_dimensions = [] @@ -663,9 +675,10 @@ def __init__(self): def make_dot_dimension_numbers( - dimension_numbers: Union[DotDimensionNumbers, - Tuple[Tuple[List[int], List[int]], - Tuple[List[int], List[int]]]] + dimension_numbers: Union[ + DotDimensionNumbers, + Tuple[Tuple[List[int], List[int]], Tuple[List[int], List[int]]], + ] ) -> DotDimensionNumbers: """Builds a DotDimensionNumbers object from a specification. @@ -692,11 +705,18 @@ def make_dot_dimension_numbers( class ConvolutionDimensionNumbers: """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" - __slots__ = ('input_batch_dimension', 'input_feature_dimension', - 'input_spatial_dimensions', 'kernel_input_feature_dimension', - 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', - 'output_batch_dimension', 'output_feature_dimension', - 'output_spatial_dimensions') + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) def __init__(self): self.input_batch_dimension = 0 @@ -711,30 +731,32 @@ def __init__(self): def make_convolution_dimension_numbers( - dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, - str]], - num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, Tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: """Builds a ConvolutionDimensionNumbers object from a specification. Args: dimension_numbers: optional, either a ConvolutionDimensionNumbers object or - a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of - length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and - the output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions - in rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers - consistent with the Conv operation with two spatial dimensions, one - could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate - dimension numbers consistent with the TensorFlow Conv2D operation, one - could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of - convolution dimension specification, window strides are associated with - spatial dimension character labels according to the order in which the - labels appear in the rhs_spec string, so that window_strides[0] is - matched with the dimension corresponding to the first character - appearing in rhs_spec that is not 'I' or 'O'. By default, use the same - dimension numbering as Conv and ConvWithGeneralPadding. + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. num_spatial_dimensions: the number of spatial dimensions. Returns: @@ -764,18 +786,26 @@ def make_convolution_dimension_numbers( dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) dimension_numbers.input_spatial_dimensions.extend( - sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]))) + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) dimension_numbers.output_spatial_dimensions.extend( - sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]))) + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) return dimension_numbers class PrecisionConfig: """Python representation of a xla.PrecisionConfig protobuf.""" + __slots__ = ('operand_precision',) Precision = _xla.PrecisionConfig_Precision @@ -786,8 +816,13 @@ def __init__(self): class GatherDimensionNumbers: """Python representation of a xla.GatherDimensionNumbers protobuf.""" - __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', - 'index_vector_dim') + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) def __init__(self): self.offset_dims = [] @@ -798,8 +833,13 @@ def __init__(self): class ScatterDimensionNumbers: """Python representation of a xla.ScatterDimensionNumbers protobuf.""" - __slots__ = ('update_window_dims', 'inserted_window_dims', - 'scatter_dims_to_operand_dims', 'index_vector_dim') + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) def __init__(self): self.update_window_dims = [] @@ -810,6 +850,7 @@ def __init__(self): class ReplicaGroup: """Python representation of a xla.ReplicaGroup protobuf.""" + __slots__ = ('replica_ids',) def __init__(self): diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index 2eb82ec094f2d7..04e22b8e7cf417 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -81,7 +81,11 @@ def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: .. def heap_profile(client: Client) -> bytes: ... -def make_cpu_client() -> Client: +def make_cpu_client( + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., +) -> Client: ... def make_gpu_client( diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 19752e1c593903..ed841b38217f96 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -20,14 +20,26 @@ import inspect import types import typing from typing import ( - Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, - Type, TypeVar, Union, overload) + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) import numpy as np -from . import ops from . import jax_jit from . import mlir +from . import ops from . import outfeed_receiver from . import pmap_lib from . import profiler @@ -88,7 +100,8 @@ class Shape: type: Union[np.dtype, PrimitiveType], dims_seq: Any = ..., layout_seq: Any = ..., - dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ... + dynamic_dimensions: Optional[List[bool]] = ..., + ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @staticmethod @@ -136,7 +149,7 @@ class XlaComputation: def get_hlo_module(self) -> HloModule: ... def program_shape(self) -> ProgramShape: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - def as_hlo_text(self, print_large_constants: bool=False) -> str: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... def as_hlo_dot_graph(self) -> str: ... def hash(self) -> int: ... def as_hlo_module(self) -> HloModule: ... @@ -176,10 +189,11 @@ class HloModule: @property def name(self) -> str: ... def to_string(self, options: HloPrintOptions = ...) -> str: ... - def as_serialized_hlo_module_proto(self)-> bytes: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... @staticmethod def from_serialized_hlo_module_proto( - serialized_hlo_module_proto: bytes) -> HloModule: ... + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... def computations(self) -> List[HloComputation]: ... class HloModuleGroup: @@ -191,10 +205,9 @@ class HloModuleGroup: def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... - def hlo_module_cost_analysis( - client: Client, - module: HloModule) -> Dict[str, float]: ... + client: Client, module: HloModule +) -> Dict[str, float]: ... class XlaOp: ... @@ -214,7 +227,8 @@ class XlaBuilder: self, __output_index: Sequence[int], __param_number: int, - __param_index: Sequence[int]) -> None: ... + __param_index: Sequence[int], + ) -> None: ... class DeviceAssignment: @staticmethod @@ -238,12 +252,18 @@ class CompileOptions: profile_version: int device_assignment: Optional[DeviceAssignment] compile_portable_executable: bool - env_option_overrides: List[Tuple[str,str]] - -def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ... -def register_custom_call_partitioner(name: str, prop_user_sharding: Callable, - partition: Callable, infer_sharding_from_operands: Callable, - can_side_effecting_have_replicated_sharding: bool) -> None: ... + env_option_overrides: List[Tuple[str, str]] + +def register_custom_call_target( + fn_name: str, capsule: Any, platform: str +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool, +) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... class DebugOptions: @@ -349,11 +369,16 @@ class HloSharding: @staticmethod def from_string(sharding: str) -> HloSharding: ... @staticmethod - def tuple_sharding(shape: Shape, shardings: Sequence[HloSharding]) -> HloSharding: ... + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... @staticmethod - def iota_tile(dims: Sequence[int], reshape_dims: Sequence[int], - transpose_perm: Sequence[int], - subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type], + ) -> HloSharding: ... @staticmethod def replicate() -> HloSharding: ... @staticmethod @@ -415,16 +440,17 @@ class Memory: class GpuAllocatorConfig: class Kind(enum.IntEnum): - DEFAULT: int - PLATFORM: int - BFC: int - CUDA_ASYNC: int + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int def __init__( self, kind: Kind = ..., memory_fraction: float = ..., - preallocate: bool = ...) -> None: ... + preallocate: bool = ..., + ) -> None: ... class HostBufferSemantics(enum.IntEnum): IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics @@ -450,61 +476,78 @@ class Client: argument: Any, device: Optional[Device] = ..., force_copy: bool = ..., - host_buffer_semantics: HostBufferSemantics = ...) -> ArrayImpl: ... + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... def make_cross_host_receive_buffers( - self, - shapes: Sequence[Shape], - device: Device) -> List[Tuple[ArrayImpl, bytes]]: ... + self, shapes: Sequence[Shape], device: Device + ) -> List[Tuple[ArrayImpl, bytes]]: ... def compile( self, computation: Union[str, bytes], - compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... def deserialize_executable( - self, serialized: bytes, - options: Optional[CompileOptions], host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... + self, + serialized: bytes, + options: Optional[CompileOptions], + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> _Status: ... def get_emit_python_callback_descriptor( - self, callable: Callable, operand_shapes: Sequence[Shape], - results_shapes: Sequence[Shape]) -> Tuple[Any, Any]: ... + self, + callable: Callable, + operand_shapes: Sequence[Shape], + results_shapes: Sequence[Shape], + ) -> Tuple[Any, Any]: ... def make_python_callback_from_host_send_and_recv( - self, callable: Callable, operand_shapes: Sequence[Shape], - result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], - recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ...) -> Any: ... + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Optional[Callable] = ..., + ) -> Any: ... def __getattr__(self, name: str) -> Any: ... -def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., +) -> Client: ... def get_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., + num_nodes: int = ..., allowed_devices: Optional[Any] = ..., platform_name: Optional[str] = ..., - mock:Optional[bool]=...) -> Client:... + mock: Optional[bool] = ..., +) -> Client: ... def get_mock_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., allowed_devices: Optional[Any] = ..., - platform_name: Optional[str] = ...) -> Client:... + platform_name: Optional[str] = ..., +) -> Client: ... def get_c_api_client( platform_name: str, options: Dict[str, Union[str, int, List[int], float, bool]], distributed_client: Optional[DistributedRuntimeClient] = ..., ) -> Client: ... - def get_default_c_api_topology( platform_name: str, topology_name: str, options: Dict[str, Union[str, int, List[int], float]], -) -> DeviceTopology: - ... -def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: - ... - +) -> DeviceTopology: ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... def load_pjrt_plugin(platform_name: str, library_path: str) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... @@ -541,10 +584,14 @@ ArrayImpl = Any # traceback: Traceback # _HAS_DYNAMIC_ATTRIBUTES: bool = ... -def copy_array_to_devices_with_sharding(self: ArrayImpl, devices: List[Device], sharding: Any) -> ArrayImpl: ... - +def copy_array_to_devices_with_sharding( + self: ArrayImpl, devices: List[Device], sharding: Any +) -> ArrayImpl: ... def batched_device_put( - aval: Any, sharding: Any, shards: Sequence[Any], devices: List[Device], + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: List[Device], committed: bool = True, ) -> ArrayImpl: ... @@ -553,11 +600,8 @@ def check_and_canonicalize_memory_kind( memory_kind: Optional[str], device_list: DeviceList) -> Optional[str]: ... def array_result_handler( - aval: Any, - sharding: Any, - committed: bool, - _skip_checks: bool = ...) -> Callable: - ... + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... class Token: def block_until_ready(self): ... @@ -569,7 +613,9 @@ class ShardedToken: class ExecuteResults: def __len__(self) -> int: ... def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... - def disassemble_prefix_into_single_device_arrays(self, n: int) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> List[List[ArrayImpl]]: ... def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... def consume_token(self) -> ShardedToken: ... @@ -581,18 +627,17 @@ class LoadedExecutable: def delete(self) -> None: ... def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... def execute_with_token( - self, - arguments: Sequence[ArrayImpl]) -> Tuple[List[ArrayImpl], Token]: - ... + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... def execute_sharded_on_local_devices( - self, - arguments: Sequence[List[ArrayImpl]]) -> List[List[ArrayImpl]]: ... + self, arguments: Sequence[List[ArrayImpl]] + ) -> List[List[ArrayImpl]]: ... def execute_sharded_on_local_devices_with_tokens( - self, - arguments: Sequence[List[ArrayImpl]]) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... + self, arguments: Sequence[List[ArrayImpl]] + ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... def execute_sharded( - self, - arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ...) -> ExecuteResults: ... + self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... def hlo_modules(self) -> List[HloModule]: ... def get_output_memory_kinds(self) -> List[List[str]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... @@ -625,14 +670,18 @@ class DeviceTopology: def __getattr__(self, name: str) -> Any: ... def buffer_to_dlpack_managed_tensor( - buffer: ArrayImpl, - stream: int | None = None) -> Any: ... + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... def dlpack_managed_tensor_to_buffer( - tensor: Any, device: Device, stream: int | None) -> ArrayImpl: ... + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... + # Legacy overload def dlpack_managed_tensor_to_buffer( - tensor: Any, cpu_backend: Optional[Client] = ..., - gpu_backend: Optional[Client] = ...) -> ArrayImpl: ... + tensor: Any, + cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... # === BEGIN py_traceback.cc @@ -651,12 +700,12 @@ class Traceback: def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... - @staticmethod def code_addr2line(code: types.CodeType, lasti: int) -> int: ... @staticmethod - def code_addr2location(code: types.CodeType, - lasti: int) -> Tuple[int, int, int, int]: ... + def code_addr2location( + code: types.CodeType, lasti: int + ) -> Tuple[int, int, int, int]: ... def replace_thread_exc_traceback(traceback: Any): ... @@ -664,16 +713,20 @@ def replace_thread_exc_traceback(traceback: Any): ... class DistributedRuntimeService: def shutdown(self) -> None: ... + class DistributedRuntimeClient: def connect(self) -> _Status: ... def shutdown(self) -> _Status: ... def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... - def blocking_key_value_get_bytes(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str) -> _Status: ... - def key_value_delete(self, key:str) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int) -> _Status: ... + def get_distributed_runtime_service( address: str, num_nodes: int, @@ -690,17 +743,16 @@ def get_distributed_runtime_client( heartbeat_interval: Optional[int] = ..., max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., - shutdown_on_destruction: Optional[bool] = ...) -> DistributedRuntimeClient: ... + shutdown_on_destruction: Optional[bool] = ..., +) -> DistributedRuntimeClient: ... class PreemptionSyncManager: def initialize(self, client: DistributedRuntimeClient) -> _Status: ... def reached_sync_point(self, step_counter: int) -> bool: ... -def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def create_preemption_sync_manager() -> PreemptionSyncManager: ... def collect_garbage() -> None: ... - def is_optimized_build() -> bool: ... - def json_to_pprof_profile(json: str) -> bytes: ... def pprof_profile_to_json(proto: bytes) -> str: ... @@ -714,8 +766,9 @@ class PmapFunction: def _cache_size(self) -> int: ... def _cache_clear(self) -> None: ... -def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...): - ... +def weakref_lru_cache( + cache_context_fn: Callable, call: Callable, maxsize=... +): ... class DeviceList: def __init__(self, device_assignment: Tuple[Device, ...]): ... @@ -738,13 +791,18 @@ class DeviceList: def memory_kinds(self) -> Tuple[str, ...]: ... class Sharding: ... - class XLACompatibleSharding(Sharding): ... class NamedSharding(XLACompatibleSharding): - def __init__(self, mesh: Any, spec: Any, *, memory_kind: Optional[str] = None, - _parsed_pspec: Any = None, - _manual_axes: frozenset[Any] = frozenset()): ... + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: Optional[str] = None, + _parsed_pspec: Any = None, + _manual_axes: frozenset[Any] = frozenset(), + ): ... mesh: Any spec: Any _memory_kind: Optional[str] @@ -759,15 +817,21 @@ class SingleDeviceSharding(XLACompatibleSharding): _internal_device_list: DeviceList class PmapSharding(XLACompatibleSharding): - def __init__(self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec): ... + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... devices: List[Any] sharding_spec: pmap_lib.ShardingSpec _internal_device_list: DeviceList class GSPMDSharding(XLACompatibleSharding): - def __init__(self, devices: Sequence[Device], - op_sharding: Union[OpSharding, HloSharding], - *, memory_kind: Optional[str] = None): ... + def __init__( + self, + devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, + memory_kind: Optional[str] = None, + ): ... _devices: Tuple[Device, ...] _hlo_sharding: HloSharding _memory_kind: Optional[str] @@ -786,12 +850,16 @@ class PjitFunctionCache: @staticmethod def clear_all(): ... -def pjit(function_name: str, fun: Optional[Callable], cache_miss: Callable, - static_argnums: Sequence[int], static_argnames: Sequence[str], - donate_argnums: Sequence[int], - pytree_registry: pytree.PyTreeRegistry, - cache: Optional[PjitFunctionCache] = ..., - ) -> PjitFunction: ... +def pjit( + function_name: str, + fun: Optional[Callable], + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + donate_argnums: Sequence[int], + pytree_registry: pytree.PyTreeRegistry, + cache: Optional[PjitFunctionCache] = ..., +) -> PjitFunction: ... class HloPassInterface: @property @@ -813,9 +881,6 @@ class TupleSimplifer(HloPassInterface): def __init__(self) -> None: ... def is_asan() -> bool: ... - def is_msan() -> bool: ... - def is_tsan() -> bool: ... - def is_sanitized() -> bool: ... From 84251124c87b706bf883169b4e839c7e2f1f714d Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 13 Nov 2023 18:04:42 -0800 Subject: [PATCH 046/391] [xla_compile][NFC] Extract the compilation and file-writing to a library. The library can be (and is) tested. PiperOrigin-RevId: 582142162 --- third_party/xla/xla/service/BUILD | 14 +- .../xla/xla/service/xla_compile_main.cc | 154 ++--------------- third_party/xla/xla/tools/BUILD | 89 ++++++++++ third_party/xla/xla/tools/xla_compile_lib.cc | 162 ++++++++++++++++++ third_party/xla/xla/tools/xla_compile_lib.h | 50 ++++++ .../xla/xla/tools/xla_compile_lib_test.cc | 141 +++++++++++++++ 6 files changed, 465 insertions(+), 145 deletions(-) create mode 100644 third_party/xla/xla/tools/xla_compile_lib.cc create mode 100644 third_party/xla/xla/tools/xla_compile_lib.h create mode 100644 third_party/xla/xla/tools/xla_compile_lib_test.cc diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index fdda60624227d6..aed7d55f90f0ae 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7222,27 +7222,21 @@ xla_cc_binary( "//xla:debug_options_flags", "//xla:statusor", "//xla:util", - "//xla/hlo/ir:hlo_module_group", "//xla/mlir_hlo", "//xla/pjrt:mlir_to_hlo", "//xla/service:cpu_plugin", - "//xla/service/cpu:cpu_compiler", - "//xla/service/cpu:cpu_executable", "//xla/service/gpu:autotuner_util", "//xla/service/gpu:gpu_symbol_repository", - "//xla/stream_executor", - "//xla/stream_executor:device_memory_allocator", "//xla/tools:hlo_module_loader", + "//xla/tools:xla_compile_lib", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_time", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", @@ -7250,7 +7244,6 @@ xla_cc_binary( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/util:command_line_flags", - "@local_tsl//tsl/util/proto:proto_utils", "@stablehlo//:register", ] + if_cuda_is_configured([ "//xla/service/gpu:executable_proto_cc", @@ -7481,3 +7474,8 @@ tf_proto_library( ], visibility = ["//visibility:public"], ) + +exports_files( + ["xla_aot_compile_test_gpu_target_config.prototxt"], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/service/xla_compile_main.cc b/third_party/xla/xla/service/xla_compile_main.cc index 71dc1cb00f6e8a..f35c3082215566 100644 --- a/third_party/xla/xla/service/xla_compile_main.cc +++ b/third_party/xla/xla/service/xla_compile_main.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -25,7 +24,6 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -35,24 +33,18 @@ limitations under the License. #include "stablehlo/dialect/Register.h" // from @stablehlo #include "xla/autotune_results.pb.h" #include "xla/debug_options_flags.h" -#include "xla/hlo/ir/hlo_module_group.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/compiler.h" -#include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/cpu/cpu_executable.h" -#include "xla/service/executable.h" #include "xla/service/export_hlo.h" #include "xla/service/hlo_module_config.h" #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" #include "xla/statusor.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/tools/hlo_module_loader.h" +#include "xla/tools/xla_compile_lib.h" #include "xla/util.h" #include "tsl/platform/env.h" -#include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" @@ -60,19 +52,10 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/types.h" #include "tsl/util/command_line_flags.h" -#include "tsl/util/proto/proto_utils.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/executable.pb.h" -#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/gpu_symbol_repository.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#endif -#if GOOGLE_CUDA -#include "xla/service/gpu/nvptx_compiler.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/service/gpu/amdgpu_compiler.h" #endif namespace xla { @@ -100,75 +83,6 @@ const char kUsageHeader[] = "understood by that repository." "\n"; -StatusOr AotCompileCpuExecutable( - std::unique_ptr hlo_module) { - cpu::CpuCompiler cpu_compiler; - TF_ASSIGN_OR_RETURN( - std::unique_ptr cpu_executable, - cpu_compiler.CompileXlaRuntimeCpuExecutable(std::move(hlo_module))); - TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, - cpu_compiler.Export(cpu_executable.get())); - TF_ASSIGN_OR_RETURN(std::string result, aot_result->SerializeAsString()); - return result; -} - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -StatusOr CompileGpuExecutable( - std::unique_ptr hlo_module, - const std::optional target_config) { - const bool aot = target_config.has_value(); - -#if GOOGLE_CUDA - auto gpu_compiler = gpu::NVPTXCompiler(); -#elif TENSORFLOW_USE_ROCM - auto gpu_compiler = gpu::AMDGPUCompiler(); -#endif - Compiler::CompileOptions compile_options; - - stream_executor::StreamExecutor* stream_executor = nullptr; - std::unique_ptr allocator; - if (aot) { - compile_options.target_config = *target_config; - } else { - TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager()); - TF_ASSIGN_OR_RETURN( - stream_executor, - stream_executor::GPUMachineManager()->ExecutorForDevice(0)); - allocator = - std::make_unique( - stream_executor); - compile_options.device_allocator = allocator.get(); - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_after_opt, - gpu_compiler.RunHloPasses(std::move(hlo_module), stream_executor, - compile_options)); - - if (aot) { - auto module_group = - std::make_unique(std::move(module_after_opt)); - - AotCompilationOptions aot_options(gpu_compiler.PlatformId()); - aot_options.set_target_config(*target_config); - - TF_ASSIGN_OR_RETURN( - std::vector> aot_results, - gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options)); - TF_ASSIGN_OR_RETURN(std::string result, - aot_results[0]->SerializeAsString()); - return result; - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - gpu_compiler.RunBackend(std::move(module_after_opt), stream_executor, - compile_options)); - return executable->module().ToString(); -} - -#endif - xla::StatusOr> LoadModule( const std::string& module_path) { auto format = std::string(tsl::io::Extension(module_path)); @@ -207,28 +121,6 @@ xla::StatusOr> LoadModule( return HloModule::CreateFromProto(hlo_module_proto, config); } -Status MaybeWriteResultFile(const std::string& result_output_file, - TimerStats& stats, - CompilationResult& compilation_result) { - if (result_output_file.empty()) { - return absl::OkStatus(); - } - absl::MutexLock ml(&stats.stats_mutex); - const double secs = std::floor(stats.cumulative_secs); - const double nanos = - (stats.cumulative_secs - secs) * tsl::EnvTime::kSecondsToNanos; - google::protobuf::Duration duration; - duration.set_seconds(secs); - duration.set_nanos(nanos); - - *compilation_result.mutable_perf_stats()->mutable_compilation_duration() = - duration; - *compilation_result.mutable_perf_stats()->mutable_total_duration() = duration; - - return tsl::WriteBinaryProto(tsl::Env::Default(), result_output_file, - compilation_result); -} - Status XlaCompileMain( const std::string& module_path, const std::string& output_path, const std::string& platform, const std::string& gpu_target_config_path, @@ -264,21 +156,14 @@ Status XlaCompileMain( absl::Cleanup cleanup([&] { // Make sure we stop the timer if compilation failed. timer.StopAndLog(); - TF_QCHECK_OK( - MaybeWriteResultFile(result_output_file, stats, compilation_result)); + if (!result_output_file.empty()) { + TF_QCHECK_OK( + WriteResultFile(result_output_file, stats, compilation_result)); + } }); // Run AOT compilation. - std::string result; - if (platform == "cpu") { - auto compile_result = AotCompileCpuExecutable(std::move(hlo_module)); - if (!compile_result.ok()) { - *compilation_result.mutable_status() = - tsl::StatusToProto(compile_result.status()); - return compile_result.status(); - } - result = *compile_result; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } else if (platform == "gpu") { + std::optional cfg = std::nullopt; + if (platform == "gpu") { if (!gpu_target_config_path.empty()) { // Parse GpuTargetConfig. std::string gpu_target_config_string; @@ -295,30 +180,25 @@ Status XlaCompileMain( target_config = std::make_unique(gpu_target_config_proto); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (!autotune_results_path.empty()) { TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile( autotune_results_path)); } +#endif } - std::optional cfg = - (use_attached_device) ? std::nullopt - : std::make_optional(*std::move(target_config)); - auto compile_result = CompileGpuExecutable(std::move(hlo_module), cfg); - if (!compile_result.ok()) { - *compilation_result.mutable_status() = - tsl ::StatusToProto(compile_result.status()); - return compile_result.status(); - } - result = *compile_result; -#endif - } else { - return Unimplemented("platform %s not supported", platform); + cfg = (use_attached_device) ? std::nullopt + : std::make_optional(*std::move(target_config)); + } + auto result = CompileExecutable(std::move(hlo_module), platform, cfg); + if (!result.ok()) { + *compilation_result.mutable_status() = tsl::StatusToProto(result.status()); + return result.status(); } - timer.StopAndLog(); TF_RETURN_IF_ERROR( - tsl::WriteStringToFile(tsl::Env::Default(), output_path, result)); + tsl::WriteStringToFile(tsl::Env::Default(), output_path, *result)); if (wait_for_uploads) { MaybeWaitForUploads(); diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index efe3e9504ba30e..057facad20511a 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -1,5 +1,6 @@ # Tools and utilities that aid in XLA development and usage. +load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load( @@ -8,9 +9,12 @@ load( "xla_cc_test", "xla_py_proto_library", ) +load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "@local_tsl//tsl:tsl.bzl", "if_cuda_or_rocm", + "tsl_gpu_library", ) load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load( @@ -18,6 +22,10 @@ load( "tf_proto_library", ) load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) package( default_visibility = ["//visibility:public"], @@ -616,3 +624,84 @@ xla_cc_binary( "@local_tsl//tsl/platform:platform_port", ], ) + +tsl_gpu_library( + name = "xla_compile_lib", + srcs = ["xla_compile_lib.cc"], + hdrs = ["xla_compile_lib.h"], + cuda_deps = [ + ], + defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + visibility = ["//visibility:public"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:cpu_executable", + "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_time", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + "//xla/service/gpu:nvptx_compiler", + "//xla/service/gpu:nvptx_compiler_impl", + ]) + if_rocm_is_configured([ + "//xla/service/gpu:amdgpu_compiler", + "//xla/service/gpu:amdgpu_compiler_impl", + ]) + if_gpu_is_configured([ + "//xla/service/gpu:executable_proto_cc", + "//xla/service/gpu:gpu_compiler", + "//xla/stream_executor/gpu:gpu_init", + ]), +) + +xla_test( + name = "xla_compile_lib_test", + srcs = ["xla_compile_lib_test.cc"], + backend_tags = { + "gpu": [ + "requires-gpu-nvidia", + "config-cuda-only", + ], + }, + backends = [ + "cpu", + "gpu", + ], + data = [ + ":data/add.hlo", + "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deps = [ + ":xla_compile_lib", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/stream_executor:device_description_proto_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_time", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + ], +) diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc new file mode 100644 index 00000000000000..4c26660a6b22fb --- /dev/null +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/xla_compile_lib.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/service/compiler.h" +#include "xla/service/cpu/cpu_compiler.h" +#include "xla/service/cpu/cpu_executable.h" +#include "xla/service/executable.h" +#include "xla/service/xla_compile_result.pb.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/env_time.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/gpu/executable.pb.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#endif +#if GOOGLE_CUDA +#include "xla/service/gpu/nvptx_compiler.h" +#elif TENSORFLOW_USE_ROCM +#include "xla/service/gpu/amdgpu_compiler.h" +#endif + +namespace xla { + +static StatusOr AotCompileCpuExecutable( + std::unique_ptr hlo_module) { + cpu::CpuCompiler cpu_compiler; + TF_ASSIGN_OR_RETURN( + std::unique_ptr cpu_executable, + cpu_compiler.CompileXlaRuntimeCpuExecutable(std::move(hlo_module))); + TF_ASSIGN_OR_RETURN(std::unique_ptr aot_result, + cpu_compiler.Export(cpu_executable.get())); + return aot_result->SerializeAsString(); +} + +static StatusOr CompileGpuExecutable( + std::unique_ptr hlo_module, + std::optional target_config) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + const bool aot = target_config.has_value(); + +#if GOOGLE_CUDA + auto gpu_compiler = gpu::NVPTXCompiler(); +#elif TENSORFLOW_USE_ROCM + auto gpu_compiler = gpu::AMDGPUCompiler(); +#endif + + Compiler::CompileOptions compile_options; + + stream_executor::StreamExecutor* stream_executor = nullptr; + std::unique_ptr allocator; + if (aot) { + compile_options.target_config = *target_config; + } else { + TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager()); + TF_ASSIGN_OR_RETURN( + stream_executor, + stream_executor::GPUMachineManager()->ExecutorForDevice(0)); + allocator = + std::make_unique( + stream_executor); + compile_options.device_allocator = allocator.get(); + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_after_opt, + gpu_compiler.RunHloPasses(std::move(hlo_module), stream_executor, + compile_options)); + + if (aot) { + auto module_group = + std::make_unique(std::move(module_after_opt)); + + AotCompilationOptions aot_options(gpu_compiler.PlatformId()); + aot_options.set_target_config(*target_config); + + TF_ASSIGN_OR_RETURN( + std::vector> aot_results, + gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options)); + TF_ASSIGN_OR_RETURN(std::string result, + aot_results[0]->SerializeAsString()); + return result; + } + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + gpu_compiler.RunBackend(std::move(module_after_opt), stream_executor, + compile_options)); + return executable->module().ToString(); +#else + LOG(ERROR) << "Neither ROCm nor CUDA present; returning empty."; + return ""; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +StatusOr CompileExecutable( + std::unique_ptr hlo_module, absl::string_view platform, + std::optional target_config) { + if (platform == "cpu") { + return AotCompileCpuExecutable(std::move(hlo_module)); + } else if (platform == "gpu") { + return CompileGpuExecutable(std::move(hlo_module), target_config); + } + + return absl::UnimplementedError( + absl::StrCat("platform", platform, " is not supported")); +} + +Status WriteResultFile(const std::string& result_output_file, TimerStats& stats, + CompilationResult& compilation_result) { + if (result_output_file.empty()) { + return absl::OkStatus(); + } + absl::MutexLock ml(&stats.stats_mutex); + const double secs = std::floor(stats.cumulative_secs); + const double nanos = + (stats.cumulative_secs - secs) * tsl::EnvTime::kSecondsToNanos; + google::protobuf::Duration duration; + duration.set_seconds(secs); + duration.set_nanos(nanos); + + *compilation_result.mutable_perf_stats()->mutable_compilation_duration() = + duration; + *compilation_result.mutable_perf_stats()->mutable_total_duration() = duration; + + return tsl::WriteBinaryProto(tsl::Env::Default(), result_output_file, + compilation_result); +} + +} // namespace xla diff --git a/third_party/xla/xla/tools/xla_compile_lib.h b/third_party/xla/xla/tools/xla_compile_lib.h new file mode 100644 index 00000000000000..63176f10faf63d --- /dev/null +++ b/third_party/xla/xla/tools/xla_compile_lib.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TOOLS_XLA_COMPILE_LIB_H_ +#define XLA_TOOLS_XLA_COMPILE_LIB_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/compiler.h" +#include "xla/service/xla_compile_result.pb.h" +#include "xla/util.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Compiles the provided module for the given platform, either "cpu" or "gpu". +// When compiling for GPU, if the target config is provided, the compilation +// will be AOT. If it is not provided, an attached GPU will be used. When +// compiling for CPU, the compilation will always be AOT. +// +// This is the expected entry point to the compilation functionality. +StatusOr CompileExecutable( + std::unique_ptr hlo_module, absl::string_view platform, + std::optional target_config); + +// Merges the measured duration into compilation_result and writes +// compilation_result to result_output_file in the wire format. +Status WriteResultFile(const std::string& result_output_file, TimerStats& stats, + CompilationResult& compilation_result); + +} // namespace xla + +#endif // XLA_TOOLS_XLA_COMPILE_LIB_H_ diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc new file mode 100644 index 00000000000000..05d86570ecee0e --- /dev/null +++ b/third_party/xla/xla/tools/xla_compile_lib_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/xla_compile_lib.h" + +#include +#include +#include +#include + +#include "google/protobuf/duration.pb.h" +#include +#include +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/platform_util.h" +#include "xla/service/xla_compile_result.pb.h" +#include "xla/stream_executor/device_description.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/env_time.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/error_codes.pb.h" + +namespace xla { +namespace { + +using ::testing::IsEmpty; +using ::testing::Not; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +#if XLA_TEST_BACKEND_CPU +static constexpr absl::string_view kPlatformName = "Host"; +#elif XLA_TEST_BACKEND_GPU +static constexpr absl::string_view kPlatformName = "CUDA"; +#endif // XLA_TEST_BACKEND_CPU + +class XlaCompileLibTest : public HloTestBase { + protected: + XlaCompileLibTest() + : HloTestBase(*PlatformUtil::GetPlatform(std::string(kPlatformName)), + GetReferencePlatform()) {} + void SetUp() override { + const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), + "tools", "data", "add.hlo"); + std::string hlo; + TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo)); + } + + std::unique_ptr module_; +}; + +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) { + EXPECT_THAT(CompileExecutable(std::move(module_), "cpu", std::nullopt), + IsOkAndHolds(Not(IsEmpty()))); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { + EXPECT_THAT(CompileExecutable(std::move(module_), "gpu", std::nullopt), + IsOkAndHolds(Not(IsEmpty()))); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { + const std::string target_config_path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", + "xla_aot_compile_test_gpu_target_config.prototxt"); + stream_executor::GpuTargetConfigProto target_config; + TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path, + &target_config)); + EXPECT_THAT(CompileExecutable(std::move(module_), "gpu", std::nullopt), + IsOkAndHolds(Not(IsEmpty()))); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) { + EXPECT_THAT(CompileExecutable(nullptr, "tpu", std::nullopt), + StatusIs(tsl::error::UNIMPLEMENTED)); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFilePropagatesErrors)) { + TimerStats stats; + CompilationResult result; + EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk())); +} + +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) { + std::string result_output_file; + ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file)); + + TimerStats stats; + { + absl::MutexLock ml(&stats.stats_mutex); + stats.cumulative_secs = 5.5; + stats.max_secs = 5.5; + } + + CompilationResult result; + google::protobuf::Duration duration; + duration.set_seconds(5); + duration.set_nanos(0.5 * tsl::EnvTime::kSecondsToNanos); + *result.mutable_perf_stats()->mutable_compilation_duration() = duration; + *result.mutable_perf_stats()->mutable_total_duration() = duration; + + TF_ASSERT_OK(WriteResultFile(result_output_file, stats, result)); + + CompilationResult got_result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_output_file, + &got_result)); + // Sadly EqualsProto isn't OSS, so we inspect a few fields manually. + // See googletest#1761 and b/229726259. + EXPECT_EQ(5, got_result.perf_stats().compilation_duration().seconds()); + EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, + got_result.perf_stats().compilation_duration().nanos()); + EXPECT_EQ(5, got_result.perf_stats().total_duration().seconds()); + EXPECT_EQ(0.5 * tsl::EnvTime::kSecondsToNanos, + got_result.perf_stats().total_duration().nanos()); +} + +} // namespace +} // namespace xla From 4db8e9e1c5e5cbb6db457935ef5d0616ab94f08b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Nov 2023 18:15:25 -0800 Subject: [PATCH 047/391] Bump jax test dependency to v0.4.1 PiperOrigin-RevId: 582144297 --- tensorflow/tools/ci_build/release/requirements_mac.txt | 2 +- tensorflow/tools/ci_build/release/requirements_ubuntu.txt | 2 +- .../tools/tf_sig_build_dockerfiles/devel.requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/ci_build/release/requirements_mac.txt b/tensorflow/tools/ci_build/release/requirements_mac.txt index 39349a3f3a6aa2..aa08a8c8db45e3 100644 --- a/tensorflow/tools/ci_build/release/requirements_mac.txt +++ b/tensorflow/tools/ci_build/release/requirements_mac.txt @@ -8,5 +8,5 @@ twine ~= 3.6.0 setuptools # Test dependencies which don't exist on Windows -jax ~= 0.3.24 +jax ~= 0.4.1 jaxlib ~= 0.4.1 diff --git a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt index 8d7122076fcd91..db2e1ee8b47fca 100644 --- a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt +++ b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt @@ -5,5 +5,5 @@ PyYAML ~= 6.0 # Test dependencies which don't exist on Windows -jax ~= 0.3.14 +jax ~= 0.4.1 jaxlib ~= 0.4.1; platform.machine != 'aarch64' diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt index 871f650caa810b..4a899fb3504e11 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt @@ -42,7 +42,7 @@ scipy ~= 1.7.2; python_version < '3.11' scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11 scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12 # Required for TFLite import from JAX tests -jax ~= 0.3.25; python_version <= '3.11' +jax ~= 0.4.1; python_version <= '3.11' jaxlib ~= 0.4.1; python_version <= '3.11' # Earliest version for Python 3.11 # Needs to be addressed. Unblocked 2.4 branchcut cl/338377048 PyYAML ~= 6.0 From f71b8d7ac55b6cfd39229d091131cc6d4e3ae39b Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 13 Nov 2023 20:18:28 -0800 Subject: [PATCH 048/391] Support native serialization version 9 when serializing StableHLO op to XlaCallModuleOp. PiperOrigin-RevId: 582168880 --- ..._main_function_with_xla_call_module_ops.cc | 15 +++++--- ...ain_function_with_xla_call_module_ops.mlir | 36 +++++++++++-------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 5bf8ba7ec07657..829831224a1b35 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "llvm/ADT/STLExtras.h" @@ -42,6 +43,11 @@ namespace { constexpr StringRef kQuantizeTargetOpAttr = "tf_quant.composite_function"; +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Default platform for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; + class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : public impl:: ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPassBase< @@ -163,14 +169,15 @@ void CreateXlaCallModuleOp(ArrayRef inputs, ArrayRef outputs, tf_type::ShapeAttr::get(ctx, result_type.cast())); } auto empty_array_attr = ArrayAttr::get(ctx, {}); - // TODO - b/303363466: Allow XlaCallModuleOp with versions >5. + // TODO - b/310291615: Support platforms = ["TPU"]. + auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); + auto xla_call_module_op = builder.create( module_op.getLoc(), /*output=*/result_types, /*args=*/inputs, - /*version=*/5, /*module=*/"", + /*version=*/kDefaultVersion, /*module=*/"", /*Sout=*/ArrayAttr::get(ctx, shape_attrs), - /*dim_args_spec=*/empty_array_attr, - /*platforms=*/empty_array_attr, + /*dim_args_spec=*/empty_array_attr, platforms, /*function_list=*/empty_array_attr, /*has_token_input_output=*/false, /*disabled_checks=*/empty_array_attr); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 3d04c72dec7f7e..60454fcde06389 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -1,6 +1,9 @@ -// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file \ +// RUN: -stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops \ +// RUN: | FileCheck %s -// Modules with "main" or "serving_default" should properly run this pass and convert subgraphs into XLACallModuleOp. +// Modules with "main" or "serving_default" should properly run this pass and +// convert subgraphs into XLACallModuleOp. module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { @@ -20,23 +23,23 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> %2 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %3 = "tf.XlaCallModule"(%2, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.XlaCallModule"(%2, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> %7 = "tf.CustomAggregator"(%4) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - %8 = "tf.XlaCallModule"(%7, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %8 = "tf.XlaCallModule"(%7, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %9 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x64xf32>) -> tensor<1x64xf32> return %9 : tensor<1x64xf32> } // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_1 // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}}> {_entry_function = @_stablehlo_main_0 // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> - // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1" + // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> // CHECK: } @@ -63,6 +66,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- +// Tests that the subgraph in serving_default excluding the tf.Identity is +// converted to a single XlaCallModuleOp. + module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1654 : i32}, tf_saved_model.semantics} { // CHECK: func private @_stablehlo_main_0 @@ -85,8 +91,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %5 : tensor<1x1024xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>] - // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) // CHECK: return %[[IDENTITY]] // CHECK } @@ -95,8 +100,10 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- +// Tests that the first stablehlo.constant is converted to XlaCallModuleOp. + module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { - // CHECK: func private @_stablehlo_main_ + // CHECK: func private @_stablehlo_main_0 // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> // CHECK: return %[[CONSTANT:.*]] // CHECK: } @@ -105,12 +112,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %3 : tensor<1x3xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_ + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0} // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) @@ -127,7 +134,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- -// Tests to confirm that the StableHLO graph is not replaced if "main" or "serving_default" function is in the module. +// Tests to confirm that the StableHLO graph is not replaced if "main" or +// "serving_default" function is not in the module. module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { // CHECK-NOT: func private @_stablehlo_main_ @@ -136,14 +144,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32> return %3 : tensor<1x3xf32> } // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> - // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) // CHECK: return %[[CUSTOM_AGGREGATOR_1]] // CHECK: } From 41c7bbfe14b7f312fbc916655293fa724dd59906 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 13 Nov 2023 20:50:09 -0800 Subject: [PATCH 049/391] [xla:gpu] NFC: Remove nested kernel namespace from custom kernels/fusions CustomKernel and CustomFusion are already unique enough, no need to put them into a unique namespace. PiperOrigin-RevId: 582174192 --- .../xla/service/gpu/custom_fusion_rewriter.cc | 3 --- .../xla/service/gpu/custom_fusion_rewriter.h | 5 ++--- .../gpu/custom_fusion_rewriter_test.cc | 4 ++-- .../xla/service/gpu/ir_emitter_unnested.cc | 6 +++--- .../xla/xla/service/gpu/kernel_thunk.cc | 2 +- .../xla/xla/service/gpu/kernel_thunk.h | 5 ++--- .../xla/service/gpu/kernels/custom_fusion.cc | 4 ++-- .../xla/service/gpu/kernels/custom_fusion.h | 20 +++++++++---------- .../gpu/kernels/custom_fusion_pattern.cc | 4 ++-- .../gpu/kernels/custom_fusion_pattern.h | 4 ++-- .../xla/service/gpu/kernels/custom_kernel.cc | 4 ++-- .../xla/service/gpu/kernels/custom_kernel.h | 4 ++-- .../gpu/kernels/cutlass_gemm_fusion.cc | 10 +++++----- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 4 ++-- 14 files changed, 37 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc index 382cb50d410ce6..4c80f7636a06c8 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -33,9 +33,6 @@ limitations under the License. namespace xla::gpu { -using xla::gpu::kernel::CustomFusionPattern; -using xla::gpu::kernel::CustomFusionPatternRegistry; - CustomFusionRewriter::CustomFusionRewriter( const CustomFusionPatternRegistry* patterns) : patterns_(patterns) {} diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h index 2db45ea0c8558f..2d7ce59207e3c0 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h @@ -61,8 +61,7 @@ namespace xla::gpu { // class CustomFusionRewriter : public HloModulePass { public: - explicit CustomFusionRewriter( - const kernel::CustomFusionPatternRegistry* patterns); + explicit CustomFusionRewriter(const CustomFusionPatternRegistry* patterns); absl::string_view name() const override { return "custom-fusion-rewriter"; } @@ -71,7 +70,7 @@ class CustomFusionRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - const kernel::CustomFusionPatternRegistry* patterns_; + const CustomFusionPatternRegistry* patterns_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc index 12395d0316bfb1..b2a8c8656c590e 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc @@ -31,7 +31,7 @@ namespace xla::gpu { // Simple pattern matchers for testing custom fusion rewriter. //===----------------------------------------------------------------------===// -class SimpleGemmPattern : public kernel::CustomFusionPattern { +class SimpleGemmPattern : public CustomFusionPattern { public: std::optional TryMatch(HloInstruction* instr) const override { if (auto* dot = DynCast(instr)) { @@ -77,7 +77,7 @@ TEST_F(CustomFusionRewriterTest, SimpleGemm) { ; CHECK: } )"; - kernel::CustomFusionPatternRegistry patterns; + CustomFusionPatternRegistry patterns; patterns.Emplace(); CustomFusionRewriter pass(&patterns); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index d1cfa6012679f5..7f4cf6dff9728b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -322,7 +322,7 @@ StatusOr> BuildKernelThunkForFusion( StatusOr> BuildCustomKernelThunkForFusion( IrEmitterContext& ir_emitter_context, const HloFusionInstruction* fusion, - kernel::CustomKernel custom_kernel) { + CustomKernel custom_kernel) { TF_ASSIGN_OR_RETURN( auto kernel_arguments, KernelArguments::Create(ir_emitter_context.buffer_assignment(), fusion)); @@ -3252,7 +3252,7 @@ StatusOr IrEmitterUnnested::EmitCustomFusion( const HloFusionInstruction* fusion, const CustomFusionConfig& config) { VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); - auto* registry = kernel::CustomFusionRegistry::Default(); + auto* registry = CustomFusionRegistry::Default(); auto* custom_fusion = registry->Lookup(config.name()); // If custom fusion is not found it means that some of the build targets might @@ -3264,7 +3264,7 @@ StatusOr IrEmitterUnnested::EmitCustomFusion( // Load custom kernels that can implement a fusion computation. TF_ASSIGN_OR_RETURN( - std::vector kernels, + std::vector kernels, custom_fusion->LoadKernels(fusion->fused_instructions_computation())); // This should never happen, it means that compilation pipeline created a diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.cc b/third_party/xla/xla/service/gpu/kernel_thunk.cc index 1dfb672ef65c8e..c5fcc33231607c 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/kernel_thunk.cc @@ -173,7 +173,7 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { //===----------------------------------------------------------------------===// CustomKernelThunk::CustomKernelThunk( - const HloInstruction* instr, kernel::CustomKernel custom_kernel, + const HloInstruction* instr, CustomKernel custom_kernel, absl::Span kernel_arguments) : Thunk(Kind::kKernel, Thunk::ThunkInfo::WithProfileAnnotation(instr)), custom_kernel_(std::move(custom_kernel)) { diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.h b/third_party/xla/xla/service/gpu/kernel_thunk.h index be458cbcc494ea..ef857b52f9ccff 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/kernel_thunk.h @@ -136,8 +136,7 @@ class KernelThunk : public Thunk { // compiled by XLA and loaded from an executable source. class CustomKernelThunk : public Thunk { public: - CustomKernelThunk(const HloInstruction* instr, - kernel::CustomKernel custom_kernel, + CustomKernelThunk(const HloInstruction* instr, CustomKernel custom_kernel, absl::Span kernel_arguments); std::string ToStringExtra(int indent) const override; @@ -156,7 +155,7 @@ class CustomKernelThunk : public Thunk { // mlir::Value(s) corresponding to the buffer slice arguments. std::vector values_; - kernel::CustomKernel custom_kernel_; + CustomKernel custom_kernel_; // Loaded kernels for each `StreamExecutor`. mutable absl::Mutex mutex_; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc index 02e7b60c6dbe10..c35ecfa833fe38 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/status.h" -namespace xla::gpu::kernel { +namespace xla::gpu { //===----------------------------------------------------------------------===// // CustomFusionRegistry @@ -53,4 +53,4 @@ CustomFusion* CustomFusionRegistry::Lookup(std::string_view name) const { return nullptr; } -} // namespace xla::gpu::kernel +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h index 9110e20ddf5e2f..9311e43e630402 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion.h @@ -32,7 +32,7 @@ limitations under the License. #include "xla/statusor.h" #include "tsl/platform/logging.h" -namespace xla::gpu::kernel { +namespace xla::gpu { //===----------------------------------------------------------------------===// // CustomFusion @@ -129,7 +129,7 @@ class CustomFusionRegistry { ABSL_GUARDED_BY(mutex_); }; -} // namespace xla::gpu::kernel +} // namespace xla::gpu #define XLA_REGISTER_CUSTOM_FUSION(NAME, FUSION) \ XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, __COUNTER__) @@ -137,14 +137,14 @@ class CustomFusionRegistry { #define XLA_REGISTER_CUSTOM_FUSION_(NAME, FUSION, N) \ XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) -#define XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) \ - ABSL_ATTRIBUTE_UNUSED static const bool \ - xla_custom_fusion_##N##_registered_ = [] { \ - ::xla::Status status = \ - ::xla::gpu::kernel::CustomFusionRegistry::Default()->Register( \ - NAME, std::make_unique()); \ - if (!status.ok()) LOG(ERROR) << status; \ - return status.ok(); \ +#define XLA_REGISTER_CUSTOM_FUSION__(NAME, FUSION, N) \ + ABSL_ATTRIBUTE_UNUSED static const bool \ + xla_custom_fusion_##N##_registered_ = [] { \ + ::xla::Status status = \ + ::xla::gpu::CustomFusionRegistry::Default()->Register( \ + NAME, std::make_unique()); \ + if (!status.ok()) LOG(ERROR) << status; \ + return status.ok(); \ }() #endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc index 304d58045d993d..99f3fcb6274870 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" -namespace xla::gpu::kernel { +namespace xla::gpu { std::vector CustomFusionPatternRegistry::Match( HloInstruction* instr) const { @@ -38,4 +38,4 @@ void CustomFusionPatternRegistry::Add( patterns_.push_back(std::move(pattern)); } -} // namespace xla::gpu::kernel +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h index 43b1428b0ce4cd..f040d647490add 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h @@ -24,7 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" -namespace xla::gpu::kernel { +namespace xla::gpu { //===----------------------------------------------------------------------===// // CustomFusionPattern @@ -69,6 +69,6 @@ class CustomFusionPatternRegistry { std::vector> patterns_; }; -} // namespace xla::gpu::kernel +} // namespace xla::gpu #endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc index 192de82b180e72..0bc2de781bbe3b 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" -namespace xla::gpu::kernel { +namespace xla::gpu { CustomKernel::CustomKernel(se::MultiKernelLoaderSpec kernel_spec, se::BlockDim block_dims, se::ThreadDim thread_dims, @@ -44,4 +44,4 @@ size_t CustomKernel::shared_memory_bytes() const { return shared_memory_bytes_; } -} // namespace xla::gpu::kernel +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h index 1767c82ac5d376..85b1a2a9e639c3 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h @@ -25,7 +25,7 @@ limitations under the License. // it's included into all device kernel implementations, and we want to minimize // the number of (very expensive!) recompilations. -namespace xla::gpu::kernel { +namespace xla::gpu { namespace se = ::stream_executor; // NOLINT // Custom kernel is a mechanism for plugging pre-compiled device kernels into @@ -60,6 +60,6 @@ class CustomKernel { size_t shared_memory_bytes_; }; -} // namespace xla::gpu::kernel +} // namespace xla::gpu #endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index ce8e43208f1a09..6fb46343652591 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" -namespace xla::gpu::kernel { +namespace xla::gpu { class CutlassGemmFusion : public CustomFusion { public: @@ -53,12 +53,12 @@ class CutlassGemmFusion : public CustomFusion { size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN(auto kernel, GetCutlassGemmKernel(dtype, m, n, k)); + TF_ASSIGN_OR_RETURN(auto kernel, + kernel::GetCutlassGemmKernel(dtype, m, n, k)); return std::vector{std::move(kernel)}; } }; -} // namespace xla::gpu::kernel +} // namespace xla::gpu -XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", - ::xla::gpu::kernel::CutlassGemmFusion); +XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", ::xla::gpu::CutlassGemmFusion); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index afb5ac568983a8..e9f23f5604aa1b 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" -namespace xla::gpu::kernel { +namespace xla::gpu { class CutlassFusionTest : public HloTestBase { // Custom fusions are not supported by XLA runtime. @@ -65,4 +65,4 @@ TEST_F(CutlassFusionTest, SimpleF32Gemm) { error_spec, /*run_hlo_passes=*/false)); } -} // namespace xla::gpu::kernel +} // namespace xla::gpu From f01d901e43834b6f81107ec573d0372dd1bd42a2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 13 Nov 2023 21:08:44 -0800 Subject: [PATCH 050/391] Enable macOS Arm64 nightly builds PiperOrigin-RevId: 582177793 --- .bazelrc | 36 ++++++++---- ci/official/envs/ci_default | 1 + ci/official/envs/continuous_macos_arm64_py310 | 2 +- ci/official/envs/continuous_macos_arm64_py311 | 2 +- ci/official/envs/continuous_macos_arm64_py39 | 2 +- ci/official/envs/nightly_macos_arm64_py310 | 5 +- ci/official/envs/nightly_macos_arm64_py311 | 5 +- ci/official/envs/nightly_macos_arm64_py312 | 5 +- ci/official/envs/nightly_macos_arm64_py39 | 7 ++- .../utilities/rename_and_verify_wheels.sh | 2 +- ci/official/utilities/setup.sh | 5 ++ ci/official/utilities/setup_macos.sh | 57 +++++++++++++++++++ third_party/xla/.bazelrc | 36 ++++++++---- third_party/xla/third_party/tsl/.bazelrc | 36 ++++++++---- 14 files changed, 153 insertions(+), 48 deletions(-) create mode 100644 ci/official/utilities/setup_macos.sh diff --git a/.bazelrc b/.bazelrc index 99147608138fd6..cf45e783817917 100644 --- a/.bazelrc +++ b/.bazelrc @@ -692,19 +692,27 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +731,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -757,8 +769,8 @@ test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorf # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -781,6 +793,6 @@ test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorf # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index d73522d431a058..890069df93c982 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -30,3 +30,4 @@ TFCI_WHL_AUDIT_PLAT= TFCI_WHL_BAZEL_TEST_ENABLE=1 TFCI_WHL_SIZE_LIMIT= TFCI_WHL_SIZE_LIMIT_ENABLE=1 +TFCI_PYENV_INSTALL_LOCAL_ENABLE= diff --git a/ci/official/envs/continuous_macos_arm64_py310 b/ci/official/envs/continuous_macos_arm64_py310 index 4a0b98f6841b8a..42b304729bc113 100644 --- a/ci/official/envs/continuous_macos_arm64_py310 +++ b/ci/official/envs/continuous_macos_arm64_py310 @@ -1,4 +1,4 @@ -TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_macos_arm64_py311 b/ci/official/envs/continuous_macos_arm64_py311 index 8b94b11c71e1ea..8a47d7ef15c863 100644 --- a/ci/official/envs/continuous_macos_arm64_py311 +++ b/ci/official/envs/continuous_macos_arm64_py311 @@ -1,4 +1,4 @@ -TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_macos_arm64_py39 b/ci/official/envs/continuous_macos_arm64_py39 index 5799e1ca97fd00..da892742753048 100644 --- a/ci/official/envs/continuous_macos_arm64_py39 +++ b/ci/official/envs/continuous_macos_arm64_py39 @@ -1,4 +1,4 @@ -TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_DOCKER_ENABLE=0 TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310 index 9ceb6d869ff60f..38ae51e9e2f587 100644 --- a/ci/official/envs/nightly_macos_arm64_py310 +++ b/ci/official/envs/nightly_macos_arm64_py310 @@ -1,10 +1,11 @@ source ci/official/envs/disable_all_uploads -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION' +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag' +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 TFCI_UPLOAD_WHL_GCS_ENABLE=1 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M +TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311 index e294cd3c76bd02..218b292dd41f3f 100644 --- a/ci/official/envs/nightly_macos_arm64_py311 +++ b/ci/official/envs/nightly_macos_arm64_py311 @@ -1,10 +1,11 @@ source ci/official/envs/disable_all_uploads -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION' +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag' +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 TFCI_UPLOAD_WHL_GCS_ENABLE=1 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M +TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312 index cca86340d3ceff..7d89c9f31118d4 100644 --- a/ci/official/envs/nightly_macos_arm64_py312 +++ b/ci/official/envs/nightly_macos_arm64_py312 @@ -1,10 +1,11 @@ source ci/official/envs/disable_all_uploads -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION' +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag' +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 TFCI_UPLOAD_WHL_GCS_ENABLE=1 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M +TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39 index 35c0adf0155a32..4a3c24353b69c1 100644 --- a/ci/official/envs/nightly_macos_arm64_py39 +++ b/ci/official/envs/nightly_macos_arm64_py39 @@ -1,10 +1,13 @@ source ci/official/envs/disable_all_uploads -TFCI_BAZEL_COMMON_ARGS=(--config release_macos_arm64 --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION' +# TODO(srnitin): Add resultstore config once the macOS builds have the right +# permissions +TFCI_BAZEL_COMMON_ARGS="--config release_macos_arm64 --config tf_public_macos_cache_push --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 -TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag' +TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 TFCI_UPLOAD_WHL_GCS_ENABLE=1 TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M +TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 067f264a56ec60..8bd55c8f6aaf23 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -23,7 +23,7 @@ set -euxo pipefail cd "$TFCI_OUTPUT_DIR" -if [[ "$(ls *.whl | wc -l)" != "1" ]]; then +if [[ "$(ls *.whl | wc -l | tr -d ' ')" != "1" ]]; then echo "Error: $TFCI_OUTPUT_DIR should contain exactly one .whl file." exit 1 fi diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index 330d69f4681145..8c004aee8b8141 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -82,6 +82,11 @@ else fi fi +# Mac builds have some specific setup needs. See setup_macos.sh for details +if [[ "${OSTYPE}" =~ darwin* ]]; then + source ./ci/official/utilities/setup_macos.sh +fi + # Force-disable uploads if the job initiator is not Kokoro # This is temporary: it's currently standard practice for employees to # run nightly jobs for testing purposes. We're aiming to move away from diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh new file mode 100644 index 00000000000000..ab76d22cfaeece --- /dev/null +++ b/ci/official/utilities/setup_macos.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# macOS specific setup for all TF scripts. +# + +# Mac version of Core utilities differ in usage. Since our scripts are written +# with the GNU style, we need to set GNU utilities to be default on Mac. +if [[ -n "$(which grealpath)" ]] && [[ -n "$(which gstat)" ]]; then + alias realpath=grealpath + alias stat=gstat + # By default, aliases are only expanded in interactive shells, which means + # that they are not substituted for their corresponding commands in shell + # scripts. By setting "expand_aliases", we enable alias expansion in + # non-interactive shells as well. + shopt -s expand_aliases +else + echo '==TFCI==: Error: Cannot find path to grealpath or gstat' + echo 'TF CI scripts require GNU core utilties to be installed. Please make' + echo 'sure they are present on your system and try again.' + exit 1 +fi + +if [[ -n "${KOKORO_JOB_NAME}" ]]; then + # Mac builds need ~150 GB of disk space to be able to run all the tests. By + # default, Kokoro runs the Bazel commands in a partition that does not have + # enough free space so we need to set TEST_TMPDIR explicitly. + mkdir -p /Volumes/BuildData/bazel_output + export TEST_TMPDIR=/Volumes/BuildData/bazel_output + + # Before uploading the nightly and release wheels, we install them in a + # virtual environment and run some smoke tests on it. The Kokoro Mac VMs + # only have Python 3.11 installed so we need to install the other Python + # versions manually. + if [[ -n "${TFCI_BUILD_PIP_PACKAGE_ARGS}" ]] && [[ "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" != 3.11 ]]; then + pyenv install "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" + pyenv local "${TFCI_PYENV_INSTALL_LOCAL_ENABLE}" + fi +elif [[ "${TFCI_WHL_BAZEL_TEST_ENABLE}" == 1 ]]; then + echo '==TFCI==: Note: Mac builds need ~150 GB of disk space to be able to' + echo 'run all the tests. Please make sure your system has enough disk space' + echo 'You can control where Bazel stores test artifacts by setting the' + echo '`TEST_TMPDIR` environment variable.' +fi \ No newline at end of file diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 99147608138fd6..cf45e783817917 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -692,19 +692,27 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +731,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -757,8 +769,8 @@ test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorf # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -781,6 +793,6 @@ test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorf # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 99147608138fd6..cf45e783817917 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -692,19 +692,27 @@ build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda build:release_cpu_macos --config=avx_linux test:release_cpu_macos --config=release_base -# Build configs for macOS ARM CPUs +# Base build configs for macOS +build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer +build:release_macos_base --define=no_nccl_support=true --output_filter=^$ + +# Build configs for macOS Arm64 +build:release_macos_arm64 --config=release_macos_base build:release_macos_arm64 --cpu=darwin_arm64 -# Set DEVELOPER_DIR to select a version of Xcode. -build:release_macos_arm64 --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer -build:release_macos_arm64 --define=no_nccl_support=true -# Suppress all warning messages -build:release_macos_arm64 --output_filter=^$ -# Disable MKL build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0 # Target Moneterey as the minimum compatible OS version build:release_macos_arm64 --macos_minimum_os=12.0 build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0 +# Base test configs for macOS +test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS +test:release_macos_base --test_timeout=300,450,1200,3600 --test_output=errors +test:release_macos_base --build_tests_only --keep_going +test:release_macos_base --flaky_test_attempts=3 + +# Test configs for macOS Arm64 +test:release_macos_arm64 --config=release_macos_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true @@ -723,10 +731,14 @@ build:no_tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compil # Use --config=tf_public_cache to try and use the TensorFlow public build cache # to build TensorFlow. Look at ci/official/envs to find which types of jobs -# push to the cache. +# push to the cache. For macOS, use --config=tf_public_macos_cache build:tf_public_cache --remote_cache="https://storage.googleapis.com/tensorflow-devinfra-bazel-cache/september2022" --remote_upload_local_results=false # Cache pushes are limited to TF's CI system. build:tf_public_cache_push --config=tf_public_cache --remote_upload_local_results=true --google_default_credentials +# Public cache for macOS builds +build:tf_public_macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# Cache pushes are limited to TF's CI system. +build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_local_results=true --google_default_credentials # END TF CACHE HELPER OPTIONS # BEGIN TF TEST SUITE OPTIONS @@ -757,8 +769,8 @@ test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorf # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --test_lang_filters=py -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -781,6 +793,6 @@ test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorf # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test +test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS From 6ae0bb40a48e4e793689966a4baff1791a30ee8e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 21:22:38 -0800 Subject: [PATCH 051/391] Internal Code Change PiperOrigin-RevId: 582180328 --- tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index e72c58bdd6b846..4b2b0576430bd1 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -70,7 +70,6 @@ td_library( ], includes = ["."], visibility = [ - # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", ], deps = [ From 6ccd6e3ca7c15432f3657e1d253ce5a71e3e7a88 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Mon, 13 Nov 2023 22:53:48 -0800 Subject: [PATCH 052/391] Do not create new buffer during copying an empty FallbackTensor. The ImmutableTensorBuffer cannot handle empty root buffer properly and in that case, we don't need to create a new buffer as there is no atomic operations involved in empty tensors. PiperOrigin-RevId: 582196733 --- tensorflow/core/tfrt/utils/BUILD | 1 + tensorflow/core/tfrt/utils/fallback_tensor.h | 5 +++-- tensorflow/core/tfrt/utils/fallback_tensor_test.cc | 10 ++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tfrt/utils/BUILD b/tensorflow/core/tfrt/utils/BUILD index f9879e84d4562d..7517e07928f87b 100644 --- a/tensorflow/core/tfrt/utils/BUILD +++ b/tensorflow/core/tfrt/utils/BUILD @@ -270,6 +270,7 @@ tf_cc_test( deps = [ ":fallback_tensor", "//tensorflow/core/common_runtime:dma_helper", + "//tensorflow/core/framework:tensor_shape", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/tfrt/utils/fallback_tensor.h b/tensorflow/core/tfrt/utils/fallback_tensor.h index 393e06e75f4a90..0856117d2b7a09 100644 --- a/tensorflow/core/tfrt/utils/fallback_tensor.h +++ b/tensorflow/core/tfrt/utils/fallback_tensor.h @@ -64,7 +64,7 @@ class FallbackTensor { FallbackTensor(const FallbackTensor& other) { *this = other; } FallbackTensor& operator=(const FallbackTensor& other) { tsl::profiler::TraceMe trace_me("FallbackTensor::Copy"); - if (!other.is_immutable()) { + if (!other.is_immutable() && other.buffer() != nullptr) { // Create a new TensorBuffer which contains a new atomic counter for each // result, to avoid downstream threads contending the original atomic // counter. @@ -72,7 +72,8 @@ class FallbackTensor { tensorflow::tfrt_stub::ImmutableTensor::Create(other.tensor()) .tensor()); } else { - // For immutable tensors, we just need to copy the pointer. + // For immutable tensors or empty tensors, we just need to copy the + // pointer as they don't incur atomic operations when they are referenced. tensor_ = other.tensor(); } is_immutable_ = true; diff --git a/tensorflow/core/tfrt/utils/fallback_tensor_test.cc b/tensorflow/core/tfrt/utils/fallback_tensor_test.cc index 9c54e8704158c3..1e3de50a38d9fc 100644 --- a/tensorflow/core/tfrt/utils/fallback_tensor_test.cc +++ b/tensorflow/core/tfrt/utils/fallback_tensor_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace tfrt_stub { @@ -158,6 +159,15 @@ TEST(FallbackTensorTest, FallbackTensorCopyRootBuffer) { tensorflow::DMAHelper::buffer(&tensor)); } +TEST(FallbackTensorTest, EmptyTensor) { + tensorflow::Tensor tensor(tensorflow::DT_FLOAT, + tensorflow::TensorShape({1, 0})); + + FallbackTensor fallback_tensor(tensor); + auto copy = fallback_tensor; + ASSERT_FALSE(copy.buffer()); +} + } // namespace } // namespace tfrt_stub } // namespace tensorflow From e03eac63f8d60859230c35bfade42111676c03fc Mon Sep 17 00:00:00 2001 From: Ziyin Huang Date: Mon, 13 Nov 2023 23:02:47 -0800 Subject: [PATCH 053/391] Internal change only PiperOrigin-RevId: 582198556 --- tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc | 3 ++- tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 2a01539f3a4823..53a0c70779534d 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -966,7 +966,8 @@ void GetMinibatchSplitsWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { table_name_); CalculateHeadroom(this_max_ids, this_max_uniques, program_key, - max_ids_per_partition, max_unique_ids_per_partition); + max_ids_per_partition, max_unique_ids_per_partition, + dropped_id_count); Tensor* splits_tensor; OP_REQUIRES_OK( diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index 275df89c7e17f0..f2d35b3fa76cd6 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -98,7 +98,8 @@ class GetMinibatchSplitsWithPhysicalReplicaOp : public OpKernel { virtual void CalculateHeadroom(int32 this_max_ids, int32 this_max_uniques, tstring program_key, int64_t max_ids_per_partition, - int64_t max_unique_ids_per_partition) {} + int64_t max_unique_ids_per_partition, + int32_t dropped_id_count) {} virtual inline int32_t CalculateBucketIdWithHashing(int32_t col_id, int32_t num_buckets) { // TODO(pineapplejuice233): Add a proper hashing function here. From 7327a309bf56668651de8ae7b2f0ff33f3a2b22f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 23:05:03 -0800 Subject: [PATCH 054/391] Enable NVCC only for XLA's build_gpu_nvcc Kokoro job PiperOrigin-RevId: 582199198 --- third_party/xla/.kokoro/linux/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/.kokoro/linux/build.sh b/third_party/xla/.kokoro/linux/build.sh index 49b10b04a899ca..9444f5a12bfea1 100644 --- a/third_party/xla/.kokoro/linux/build.sh +++ b/third_party/xla/.kokoro/linux/build.sh @@ -27,7 +27,7 @@ function is_linux_gpu_job() { } function is_use_nvcc() { - [[ -z "${USE_NVCC:-}" ]] || [[ "$USE_NVCC" == "true" ]] + [[ "${USE_NVCC:-}" == "true" ]] } # Pull the container (in case it was updated since the instance started) and From d28c2ab17e4db85de0c309b0c51d2df33afa4d1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 01:02:21 -0800 Subject: [PATCH 055/391] Update GraphDef version to 1680. PiperOrigin-RevId: 582224743 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 12b17cb785e13b..28f282f8ebe8e5 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1679 // Updated: 2023/11/13 +#define TF_GRAPH_DEF_VERSION 1680 // Updated: 2023/11/14 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 9fa48d09da86703782f68c523327ad57c2cca8ed Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 01:02:24 -0800 Subject: [PATCH 056/391] compat: Update forward compatibility horizon to 2023-11-14 PiperOrigin-RevId: 582224756 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index fb919e026268a9..d49b9634a64a25 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 13) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 14) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 350774f252ed65a69c81da8e07804bb97d41ce8d Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 14 Nov 2023 02:42:45 -0800 Subject: [PATCH 057/391] Use `malloc` instead of `new` to allocate buffers to reduce overhead needed to ensure alignment. Note this can also leave the memory uninitialized, potentially improving performance (untested). PiperOrigin-RevId: 582249483 --- tensorflow/lite/simple_memory_arena.cc | 30 +++++++++++++++----------- tensorflow/lite/simple_memory_arena.h | 14 ++++++++---- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index f0b5f281985539..73da4d86b48ea8 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include #include +#include #include #include -#include #include -#include #include #include "tensorflow/lite/core/c/common.h" @@ -57,14 +57,13 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_allocation_size); #endif - auto new_buffer = std::make_unique(new_allocation_size); + char* new_buffer = reinterpret_cast(malloc(new_allocation_size)); char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); + AlignTo(alignment_, reinterpret_cast(new_buffer))); if (new_size > 0 && allocation_size_ > 0) { // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = - new_aligned_ptr - new_buffer.get(); - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); + const size_t new_alloc_alignment_adjustment = new_aligned_ptr - new_buffer; + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_; const size_t copy_amount = std::min(allocation_size_ - old_alloc_alignment_adjustment, new_allocation_size - new_alloc_alignment_adjustment); @@ -77,7 +76,8 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { allocation_size_); } #endif - buffer_ = std::move(new_buffer); + free(buffer_); + buffer_ = new_buffer; allocation_size_ = new_allocation_size; aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER @@ -87,13 +87,17 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { } void ResizableAlignedBuffer::Release() { + if (buffer_ != nullptr) { #ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - allocation_size_); + OnTfLiteArenaDealloc(subgraph_index_, + reinterpret_cast(this), + allocation_size_); #endif - buffer_.reset(); - allocation_size_ = 0; - aligned_ptr_ = nullptr; + free(buffer_); + buffer_ = nullptr; + allocation_size_ = 0; + aligned_ptr_ = nullptr; + } } void SimpleMemoryArena::PurgeAfter(int32_t node) { diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 05bb52e6a225e4..f42545dcb0caaa 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -17,8 +17,9 @@ limitations under the License. #include +#include +#include #include -#include #include #include @@ -58,7 +59,8 @@ struct ArenaAllocWithUsageInterval { class ResizableAlignedBuffer { public: explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : allocation_size_(0), + : buffer_(nullptr), + allocation_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { // To silence unused private member warning, only used with @@ -66,6 +68,8 @@ class ResizableAlignedBuffer { (void)subgraph_index_; } + ~ResizableAlignedBuffer() { Release(); } + // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps // alignment and any existing the data. Returns true when any external // pointers into the data array need to be adjusted (the buffer was moved). @@ -82,10 +86,12 @@ class ResizableAlignedBuffer { private: size_t RequiredAllocationSize(size_t data_array_size) const { - return data_array_size + alignment_ - 1; + // malloc guarantees returned pointers are aligned to at least max_align_t. + return data_array_size + + std::max(std::size_t{0}, alignment_ - alignof(std::max_align_t)); } - std::unique_ptr buffer_; + char* buffer_; size_t allocation_size_; size_t alignment_; char* aligned_ptr_; From 88ddb5843752577fb0e52f3fb33c22384aab61bf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 02:56:53 -0800 Subject: [PATCH 058/391] Internal Code Change PiperOrigin-RevId: 582253467 --- tensorflow/python/training/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index a041d50f4b61cb..847e93392b2c67 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -75,7 +75,6 @@ py_strict_library( visibility = [ "//tensorflow:internal", "//tensorflow_minigo:__subpackages__", - "//tensorflow_model_optimization:__subpackages__", "//tensorflow_models:__subpackages__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", "//third_party/mlperf:__subpackages__", @@ -229,7 +228,6 @@ py_strict_library( srcs_version = "PY3", visibility = [ "//tensorflow:internal", - "//tensorflow_estimator/python/estimator:__pkg__", "//third_party/py/tf_slim/training:__pkg__", ], deps = [ @@ -340,7 +338,6 @@ py_strict_library( srcs_version = "PY3", visibility = [ "//tensorflow:internal", - "//tensorflow_model_optimization/python/core/quantization/keras:__pkg__", "//third_party/py/tf_slim/layers:__pkg__", ], deps = [ From b65928ed4384a779a8b7e38ead84ae511935d3b4 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 14 Nov 2023 03:41:10 -0800 Subject: [PATCH 059/391] Do not use deprecated usePropertiesAsAttributes=0 for TFGraph dialects PiperOrigin-RevId: 582264389 --- tensorflow/core/ir/dialect.td | 1 - tensorflow/core/ir/types/dialect.td | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/core/ir/dialect.td b/tensorflow/core/ir/dialect.td index d80ecfa70cc19a..f3530af1232a16 100644 --- a/tensorflow/core/ir/dialect.td +++ b/tensorflow/core/ir/dialect.td @@ -172,7 +172,6 @@ def TFGraphDialect : Dialect { let useDefaultAttributePrinterParser = 1; let hasNonDefaultDestructor = 1; let hasOperationInterfaceFallback = 1; - let usePropertiesForAttributes = 0; } #endif // TFG_DIALECT diff --git a/tensorflow/core/ir/types/dialect.td b/tensorflow/core/ir/types/dialect.td index e6a13969b6843b..417b870977782a 100644 --- a/tensorflow/core/ir/types/dialect.td +++ b/tensorflow/core/ir/types/dialect.td @@ -47,7 +47,6 @@ def TFTypeDialect : Dialect { void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const; }]; let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 0; } //===----------------------------------------------------------------------===// From a71b73848c5a24916c75ef024e55f64fad783631 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 14 Nov 2023 04:31:15 -0800 Subject: [PATCH 060/391] Revert: Make the CPU backend participate in distributed initialization. The main effect of this change is that CPU devices end up with a unique global ID and the correct process index. PiperOrigin-RevId: 582275667 --- third_party/xla/xla/pjrt/BUILD | 2 - .../xla/xla/pjrt/tfrt_cpu_pjrt_client.cc | 62 ++--- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.h | 27 +- third_party/xla/xla/python/BUILD | 1 - third_party/xla/xla/python/xla.cc | 27 +- third_party/xla/xla/python/xla_client.py | 157 ++++------- third_party/xla/xla/python/xla_client.pyi | 6 +- .../xla/xla/python/xla_extension/__init__.pyi | 261 +++++++----------- 8 files changed, 183 insertions(+), 360 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index abbb5886b81cdd..29bd730a136c16 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -678,7 +678,6 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", - "//xla/pjrt/distributed:topology_util", "//xla/runtime:cpu_event", "//xla/service:buffer_assignment", "//xla/service:compiler", @@ -707,7 +706,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc index 12e5cce95e42d2..863f46ce2725af 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -41,7 +41,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -58,7 +57,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/abstract_tfrt_cpu_buffer.h" #include "xla/pjrt/compile_options.pb.h" -#include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -238,11 +236,7 @@ class TfrtCpuAsyncHostToDeviceTransferManager } // namespace -TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id, int process_index, - int local_hardware_id) - : id_(id), - process_index_(process_index), - local_hardware_id_(local_hardware_id) { +TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id) : id_(id) { debug_string_ = absl::StrCat("TFRT_CPU_", id); to_string_ = absl::StrCat("CpuDevice(id=", id, ")"); } @@ -259,9 +253,8 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { return to_string_; } -TfrtCpuDevice::TfrtCpuDevice(int id, int process_index, int local_hardware_id, - int max_inflight_computations) - : description_(id, process_index, local_hardware_id), +TfrtCpuDevice::TfrtCpuDevice(int id, int max_inflight_computations) + : description_(id), max_inflight_computations_semaphore_( /*capacity=*/max_inflight_computations) {} @@ -288,47 +281,30 @@ static int CpuDeviceCount() { return GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } +static StatusOr>> GetTfrtCpuDevices( + int cpu_device_count, int max_inflight_computations_per_device) { + std::vector> devices; + for (int i = 0; i < cpu_device_count; ++i) { + auto device = std::make_unique( + /*id=*/i, max_inflight_computations_per_device); + devices.push_back(std::move(device)); + } + return std::move(devices); +} + StatusOr> GetTfrtCpuClient( const CpuClientOptions& options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); - LocalTopologyProto local_topology; - local_topology.set_node_id(options.node_id); - std::string boot_id_str; - auto boot_id_str_or_status = GetBootIdString(); - if (!boot_id_str_or_status.ok()) { - LOG(INFO) << boot_id_str_or_status.status(); - } else { - boot_id_str = boot_id_str_or_status.value(); - } - local_topology.set_boot_id(boot_id_str); - for (int i = 0; i < cpu_device_count; ++i) { - DeviceProto* device_proto = local_topology.add_devices(); - device_proto->set_local_device_ordinal(i); - device_proto->set_name("cpu"); - } - - GlobalTopologyProto global_topology; - TF_RETURN_IF_ERROR( - ExchangeTopologies("cpu", options.node_id, options.num_nodes, - absl::Minutes(2), absl::Minutes(5), options.kv_get, - options.kv_put, local_topology, &global_topology)); - - std::vector> devices; - for (const LocalTopologyProto& node : global_topology.nodes()) { - for (const DeviceProto& device_proto : node.devices()) { - auto device = std::make_unique( - /*id=*/device_proto.global_device_id(), node.node_id(), - device_proto.local_device_ordinal(), - options.max_inflight_computations_per_device); - devices.push_back(std::move(device)); - } - } + TF_ASSIGN_OR_RETURN( + std::vector> devices, + GetTfrtCpuDevices(cpu_device_count, + options.max_inflight_computations_per_device)); return std::unique_ptr(std::make_unique( - /*process_index=*/options.node_id, std::move(devices), num_threads)); + /*process_index=*/0, std::move(devices), num_threads)); } TfrtCpuClient::TfrtCpuClient( diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h index e9543ab92e93e7..c744d10ac3a9ea 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -63,13 +63,11 @@ namespace xla { class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { public: - TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); + explicit TfrtCpuDeviceDescription(int id); int id() const override { return id_; } - int process_index() const override { return process_index_; } - - int local_hardware_id() const { return local_hardware_id_; } + int process_index() const override { return 0; } absl::string_view device_kind() const override; @@ -84,8 +82,6 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { private: int id_; - int process_index_; - int local_hardware_id_; std::string debug_string_; std::string to_string_; absl::flat_hash_map attributes_ = {}; @@ -93,8 +89,7 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { class TfrtCpuDevice final : public PjRtDevice { public: - explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, - int max_inflight_computations = 32); + explicit TfrtCpuDevice(int id, int max_inflight_computations = 32); const TfrtCpuDeviceDescription& description() const override { return description_; @@ -111,9 +106,8 @@ class TfrtCpuDevice final : public PjRtDevice { return process_index() == client()->process_index(); } - int local_hardware_id() const override { - return description_.local_hardware_id(); - } + // Used as `device_ordinal`. + int local_hardware_id() const override { return id(); } Status TransferToInfeed(const LiteralSlice& literal) override; @@ -524,17 +518,6 @@ struct CpuClientOptions { std::optional cpu_device_count = std::nullopt; int max_inflight_computations_per_device = 32; - - // Number of distributed nodes. node_id, kv_get, and kv_put are ignored if - // this is set to 1. - int num_nodes = 1; - - // My node ID. - int node_id = 0; - - // KV store primitives for sharing topology information. - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; }; StatusOr> GetTfrtCpuClient( const CpuClientOptions& options); diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index c7367a7f35ba59..c0a5938032acd7 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1153,7 +1153,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:literal", diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index b2269a423e5393..a4db3deef34730 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -41,7 +41,6 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "absl/types/span.h" #include "pybind11/attr.h" // from @pybind11 #include "pybind11/cast.h" // from @pybind11 @@ -493,38 +492,16 @@ static void Init(py::module_& m) { m.def( "get_tfrt_cpu_client", - [](bool asynchronous, - std::shared_ptr distributed_client, - int node_id, int num_nodes) -> std::shared_ptr { + [](bool asynchronous) -> std::shared_ptr { py::gil_scoped_release gil_release; CpuClientOptions options; - if (distributed_client != nullptr) { - std::string key_prefix = "cpu:"; - options.kv_get = - [distributed_client, key_prefix]( - const std::string& k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - options.kv_put = [distributed_client, key_prefix]( - const std::string& k, - const std::string& v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), - v); - }; - options.node_id = node_id; - options.num_nodes = num_nodes; - } - options.asynchronous = asynchronous; std::unique_ptr client = xla::ValueOrThrow(GetTfrtCpuClient(options)); return std::make_shared( ifrt::PjRtClient::Create(std::move(client))); }, - py::arg("asynchronous") = true, py::arg("distributed_client") = nullptr, - py::arg("node_id") = 0, py::arg("num_nodes") = 1); + py::arg("asynchronous") = true); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { xla::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index eba39f47830c89..6239cce3b38351 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 214 +_version = 213 # Version number for MLIR:Python components. mlir_api_version = 54 @@ -63,18 +63,11 @@ _NameValueMapping = Mapping[str, Union[str, int, List[int], float, bool]] -def make_cpu_client( - distributed_client=None, - node_id=0, - num_nodes=1, -) -> ...: - register_custom_call_handler('cpu', _xla.register_custom_call_target) - return _xla.get_tfrt_cpu_client( - asynchronous=True, - distributed_client=distributed_client, - node_id=node_id, - num_nodes=num_nodes, +def make_cpu_client() -> ...: + register_custom_call_handler( + 'cpu', _xla.register_custom_call_target ) + return _xla.get_tfrt_cpu_client(asynchronous=True) def make_gpu_client( @@ -104,8 +97,12 @@ def make_gpu_client( if memory_fraction: config.memory_fraction = float(memory_fraction) config.preallocate = preallocate not in ('0', 'false', 'False') - register_custom_call_handler('CUDA', _xla.register_custom_call_target) - register_custom_call_handler('ROCM', _xla.register_custom_call_target) + register_custom_call_handler( + 'CUDA', _xla.register_custom_call_target + ) + register_custom_call_handler( + 'ROCM', _xla.register_custom_call_target + ) return _xla.get_gpu_client( asynchronous=True, @@ -227,7 +224,6 @@ def generate_pjrt_gpu_plugin_options( class OpMetadata: """Python representation of a xla.OpMetadata protobuf.""" - __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') def __init__(self, op_type='', op_name='', source_file='', source_line=0): @@ -242,8 +238,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): full_filename, lineno = inspect.stack()[skip_frames][1:3] filename = os.path.basename(full_filename) return OpMetadata( - op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno - ) + op_type=op_type, + op_name=op_name, + source_file=filename, + source_line=lineno) PrimitiveType = _xla.PrimitiveType @@ -382,8 +380,7 @@ def convert(pyval): if isinstance(pyval, tuple): if layout is not None: raise NotImplementedError( - 'shape_from_pyval does not support layouts for tuple shapes' - ) + 'shape_from_pyval does not support layouts for tuple shapes') return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) else: return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) @@ -481,9 +478,8 @@ class PaddingType(enum.Enum): SAME = 2 -def window_padding_type_to_pad_values( - padding_type, lhs_dims, rhs_dims, window_strides -): +def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, + window_strides): """Maps PaddingType or string to pad values (list of pairs of ints).""" if not isinstance(padding_type, (str, PaddingType)): msg = 'padding_type must be str or PaddingType, got {}.' @@ -505,8 +501,7 @@ def window_padding_type_to_pad_values( pad_sizes = [ max((out_size - 1) * stride + filter_size - in_size, 0) for out_size, stride, filter_size, in_size in zip( - out_shape, window_strides, rhs_dims, lhs_dims - ) + out_shape, window_strides, rhs_dims, lhs_dims) ] return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] else: @@ -610,7 +605,6 @@ def register_custom_call_handler(platform: str, handler: Any) -> None: class PaddingConfigDimension: """Python representation of a xla.PaddingConfigDimension protobuf.""" - __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') edge_padding_low: int @@ -625,7 +619,6 @@ def __init__(self): class PaddingConfig: """Python representation of a xla.PaddingConfig protobuf.""" - __slots__ = ('dimensions',) def __init__(self): @@ -659,13 +652,8 @@ def make_padding_config( class DotDimensionNumbers: """Python representation of a xla.DotDimensionNumbers protobuf.""" - - __slots__ = ( - 'lhs_contracting_dimensions', - 'rhs_contracting_dimensions', - 'lhs_batch_dimensions', - 'rhs_batch_dimensions', - ) + __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', 'rhs_batch_dimensions') def __init__(self): self.lhs_contracting_dimensions = [] @@ -675,10 +663,9 @@ def __init__(self): def make_dot_dimension_numbers( - dimension_numbers: Union[ - DotDimensionNumbers, - Tuple[Tuple[List[int], List[int]], Tuple[List[int], List[int]]], - ] + dimension_numbers: Union[DotDimensionNumbers, + Tuple[Tuple[List[int], List[int]], + Tuple[List[int], List[int]]]] ) -> DotDimensionNumbers: """Builds a DotDimensionNumbers object from a specification. @@ -705,18 +692,11 @@ def make_dot_dimension_numbers( class ConvolutionDimensionNumbers: """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" - - __slots__ = ( - 'input_batch_dimension', - 'input_feature_dimension', - 'input_spatial_dimensions', - 'kernel_input_feature_dimension', - 'kernel_output_feature_dimension', - 'kernel_spatial_dimensions', - 'output_batch_dimension', - 'output_feature_dimension', - 'output_spatial_dimensions', - ) + __slots__ = ('input_batch_dimension', 'input_feature_dimension', + 'input_spatial_dimensions', 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', + 'output_batch_dimension', 'output_feature_dimension', + 'output_spatial_dimensions') def __init__(self): self.input_batch_dimension = 0 @@ -731,32 +711,30 @@ def __init__(self): def make_convolution_dimension_numbers( - dimension_numbers: Union[ - None, ConvolutionDimensionNumbers, Tuple[str, str, str] - ], - num_spatial_dimensions: int, -) -> ConvolutionDimensionNumbers: + dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, + str]], + num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: """Builds a ConvolutionDimensionNumbers object from a specification. Args: dimension_numbers: optional, either a ConvolutionDimensionNumbers object or - a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length - N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the - output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions in - rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers consistent - with the Conv operation with two spatial dimensions, one could use - ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension - numbers consistent with the TensorFlow Conv2D operation, one could use - ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution - dimension specification, window strides are associated with spatial - dimension character labels according to the order in which the labels - appear in the rhs_spec string, so that window_strides[0] is matched with - the dimension corresponding to the first character appearing in rhs_spec - that is not 'I' or 'O'. By default, use the same dimension numbering as - Conv and ConvWithGeneralPadding. + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of + length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. By default, use the same + dimension numbering as Conv and ConvWithGeneralPadding. num_spatial_dimensions: the number of spatial dimensions. Returns: @@ -786,26 +764,18 @@ def make_convolution_dimension_numbers( dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} - ) + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) dimension_numbers.input_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]), - ) - ) + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) dimension_numbers.output_spatial_dimensions.extend( - sorted( - (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]), - ) - ) + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) return dimension_numbers class PrecisionConfig: """Python representation of a xla.PrecisionConfig protobuf.""" - __slots__ = ('operand_precision',) Precision = _xla.PrecisionConfig_Precision @@ -816,13 +786,8 @@ def __init__(self): class GatherDimensionNumbers: """Python representation of a xla.GatherDimensionNumbers protobuf.""" - - __slots__ = ( - 'offset_dims', - 'collapsed_slice_dims', - 'start_index_map', - 'index_vector_dim', - ) + __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', + 'index_vector_dim') def __init__(self): self.offset_dims = [] @@ -833,13 +798,8 @@ def __init__(self): class ScatterDimensionNumbers: """Python representation of a xla.ScatterDimensionNumbers protobuf.""" - - __slots__ = ( - 'update_window_dims', - 'inserted_window_dims', - 'scatter_dims_to_operand_dims', - 'index_vector_dim', - ) + __slots__ = ('update_window_dims', 'inserted_window_dims', + 'scatter_dims_to_operand_dims', 'index_vector_dim') def __init__(self): self.update_window_dims = [] @@ -850,7 +810,6 @@ def __init__(self): class ReplicaGroup: """Python representation of a xla.ReplicaGroup protobuf.""" - __slots__ = ('replica_ids',) def __init__(self): diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index 04e22b8e7cf417..2eb82ec094f2d7 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -81,11 +81,7 @@ def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: .. def heap_profile(client: Client) -> bytes: ... -def make_cpu_client( - distributed_client: Optional[DistributedRuntimeClient] = ..., - node_id: int = ..., - num_nodes: int = ..., -) -> Client: +def make_cpu_client() -> Client: ... def make_gpu_client( diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index ed841b38217f96..19752e1c593903 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -20,26 +20,14 @@ import inspect import types import typing from typing import ( - Any, - Callable, - ClassVar, - Dict, - Iterator, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - overload, -) + Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, + Type, TypeVar, Union, overload) import numpy as np +from . import ops from . import jax_jit from . import mlir -from . import ops from . import outfeed_receiver from . import pmap_lib from . import profiler @@ -100,8 +88,7 @@ class Shape: type: Union[np.dtype, PrimitiveType], dims_seq: Any = ..., layout_seq: Any = ..., - dynamic_dimensions: Optional[List[bool]] = ..., - ) -> Shape: ... + dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @staticmethod @@ -149,7 +136,7 @@ class XlaComputation: def get_hlo_module(self) -> HloModule: ... def program_shape(self) -> ProgramShape: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_text(self, print_large_constants: bool=False) -> str: ... def as_hlo_dot_graph(self) -> str: ... def hash(self) -> int: ... def as_hlo_module(self) -> HloModule: ... @@ -189,11 +176,10 @@ class HloModule: @property def name(self) -> str: ... def to_string(self, options: HloPrintOptions = ...) -> str: ... - def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_serialized_hlo_module_proto(self)-> bytes: ... @staticmethod def from_serialized_hlo_module_proto( - serialized_hlo_module_proto: bytes, - ) -> HloModule: ... + serialized_hlo_module_proto: bytes) -> HloModule: ... def computations(self) -> List[HloComputation]: ... class HloModuleGroup: @@ -205,9 +191,10 @@ class HloModuleGroup: def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... + def hlo_module_cost_analysis( - client: Client, module: HloModule -) -> Dict[str, float]: ... + client: Client, + module: HloModule) -> Dict[str, float]: ... class XlaOp: ... @@ -227,8 +214,7 @@ class XlaBuilder: self, __output_index: Sequence[int], __param_number: int, - __param_index: Sequence[int], - ) -> None: ... + __param_index: Sequence[int]) -> None: ... class DeviceAssignment: @staticmethod @@ -252,18 +238,12 @@ class CompileOptions: profile_version: int device_assignment: Optional[DeviceAssignment] compile_portable_executable: bool - env_option_overrides: List[Tuple[str, str]] - -def register_custom_call_target( - fn_name: str, capsule: Any, platform: str -) -> _Status: ... -def register_custom_call_partitioner( - name: str, - prop_user_sharding: Callable, - partition: Callable, - infer_sharding_from_operands: Callable, - can_side_effecting_have_replicated_sharding: bool, -) -> None: ... + env_option_overrides: List[Tuple[str,str]] + +def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ... +def register_custom_call_partitioner(name: str, prop_user_sharding: Callable, + partition: Callable, infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... class DebugOptions: @@ -369,16 +349,11 @@ class HloSharding: @staticmethod def from_string(sharding: str) -> HloSharding: ... @staticmethod - def tuple_sharding( - shape: Shape, shardings: Sequence[HloSharding] - ) -> HloSharding: ... + def tuple_sharding(shape: Shape, shardings: Sequence[HloSharding]) -> HloSharding: ... @staticmethod - def iota_tile( - dims: Sequence[int], - reshape_dims: Sequence[int], - transpose_perm: Sequence[int], - subgroup_types: Sequence[OpSharding.Type], - ) -> HloSharding: ... + def iota_tile(dims: Sequence[int], reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... @staticmethod def replicate() -> HloSharding: ... @staticmethod @@ -440,17 +415,16 @@ class Memory: class GpuAllocatorConfig: class Kind(enum.IntEnum): - DEFAULT: int - PLATFORM: int - BFC: int - CUDA_ASYNC: int + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int def __init__( self, kind: Kind = ..., memory_fraction: float = ..., - preallocate: bool = ..., - ) -> None: ... + preallocate: bool = ...) -> None: ... class HostBufferSemantics(enum.IntEnum): IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics @@ -476,78 +450,61 @@ class Client: argument: Any, device: Optional[Device] = ..., force_copy: bool = ..., - host_buffer_semantics: HostBufferSemantics = ..., - ) -> ArrayImpl: ... + host_buffer_semantics: HostBufferSemantics = ...) -> ArrayImpl: ... def make_cross_host_receive_buffers( - self, shapes: Sequence[Shape], device: Device - ) -> List[Tuple[ArrayImpl, bytes]]: ... + self, + shapes: Sequence[Shape], + device: Device) -> List[Tuple[ArrayImpl, bytes]]: ... def compile( self, computation: Union[str, bytes], - compile_options: CompileOptions = ..., - host_callbacks: Sequence[Any] = ..., - ) -> LoadedExecutable: ... + compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... def deserialize_executable( - self, - serialized: bytes, - options: Optional[CompileOptions], - host_callbacks: Sequence[Any] = ..., - ) -> LoadedExecutable: ... + self, serialized: bytes, + options: Optional[CompileOptions], host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> _Status: ... def get_emit_python_callback_descriptor( - self, - callable: Callable, - operand_shapes: Sequence[Shape], - results_shapes: Sequence[Shape], - ) -> Tuple[Any, Any]: ... + self, callable: Callable, operand_shapes: Sequence[Shape], + results_shapes: Sequence[Shape]) -> Tuple[Any, Any]: ... def make_python_callback_from_host_send_and_recv( - self, - callable: Callable, - operand_shapes: Sequence[Shape], - result_shapes: Sequence[Shape], - send_channel_ids: Sequence[int], - recv_channel_ids: Sequence[int], - serializer: Optional[Callable] = ..., - ) -> Any: ... + self, callable: Callable, operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ...) -> Any: ... def __getattr__(self, name: str) -> Any: ... -def get_tfrt_cpu_client( - asynchronous: bool = ..., - distributed_client: Optional[DistributedRuntimeClient] = ..., - node_id: int = ..., - num_nodes: int = ..., -) -> Client: ... +def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... def get_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., - num_nodes: int = ..., allowed_devices: Optional[Any] = ..., platform_name: Optional[str] = ..., - mock: Optional[bool] = ..., -) -> Client: ... + mock:Optional[bool]=...) -> Client:... def get_mock_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., allowed_devices: Optional[Any] = ..., - platform_name: Optional[str] = ..., -) -> Client: ... + platform_name: Optional[str] = ...) -> Client:... def get_c_api_client( platform_name: str, options: Dict[str, Union[str, int, List[int], float, bool]], distributed_client: Optional[DistributedRuntimeClient] = ..., ) -> Client: ... + def get_default_c_api_topology( platform_name: str, topology_name: str, options: Dict[str, Union[str, int, List[int], float]], -) -> DeviceTopology: ... -def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... +) -> DeviceTopology: + ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: + ... + def load_pjrt_plugin(platform_name: str, library_path: str) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... @@ -584,14 +541,10 @@ ArrayImpl = Any # traceback: Traceback # _HAS_DYNAMIC_ATTRIBUTES: bool = ... -def copy_array_to_devices_with_sharding( - self: ArrayImpl, devices: List[Device], sharding: Any -) -> ArrayImpl: ... +def copy_array_to_devices_with_sharding(self: ArrayImpl, devices: List[Device], sharding: Any) -> ArrayImpl: ... + def batched_device_put( - aval: Any, - sharding: Any, - shards: Sequence[Any], - devices: List[Device], + aval: Any, sharding: Any, shards: Sequence[Any], devices: List[Device], committed: bool = True, ) -> ArrayImpl: ... @@ -600,8 +553,11 @@ def check_and_canonicalize_memory_kind( memory_kind: Optional[str], device_list: DeviceList) -> Optional[str]: ... def array_result_handler( - aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... -) -> Callable: ... + aval: Any, + sharding: Any, + committed: bool, + _skip_checks: bool = ...) -> Callable: + ... class Token: def block_until_ready(self): ... @@ -613,9 +569,7 @@ class ShardedToken: class ExecuteResults: def __len__(self) -> int: ... def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... - def disassemble_prefix_into_single_device_arrays( - self, n: int - ) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays(self, n: int) -> List[List[ArrayImpl]]: ... def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... def consume_token(self) -> ShardedToken: ... @@ -627,17 +581,18 @@ class LoadedExecutable: def delete(self) -> None: ... def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... def execute_with_token( - self, arguments: Sequence[ArrayImpl] - ) -> Tuple[List[ArrayImpl], Token]: ... + self, + arguments: Sequence[ArrayImpl]) -> Tuple[List[ArrayImpl], Token]: + ... def execute_sharded_on_local_devices( - self, arguments: Sequence[List[ArrayImpl]] - ) -> List[List[ArrayImpl]]: ... + self, + arguments: Sequence[List[ArrayImpl]]) -> List[List[ArrayImpl]]: ... def execute_sharded_on_local_devices_with_tokens( - self, arguments: Sequence[List[ArrayImpl]] - ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... + self, + arguments: Sequence[List[ArrayImpl]]) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... def execute_sharded( - self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... - ) -> ExecuteResults: ... + self, + arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ...) -> ExecuteResults: ... def hlo_modules(self) -> List[HloModule]: ... def get_output_memory_kinds(self) -> List[List[str]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... @@ -670,18 +625,14 @@ class DeviceTopology: def __getattr__(self, name: str) -> Any: ... def buffer_to_dlpack_managed_tensor( - buffer: ArrayImpl, stream: int | None = None -) -> Any: ... + buffer: ArrayImpl, + stream: int | None = None) -> Any: ... def dlpack_managed_tensor_to_buffer( - tensor: Any, device: Device, stream: int | None -) -> ArrayImpl: ... - + tensor: Any, device: Device, stream: int | None) -> ArrayImpl: ... # Legacy overload def dlpack_managed_tensor_to_buffer( - tensor: Any, - cpu_backend: Optional[Client] = ..., - gpu_backend: Optional[Client] = ..., -) -> ArrayImpl: ... + tensor: Any, cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ...) -> ArrayImpl: ... # === BEGIN py_traceback.cc @@ -700,12 +651,12 @@ class Traceback: def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... + @staticmethod def code_addr2line(code: types.CodeType, lasti: int) -> int: ... @staticmethod - def code_addr2location( - code: types.CodeType, lasti: int - ) -> Tuple[int, int, int, int]: ... + def code_addr2location(code: types.CodeType, + lasti: int) -> Tuple[int, int, int, int]: ... def replace_thread_exc_traceback(traceback: Any): ... @@ -713,20 +664,16 @@ def replace_thread_exc_traceback(traceback: Any): ... class DistributedRuntimeService: def shutdown(self) -> None: ... - class DistributedRuntimeClient: def connect(self) -> _Status: ... def shutdown(self) -> _Status: ... def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... - def blocking_key_value_get_bytes( - self, key: str, timeout_in_ms: int - ) -> _Status: ... + def blocking_key_value_get_bytes(self, key: str, timeout_in_ms: int) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str) -> _Status: ... - def key_value_delete(self, key: str) -> _Status: ... + def key_value_delete(self, key:str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int) -> _Status: ... - def get_distributed_runtime_service( address: str, num_nodes: int, @@ -743,16 +690,17 @@ def get_distributed_runtime_client( heartbeat_interval: Optional[int] = ..., max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., - shutdown_on_destruction: Optional[bool] = ..., -) -> DistributedRuntimeClient: ... + shutdown_on_destruction: Optional[bool] = ...) -> DistributedRuntimeClient: ... class PreemptionSyncManager: def initialize(self, client: DistributedRuntimeClient) -> _Status: ... def reached_sync_point(self, step_counter: int) -> bool: ... - def create_preemption_sync_manager() -> PreemptionSyncManager: ... + def collect_garbage() -> None: ... + def is_optimized_build() -> bool: ... + def json_to_pprof_profile(json: str) -> bytes: ... def pprof_profile_to_json(proto: bytes) -> str: ... @@ -766,9 +714,8 @@ class PmapFunction: def _cache_size(self) -> int: ... def _cache_clear(self) -> None: ... -def weakref_lru_cache( - cache_context_fn: Callable, call: Callable, maxsize=... -): ... +def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...): + ... class DeviceList: def __init__(self, device_assignment: Tuple[Device, ...]): ... @@ -791,18 +738,13 @@ class DeviceList: def memory_kinds(self) -> Tuple[str, ...]: ... class Sharding: ... + class XLACompatibleSharding(Sharding): ... class NamedSharding(XLACompatibleSharding): - def __init__( - self, - mesh: Any, - spec: Any, - *, - memory_kind: Optional[str] = None, - _parsed_pspec: Any = None, - _manual_axes: frozenset[Any] = frozenset(), - ): ... + def __init__(self, mesh: Any, spec: Any, *, memory_kind: Optional[str] = None, + _parsed_pspec: Any = None, + _manual_axes: frozenset[Any] = frozenset()): ... mesh: Any spec: Any _memory_kind: Optional[str] @@ -817,21 +759,15 @@ class SingleDeviceSharding(XLACompatibleSharding): _internal_device_list: DeviceList class PmapSharding(XLACompatibleSharding): - def __init__( - self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec - ): ... + def __init__(self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec): ... devices: List[Any] sharding_spec: pmap_lib.ShardingSpec _internal_device_list: DeviceList class GSPMDSharding(XLACompatibleSharding): - def __init__( - self, - devices: Sequence[Device], - op_sharding: Union[OpSharding, HloSharding], - *, - memory_kind: Optional[str] = None, - ): ... + def __init__(self, devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, memory_kind: Optional[str] = None): ... _devices: Tuple[Device, ...] _hlo_sharding: HloSharding _memory_kind: Optional[str] @@ -850,16 +786,12 @@ class PjitFunctionCache: @staticmethod def clear_all(): ... -def pjit( - function_name: str, - fun: Optional[Callable], - cache_miss: Callable, - static_argnums: Sequence[int], - static_argnames: Sequence[str], - donate_argnums: Sequence[int], - pytree_registry: pytree.PyTreeRegistry, - cache: Optional[PjitFunctionCache] = ..., -) -> PjitFunction: ... +def pjit(function_name: str, fun: Optional[Callable], cache_miss: Callable, + static_argnums: Sequence[int], static_argnames: Sequence[str], + donate_argnums: Sequence[int], + pytree_registry: pytree.PyTreeRegistry, + cache: Optional[PjitFunctionCache] = ..., + ) -> PjitFunction: ... class HloPassInterface: @property @@ -881,6 +813,9 @@ class TupleSimplifer(HloPassInterface): def __init__(self) -> None: ... def is_asan() -> bool: ... + def is_msan() -> bool: ... + def is_tsan() -> bool: ... + def is_sanitized() -> bool: ... From 6bd15c8379685f3139f9446eb1fafb5a46d8b4a1 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Tue, 14 Nov 2023 04:43:15 -0800 Subject: [PATCH 061/391] Use `malloc` instead of `new` to allocate buffers to reduce overhead needed to ensure alignment. Note this can also leave the memory uninitialized, potentially improving performance (untested). PiperOrigin-RevId: 582277847 --- tensorflow/lite/simple_memory_arena.cc | 30 +++++++++++--------------- tensorflow/lite/simple_memory_arena.h | 14 ++++-------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 73da4d86b48ea8..f0b5f281985539 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include #include -#include #include #include +#include #include +#include #include #include "tensorflow/lite/core/c/common.h" @@ -57,13 +57,14 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_allocation_size); #endif - char* new_buffer = reinterpret_cast(malloc(new_allocation_size)); + auto new_buffer = std::make_unique(new_allocation_size); char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer))); + AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); if (new_size > 0 && allocation_size_ > 0) { // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = new_aligned_ptr - new_buffer; - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_; + const size_t new_alloc_alignment_adjustment = + new_aligned_ptr - new_buffer.get(); + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); const size_t copy_amount = std::min(allocation_size_ - old_alloc_alignment_adjustment, new_allocation_size - new_alloc_alignment_adjustment); @@ -76,8 +77,7 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { allocation_size_); } #endif - free(buffer_); - buffer_ = new_buffer; + buffer_ = std::move(new_buffer); allocation_size_ = new_allocation_size; aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER @@ -87,17 +87,13 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { } void ResizableAlignedBuffer::Release() { - if (buffer_ != nullptr) { #ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - allocation_size_); + OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), + allocation_size_); #endif - free(buffer_); - buffer_ = nullptr; - allocation_size_ = 0; - aligned_ptr_ = nullptr; - } + buffer_.reset(); + allocation_size_ = 0; + aligned_ptr_ = nullptr; } void SimpleMemoryArena::PurgeAfter(int32_t node) { diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index f42545dcb0caaa..05bb52e6a225e4 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -17,9 +17,8 @@ limitations under the License. #include -#include -#include #include +#include #include #include @@ -59,8 +58,7 @@ struct ArenaAllocWithUsageInterval { class ResizableAlignedBuffer { public: explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : buffer_(nullptr), - allocation_size_(0), + : allocation_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { // To silence unused private member warning, only used with @@ -68,8 +66,6 @@ class ResizableAlignedBuffer { (void)subgraph_index_; } - ~ResizableAlignedBuffer() { Release(); } - // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps // alignment and any existing the data. Returns true when any external // pointers into the data array need to be adjusted (the buffer was moved). @@ -86,12 +82,10 @@ class ResizableAlignedBuffer { private: size_t RequiredAllocationSize(size_t data_array_size) const { - // malloc guarantees returned pointers are aligned to at least max_align_t. - return data_array_size + - std::max(std::size_t{0}, alignment_ - alignof(std::max_align_t)); + return data_array_size + alignment_ - 1; } - char* buffer_; + std::unique_ptr buffer_; size_t allocation_size_; size_t alignment_; char* aligned_ptr_; From 9e33fb3d6786f2f14d89576551bfb425d1298476 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 05:01:01 -0800 Subject: [PATCH 062/391] Internal visibility change only PiperOrigin-RevId: 582281437 --- tensorflow/core/util/autotune_maps/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index 94ccd691946c78..f4f13211ab2f8e 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -107,7 +107,10 @@ tf_proto_library( "//tensorflow/core/util/autotune_maps:conv_parameters_proto", "@local_tsl//tsl/protobuf:dnn_proto", ], - visibility = ["//waymo/ml/deploy/system/autotuning:__subpackages__"], + visibility = [ + "//waymo/ml/deploy/benchmark:__subpackages__", + "//waymo/ml/deploy/system/autotuning:__subpackages__", + ], ) # copybara:uncomment_begin(google-only) From 1da546ba04fdfab6748737a5759db3a16c268169 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Tue, 14 Nov 2023 05:38:01 -0800 Subject: [PATCH 063/391] Refactor `SimpleMemoryArena`, extract code that deals with the resizable aligned buffer to a separate class. PiperOrigin-RevId: 582289220 --- tensorflow/lite/simple_memory_arena.cc | 118 +++++++++++-------------- tensorflow/lite/simple_memory_arena.h | 62 ++++--------- 2 files changed, 71 insertions(+), 109 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index f0b5f281985539..694215115297bc 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -24,12 +24,10 @@ limitations under the License. #include #include #include -#include #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" - #ifdef TF_LITE_TENSORFLOW_PROFILER #include "tensorflow/lite/tensorflow_profiler_logger.h" #endif // TF_LITE_TENSORFLOW_PROFILER @@ -46,56 +44,6 @@ T AlignTo(size_t alignment, T offset) { namespace tflite { -bool ResizableAlignedBuffer::Resize(size_t new_size) { - const size_t new_allocation_size = RequiredAllocationSize(new_size); - if (new_allocation_size <= allocation_size_) { - // Skip reallocation when resizing down. - return false; - } -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/true); - OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), - new_allocation_size); -#endif - auto new_buffer = std::make_unique(new_allocation_size); - char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); - if (new_size > 0 && allocation_size_ > 0) { - // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = - new_aligned_ptr - new_buffer.get(); - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); - const size_t copy_amount = - std::min(allocation_size_ - old_alloc_alignment_adjustment, - new_allocation_size - new_alloc_alignment_adjustment); - memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); - } -#ifdef TF_LITE_TENSORFLOW_PROFILER - if (allocation_size_ > 0) { - OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - allocation_size_); - } -#endif - buffer_ = std::move(new_buffer); - allocation_size_ = new_allocation_size; - aligned_ptr_ = new_aligned_ptr; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/false); -#endif - return true; -} - -void ResizableAlignedBuffer::Release() { -#ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - allocation_size_); -#endif - buffer_.reset(); - allocation_size_ = 0; - aligned_ptr_ = nullptr; -} - void SimpleMemoryArena::PurgeAfter(int32_t node) { for (int i = 0; i < active_allocs_.size(); ++i) { if (active_allocs_[i].first_node > node) { @@ -143,7 +91,7 @@ TfLiteStatus SimpleMemoryArena::Allocate( TfLiteContext* context, size_t alignment, size_t size, int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc) { - TF_LITE_ENSURE(context, alignment <= underlying_buffer_.GetAlignment()); + TF_LITE_ENSURE(context, alignment <= arena_alignment_); new_alloc->tensor = tensor; new_alloc->first_node = first_node; new_alloc->last_node = last_node; @@ -194,12 +142,48 @@ TfLiteStatus SimpleMemoryArena::Allocate( } TfLiteStatus SimpleMemoryArena::Commit(bool* arena_reallocated) { - // Resize the arena to the high water mark (calculated by Allocate), retaining - // old contents and alignment in the process. Since Alloc pointers are offset - // based, they will remain valid in the new memory block. - *arena_reallocated = underlying_buffer_.Resize(high_water_mark_); + size_t required_size = RequiredBufferSize(); + if (required_size > underlying_buffer_size_) { + *arena_reallocated = true; +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/true); + OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), + required_size); +#endif + char* new_alloc = new char[required_size]; + char* new_underlying_buffer_aligned_ptr = reinterpret_cast( + AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); + + // If the arena had been previously allocated, copy over the old memory. + // Since Alloc pointers are offset based, they will remain valid in the new + // memory block. + if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { + size_t copy_amount = std::min( + underlying_buffer_.get() + underlying_buffer_size_ - + underlying_buffer_aligned_ptr_, + new_alloc + required_size - new_underlying_buffer_aligned_ptr); + memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, + copy_amount); + } + +#ifdef TF_LITE_TENSORFLOW_PROFILER + if (underlying_buffer_size_ > 0) { + OnTfLiteArenaDealloc(subgraph_index_, + reinterpret_cast(this), + underlying_buffer_size_); + } +#endif + underlying_buffer_.reset(new_alloc); + underlying_buffer_size_ = required_size; + underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/false); +#endif + } else { + *arena_reallocated = false; + } committed_ = true; - return kTfLiteOk; + return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; } TfLiteStatus SimpleMemoryArena::ResolveAlloc( @@ -207,12 +191,12 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - TF_LITE_ENSURE(context, underlying_buffer_.GetAllocationSize() >= - (alloc.offset + alloc.size)); + TF_LITE_ENSURE(context, + underlying_buffer_size_ >= (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { - *output_ptr = underlying_buffer_.GetPtr() + alloc.offset; + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; } return kTfLiteOk; } @@ -226,7 +210,13 @@ TfLiteStatus SimpleMemoryArena::ClearPlan() { TfLiteStatus SimpleMemoryArena::ReleaseBuffer() { committed_ = false; - underlying_buffer_.Release(); +#ifdef TF_LITE_TENSORFLOW_PROFILER + OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), + underlying_buffer_size_); +#endif + underlying_buffer_size_ = 0; + underlying_buffer_aligned_ptr_ = nullptr; + underlying_buffer_.reset(); return kTfLiteOk; } @@ -238,8 +228,8 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, - underlying_buffer_.GetAllocationSize(), active_allocs_); + tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_size_, + active_allocs_); } } // namespace tflite diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 05bb52e6a225e4..0e527df9ac98b1 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -55,44 +55,6 @@ struct ArenaAllocWithUsageInterval { } }; -class ResizableAlignedBuffer { - public: - explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : allocation_size_(0), - alignment_(alignment), - subgraph_index_(subgraph_index) { - // To silence unused private member warning, only used with - // TF_LITE_TENSORFLOW_PROFILER - (void)subgraph_index_; - } - - // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps - // alignment and any existing the data. Returns true when any external - // pointers into the data array need to be adjusted (the buffer was moved). - bool Resize(size_t new_size); - // Releases any allocated memory. - void Release(); - - // Pointer to the data array. - char* GetPtr() const { return aligned_ptr_; } - // Size of the allocation (NOT of the data array). - size_t GetAllocationSize() const { return allocation_size_; } - // Alignment of the data array. - size_t GetAlignment() const { return alignment_; } - - private: - size_t RequiredAllocationSize(size_t data_array_size) const { - return data_array_size + alignment_ - 1; - } - - std::unique_ptr buffer_; - size_t allocation_size_; - size_t alignment_; - char* aligned_ptr_; - - int subgraph_index_; -}; - // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is @@ -101,9 +63,11 @@ class ResizableAlignedBuffer { class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment, int subgraph_index = 0) - : committed_(false), + : subgraph_index_(subgraph_index), + committed_(false), + arena_alignment_(arena_alignment), high_water_mark_(0), - underlying_buffer_(arena_alignment, subgraph_index), + underlying_buffer_size_(0), active_allocs_() {} // Delete all allocs. This should be called when allocating the first node of @@ -135,6 +99,10 @@ class SimpleMemoryArena { int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc); + inline size_t RequiredBufferSize() { + return high_water_mark_ + arena_alignment_ - 1; + } + TfLiteStatus Commit(bool* arena_reallocated); TfLiteStatus ResolveAlloc(TfLiteContext* context, @@ -151,12 +119,10 @@ class SimpleMemoryArena { // again until Commit() is called & tensor allocations are resolved. TfLiteStatus ReleaseBuffer(); - size_t GetBufferSize() const { - return underlying_buffer_.GetAllocationSize(); - } + size_t GetBufferSize() const { return underlying_buffer_size_; } std::intptr_t BasePointer() const { - return reinterpret_cast(underlying_buffer_.GetPtr()); + return reinterpret_cast(underlying_buffer_aligned_ptr_); } // Dumps the memory allocation information of this memory arena (which could @@ -176,10 +142,16 @@ class SimpleMemoryArena { void DumpDebugInfo(const std::string& name, const std::vector& execution_plan) const; + protected: + int subgraph_index_; + private: bool committed_; + size_t arena_alignment_; size_t high_water_mark_; - ResizableAlignedBuffer underlying_buffer_; + std::unique_ptr underlying_buffer_; + size_t underlying_buffer_size_; + char* underlying_buffer_aligned_ptr_; std::vector active_allocs_; }; From 4a700bfdee45d4dc18479752be7fa190fa95da18 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 14 Nov 2023 05:46:50 -0800 Subject: [PATCH 064/391] [XLA] Serialize replica_count and num_partitions module properties to IR Serializing them is in line with all other module properties which affect compilation (aliasing, layout, etc.), and not serializing creates an impure compilation environment where IR does not and can not capture semantics of the module. PiperOrigin-RevId: 582290809 --- third_party/xla/xla/hlo/ir/hlo_module.cc | 8 +++ third_party/xla/xla/service/hlo_parser.cc | 14 +++++ .../xla/xla/service/hlo_parser_test.cc | 51 +++++++++++++++---- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index dd9570f874653a..d505039c06c65c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -379,6 +379,14 @@ void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { }); printer->Append("}"); } + if (config.replica_count() != 1) { + printer->Append(", replica_count="); + printer->Append(config.replica_count()); + } + if (config.num_partitions() != 1) { + printer->Append(", num_partitions="); + printer->Append(config.num_partitions()); + } if (!frontend_attributes_.map().empty()) { AppendCat(printer, ", frontend_attributes=", FrontendAttributesToString(frontend_attributes_)); diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 132d14aa876ff5..d7105c6ddb5be1 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -1006,6 +1006,8 @@ bool HloParserImpl::ParseHloModule(HloModule* module, bool parse_module_without_header) { std::string name; std::optional is_scheduled; + std::optional replica_count; + std::optional num_partitions; std::optional aliasing_data; std::optional buffer_donor_data; std::optional alias_passthrough_params; @@ -1015,6 +1017,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module, BoolList allow_spmd_sharding_propagation_to_output; attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + attrs["replica_count"] = {/*required=*/false, AttrTy::kInt64, &replica_count}; + attrs["num_partitions"] = {/*required=*/false, AttrTy::kInt64, + &num_partitions}; attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing, &aliasing_data}; attrs["buffer_donor"] = {/*required=*/false, AttrTy::kBufferDonor, @@ -1068,6 +1073,15 @@ bool HloParserImpl::ParseHloModule(HloModule* module, config.set_alias_passthrough_params(true); default_config = false; } + if (num_partitions.value_or(1) != 1) { + config.set_num_partitions(*num_partitions); + config.set_use_spmd_partitioning(true); + default_config = false; + } + if (replica_count.value_or(1) != 1) { + config.set_replica_count(*replica_count); + default_config = false; + } if (entry_computation_layout.has_value()) { *config.mutable_entry_computation_layout() = *entry_computation_layout; default_config = false; diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 4baf3be5bff223..3e069e3371d8e0 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -1816,7 +1816,7 @@ ENTRY CRS { // all-reduce with subgroups { "AllReduceWithSubgroups", -R"(HloModule CRS_Subgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}} +R"(HloModule CRS_Subgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}, replica_count=4 add { lhs = f32[] parameter(0) @@ -1933,7 +1933,7 @@ ENTRY AllGather { // all-gather with subgroups { "AllGatherWithSubgroups", -R"(HloModule AllGatherWithSubgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,64]{0,1}} +R"(HloModule AllGatherWithSubgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,64]{0,1}}, replica_count=4 ENTRY AllGatherWithSubgroups { input = f32[128,32]{0,1} parameter(0) @@ -1958,7 +1958,7 @@ ENTRY AllToAll { // all-to-all with subgroups { "AllToAllWithSubgroups", -R"(HloModule AllToAllWithSubgroups, entry_computation_layout={(f32[128,32]{0,1}, f32[128,32]{0,1})->(f32[128,32]{0,1}, f32[128,32]{0,1})} +R"(HloModule AllToAllWithSubgroups, entry_computation_layout={(f32[128,32]{0,1}, f32[128,32]{0,1})->(f32[128,32]{0,1}, f32[128,32]{0,1})}, replica_count=4 ENTRY AllToAllWithSubgroups { p0 = f32[128,32]{0,1} parameter(0) @@ -1972,7 +1972,7 @@ ENTRY AllToAllWithSubgroups { // collective-permute { "CollectivePermute", -R"(HloModule CollectivePermute, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}} +R"(HloModule CollectivePermute, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}, replica_count=4 ENTRY CollectivePermute { input = f32[128,32]{0,1} parameter(0) @@ -1985,7 +1985,7 @@ ENTRY CollectivePermute { // collective-permute with in-place updates { "CollectivePermuteInPlaceUpdate", -R"(HloModule CollectivePermuteInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}} +R"(HloModule CollectivePermuteInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}, replica_count=4 ENTRY CollectivePermuteInPlaceUpdate { input = f32[128,32]{0,1} parameter(0) @@ -2004,7 +2004,7 @@ ENTRY CollectivePermuteInPlaceUpdate { // collective-permute with in-place updates with multiple targets per source { "CollectivePermuteInPlaceUpdateMultipleReadWrite", -R"(HloModule CollectivePermuteInPlaceUpdateMultipleReadWrite, entry_computation_layout={(f32[8,8,128]{2,1,0})->f32[8,8,128]{2,1,0}} +R"(HloModule CollectivePermuteInPlaceUpdateMultipleReadWrite, entry_computation_layout={(f32[8,8,128]{2,1,0})->f32[8,8,128]{2,1,0}}, replica_count=4 ENTRY CollectivePermuteInPlaceUpdate { constant.3 = s32[] constant(2) @@ -2028,7 +2028,7 @@ ENTRY CollectivePermuteInPlaceUpdate { }, { "CollectivePermuteInPlaceUpdateTupleMultipleReadWrite", -R"(HloModule hlo_runner_test_0.1, entry_computation_layout={()->(u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})} +R"(HloModule hlo_runner_test_0.1, entry_computation_layout={()->(u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})}, replica_count=4 ENTRY hlo_runner_test_0.1 { replica_id = u32[] replica-id() @@ -2059,7 +2059,7 @@ ENTRY hlo_runner_test_0.1 { // collective-permute tuple with in-place updates { "CollectivePermuteTupleInPlaceUpdate", -R"(HloModule CollectivePermuteTupleInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,128]{0,1}, f32[128,128]{0,1})} +R"(HloModule CollectivePermuteTupleInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,128]{0,1}, f32[128,128]{0,1})}, replica_count=4 ENTRY CollectivePermuteInPlaceUpdate { input = f32[128,32]{0,1} parameter(0) @@ -2084,7 +2084,7 @@ ENTRY CollectivePermuteInPlaceUpdate { // collective-permute-start and -done with inplace update { "CollectivePermuteStartAndDone", -R"(HloModule CollectivePermuteStartAndDone, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}} +R"(HloModule CollectivePermuteStartAndDone, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}, replica_count=4 ENTRY CollectivePermuteStartAndDone { input = f32[128,32]{0,1} parameter(0) @@ -2098,7 +2098,7 @@ ENTRY CollectivePermuteStartAndDone { // collective-permute-start and -done { "CollectivePermuteStartAndDoneInplaceUpdate", -R"(HloModule CollectivePermuteStartAndDoneInplaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}} +R"(HloModule CollectivePermuteStartAndDoneInplaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}, replica_count=4 ENTRY CollectivePermuteStartAndDoneInplaceUpdate { input = f32[128,32]{0,1} parameter(0) @@ -4487,6 +4487,37 @@ ENTRY TestComputation { EXPECT_TRUE(result.value()->config().alias_passthrough_params()); } +TEST_F(HloParserTest, CheckReplicaCount) { + const char* const hlo_string = R"( +HloModule TestModule, replica_count=5 + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(result.value()->config().replica_count(), 5); +} + +TEST_F(HloParserTest, CheckNumPartitions) { + const char* const hlo_string = R"( +HloModule TestModule, num_partitions=3 + +ENTRY TestComputation { + p0 = f16[2048,1024] parameter(0) + p1 = f16[2048,1024] parameter(1) + ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1) +} +)"; + auto result = ParseAndReturnVerifiedModule(hlo_string); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(result.value()->config().num_partitions(), 3); + EXPECT_TRUE(result.value()->config().use_spmd_partitioning()); +} + TEST_F(HloParserTest, CheckFrontendAttributes) { const char* const hlo_string = R"( HloModule TestModule, frontend_attributes={attr_name="attr_value"} From f8bd1ae461266720087224fdfa009b6aaf79803b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 05:53:17 -0800 Subject: [PATCH 065/391] [XLA:GPU] Loosen up expectations on Int8 gemm test. We do not need to check the backend config field. PiperOrigin-RevId: 582292178 --- .../xla/xla/service/gpu/tests/gemm_rewrite_test.cc | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc index a05f411fec5f31..11f917376ebabf 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -935,17 +935,6 @@ ENTRY main.4 { MatchOptimizedHlo(hlo_text, R"( ; CHECK: [[GEMM:%[^ ]+]] = (s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) custom-call(s8[8,4]{1,0} %fusion.1, s8[4,4]{0,1} %bitcast.13), custom_call_target="__cublas$gemm", -; CHECK: backend_config={ -; CHECK-DAG: "selected_algorithm":"0" -; CHECK-DAG: "alpha_real":1 -; CHECK-DAG: "alpha_imag":0 -; CHECK-DAG: "beta":0 -; CHECK-DAG: "dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]} -; CHECK-DAG: "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]} -; CHECK-DAG: "epilogue":"DEFAULT" -; CHECK: } -; CHECK: [[RES:%[^ ]+]] = s32[8,4]{1,0} get-tuple-element((s32[8,4]{1,0}, s8[{{[0-9]+}}]{0}) [[GEMM]]), index=0 -; CHECK: ROOT [[OUT:%[^ ]+]] = s32[1,8,4]{2,1,0} bitcast(s32[8,4]{1,0} [[RES]]) )", /*print_operand_shape=*/true); } From 7848fa9d42576c0975729e5e18a8416469044b43 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 14 Nov 2023 06:49:09 -0800 Subject: [PATCH 066/391] Error out if ptxas version < 11.8 PiperOrigin-RevId: 582305685 --- third_party/xla/xla/service/gpu/nvptx_compiler.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index c98f64520c51ad..6354250dddf602 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -670,6 +670,14 @@ StatusOr NVPTXCompiler::ChooseLinkingMethod( TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, se::GetAsmCompilerVersion(preferred_cuda_dir)); + // ptxas versions prior to 11.8 are not supported anymore. We check this here, + // since we are fetching the ptxas version anyway. Catching the error + // elsewhere might introduce unnecessary overhead. + if (ptxas_version_tuple < std::array{11, 8, 0}) { + return Status(absl::StatusCode::kInternal, + "XLA requires ptxas version 11.8 or higher"); + } + static const std::optional> nvlink_version = GetNvLinkVersion(preferred_cuda_dir); if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { From 1fb7ae62ccc43b326a9609ead88a35b9db30a69c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Tue, 14 Nov 2023 07:32:29 -0800 Subject: [PATCH 067/391] Set HIGHEST precision for all dots and convolutions when TF32 execution is disabled Some operations, such as Einsum are converted through MlirXlaOpKernel, which doesn't set the precisions, so it caused precision problems recently. I added some testcases as well. PiperOrigin-RevId: 582315919 --- tensorflow/compiler/tf2xla/BUILD | 9 +- tensorflow/compiler/tf2xla/tf2xla_test.cc | 227 +++++++++++++++++++++ tensorflow/compiler/tf2xla/xla_compiler.cc | 41 ++++ 3 files changed, 275 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c8dc6721853ccb..5ae1c907f0138d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -14,6 +14,7 @@ load( ) load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -506,9 +507,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:executable_run_options", "@local_xla//xla:protobuf_util", "@local_xla//xla:shape_util", @@ -522,10 +525,11 @@ cc_library( "@local_xla//xla/client:xla_computation", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/service:computation_placer_hdr", + "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ] + if_libtpu([ ":xla_tpu_backend_registration", - ]), + ]) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), alwayslink = 1, ) @@ -901,6 +905,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:literal", "@local_xla//xla:literal_util", "@local_xla//xla:statusor", @@ -908,7 +913,7 @@ tf_cc_test( "@local_xla//xla/client:local_client", "@local_xla//xla/client:xla_computation", "@local_xla//xla/service:cpu_plugin", - ], + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) tf_cc_test( diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 1336d58521404a..01bb69d16ee264 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.h" +#include + #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" @@ -25,20 +27,61 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { +class ConvertGraphDefToXlaWithTF32Disabled : public ::testing::Test { + public: + ConvertGraphDefToXlaWithTF32Disabled() { + tsl::enable_tensor_float_32_execution(false); + } + ~ConvertGraphDefToXlaWithTF32Disabled() override { + tsl::enable_tensor_float_32_execution(true); + } +}; + AttrValue TypeAttrValue(DataType type) { AttrValue attr_value; SetAttrValue(type, &attr_value); return attr_value; } +AttrValue StringAttrValue(StringPiece str) { + AttrValue attr_value; + SetAttrValue(str, &attr_value); + return attr_value; +} + +AttrValue IntAttrValue(int i) { + AttrValue attr_value; + SetAttrValue(i, &attr_value); + return attr_value; +} + +AttrValue IntVectorAttrValue(const std::vector& ints) { + AttrValue attr_value; + SetAttrValue(ints, &attr_value); + return attr_value; +} + +TensorShapeProto TensorShape(const std::vector& dims) { + TensorShapeProto shape; + for (int i = 0; i < dims.size(); ++i) { + shape.add_dim(); + shape.mutable_dim(i)->set_size(dims[i]); + } + return shape; +} + GraphDef SumGraph() { GraphDef graph_def; NodeDef* x = graph_def.add_node(); @@ -97,6 +140,190 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } +GraphDef EinsumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* einsum = graph_def.add_node(); + einsum->set_name("einsum"); + einsum->set_op("Einsum"); + einsum->add_input("x"); + einsum->add_input("y"); + (*einsum->mutable_attr())["equation"] = StringAttrValue("ij,jk->ik"); + (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["N"] = IntAttrValue(2); + return graph_def; +} + +tf2xla::Config EinsumConfig() { + tf2xla::Config config; + + tf2xla::Feed* x_feed = config.add_feed(); + x_feed->mutable_id()->set_node_name("x"); + *x_feed->mutable_shape() = TensorShape({2, 2}); + + tf2xla::Feed* y_feed = config.add_feed(); + y_feed->mutable_id()->set_node_name("y"); + *y_feed->mutable_shape() = TensorShape({2, 2}); + + config.add_fetch()->mutable_id()->set_node_name("einsum"); + return config; +} + +TEST(ConvertGraphDefToXla, EinsumIsConvertedToDotWithDefaultPrecision) { + GraphDef graph_def = EinsumGraph(); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_dots, 1); +} + +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + EinsumIsConvertedToDotWithHighestPrecision) { + GraphDef graph_def = EinsumGraph(); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::HIGHEST); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::HIGHEST); + } + } + } + EXPECT_EQ(num_dots, 1); +} + +GraphDef Conv2DGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + NodeDef* einsum = graph_def.add_node(); + einsum->set_name("conv2d"); + einsum->set_op("Conv2D"); + einsum->add_input("x"); + einsum->add_input("y"); + (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["padding"] = StringAttrValue("VALID"); + (*einsum->mutable_attr())["strides"] = IntVectorAttrValue({1, 1, 1, 1}); + return graph_def; +} + +tf2xla::Config Conv2DConfig() { + tf2xla::Config config; + tf2xla::Feed* x_feed = config.add_feed(); + x_feed->mutable_id()->set_node_name("x"); + *x_feed->mutable_shape() = TensorShape({1, 1, 2, 2}); + + tf2xla::Feed* y_feed = config.add_feed(); + y_feed->mutable_id()->set_node_name("y"); + *y_feed->mutable_shape() = TensorShape({1, 1, 2, 2}); + config.add_fetch()->mutable_id()->set_node_name("conv2d"); + return config; +} + +TEST(ConvertGraphDefToXla, Conv2DIsConvertedToConvolutionWithDefaultPrecision) { + GraphDef graph_def = Conv2DGraph(); + tf2xla::Config config = Conv2DConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_convolutions = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "convolution") { + num_convolutions++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_convolutions, 1); +} + +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + Conv2DIsConvertedToConvolutionWithHighestPrecision) { + GraphDef graph_def = Conv2DGraph(); + tf2xla::Config config = Conv2DConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_convolutions = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "convolution") { + num_convolutions++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::HIGHEST); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::HIGHEST); + } + } + } + EXPECT_EQ(num_convolutions, 1); +} + TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { GraphDef graph_def = SumGraph(); tf2xla::Config config = SumConfig(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index aa2c761ccb6e26..bb8b29de5b9acf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include +#include #include #include #include @@ -27,9 +28,11 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" @@ -52,6 +55,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/protobuf_util.h" +#include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/device.h" @@ -72,6 +76,7 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { @@ -1435,6 +1440,38 @@ class DummyStackTrace : public AbstractStackTrace { StackFrame({"dummy_file_name", 10, "dummy_function_name"})}; }; +namespace { + +// Add precisions configs to the HLO module to avoid TensorFloat32 computations +// in XLA. +// +// Some operations, such as Einsum are converted through MlirXlaOpKernel, which +// doesn't set the precisions, so we set them all here. +// +// TODO(tdanyluk): We may want to restrict this logic to only set the operand +// precision for F32 operands. (Historically, it was set without regard to +// operand type in other parts of TF2XLA.) +void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { + static constexpr std::array kOpsPossiblyUsingTF32 = { + "dot", "convolution"}; + + xla::PrecisionConfig precision_config; + precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); + precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); + + for (xla::HloComputationProto& computation : *module.mutable_computations()) { + for (xla::HloInstructionProto& instruction : + *computation.mutable_instructions()) { + if (absl::c_find(kOpsPossiblyUsingTF32, instruction.opcode()) != + kOpsPossiblyUsingTF32.end()) { + *instruction.mutable_precision_config() = precision_config; + } + } + } +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -1571,6 +1608,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, *result->host_compute_metadata.add_host_to_device() = recv; } + if (!tsl::tensor_float_32_execution_enabled()) { + IncreasePrecisionsToAvoidTF32(*result->computation->mutable_proto()); + } + VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; VLOG(2) << "XLA output shape: " From dae596cb908da354c27ded156d2083e26be81739 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 09:14:13 -0800 Subject: [PATCH 068/391] Fix test gpu_aot_compilation_test PiperOrigin-RevId: 582343268 --- .../xla/xla/service/gpu/gpu_aot_compilation_test.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc index 34149050ae20c0..60eb1453b40057 100644 --- a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc @@ -45,6 +45,11 @@ ENTRY main { // Compile AOT. auto module_group = std::make_unique(std::move(module)); AotCompilationOptions aot_options(compiler.PlatformId()); + // ToDo: Remove after unification of AOT compiler + if (!aot_options.debug_options().xla_gpu_enable_xla_runtime_executable()) { + return; + } + aot_options.set_executor(stream_exec); TF_ASSERT_OK_AND_ASSIGN( std::vector> aot_results, @@ -85,6 +90,11 @@ ENTRY main { // Stream executor is not passed as an option. Compiler::TargetConfig gpu_target_config(stream_exec); AotCompilationOptions aot_options(compiler.PlatformId()); + // ToDo: Remove after unification of AOT compiler + if (!aot_options.debug_options().xla_gpu_enable_xla_runtime_executable()) { + return; + } + aot_options.set_target_config(gpu_target_config); TF_ASSERT_OK_AND_ASSIGN( From 4209e9cefaca57a0d708d770fb315ab6d361f148 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 14 Nov 2023 09:24:04 -0800 Subject: [PATCH 069/391] [XLA] [NFC] Unify graph rendering for fusion visualization and HTML visualization The HTML codepath has bitrotted, is not tested, and isn't currently working. Let's use the same approach as for fusion visualization, as it is working. PiperOrigin-RevId: 582346207 --- .../xla/xla/service/hlo_graph_dumper.cc | 186 +++++------------- 1 file changed, 46 insertions(+), 140 deletions(-) diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index b367a27d25df8e..bcb761531d2a5f 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -1706,114 +1706,6 @@ NodeFilter MakeNodeFromToFilter(const HloInstruction* from, }); } -std::string WrapDotInHtml(absl::string_view dot) { - std::string html_prefix = - absl::StrReplaceAll(R"html( - - - - - - - - $JS_INCLUDE -
- - - -)html"; - - return absl::StrCat(html_prefix, dot, html_suffix); -} - absl::Mutex url_renderer_mu(absl::kConstInit); std::function(absl::string_view)>* url_renderer ABSL_GUARDED_BY(url_renderer_mu) = nullptr; @@ -1860,27 +1752,6 @@ static std::pair FusionVisualizerStateKey( computation.unique_id()); } -// Precondition: (url_renderer != nullptr || format != kUrl). -// -// (We specify this as a precondition rather than checking it in here and -// returning an error because we want to fail quickly when there's no URL -// renderer available, and this function runs only after we've done all the work -// of producing dot for the graph.) -StatusOr WrapDotInFormat(const HloComputation& computation, - absl::string_view dot, - RenderedGraphFormat format) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { - switch (format) { - case RenderedGraphFormat::kUrl: - CHECK(url_renderer != nullptr) - << "Should have checked url_renderer != null before calling."; - return (*url_renderer)(dot); - case RenderedGraphFormat::kHtml: - return WrapDotInHtml(dot); - case RenderedGraphFormat::kDot: - return std::string(dot); - } -} } // namespace @@ -1926,13 +1797,9 @@ static std::string EscapeJSONString(absl::string_view raw) { "\""); } -StatusOr WrapFusionExplorer(const HloComputation& computation) { - absl::MutexLock lock(&fusion_visualizer_state_mu); - using absl::StrAppend; - using absl::StrFormat; - using absl::StrJoin; - const FusionVisualizerProgress& visualizer_progress = - fusion_visualizer_states[FusionVisualizerStateKey(computation)]; +StatusOr WrapFusionExplorer( + const FusionVisualizerProgress& visualizer_progress, + absl::string_view graph_title) { if (visualizer_progress.frames.empty()) { return InternalError("Empty"); } @@ -1954,7 +1821,7 @@ StatusOr WrapFusionExplorer(const HloComputation& computation) { CompressAndEncode(dot_graphs)); return absl::StrReplaceAll( - R"( + R"wrapper( @@ -2118,11 +1985,50 @@ StatusOr WrapFusionExplorer(const HloComputation& computation) { - )", + )wrapper", {{"$DOTS", dot_graphs_compressed}, {"$FRAMES", frames}, - {"$TITLE", - absl::StrCat(computation.parent()->name(), "_", computation.name())}}); + {"$TITLE", graph_title}}); +} + +static std::string GraphTitle(const HloComputation& computation) { + return absl::StrCat(computation.parent()->name(), "_", computation.name()); +} + +StatusOr WrapFusionExplorer(const HloComputation& computation) { + absl::MutexLock lock(&fusion_visualizer_state_mu); + const FusionVisualizerProgress& visualizer_progress = + fusion_visualizer_states[FusionVisualizerStateKey(computation)]; + return WrapFusionExplorer(visualizer_progress, GraphTitle(computation)); +} + +static StatusOr WrapDotInHtml(absl::string_view dot, + absl::string_view title) { + FusionVisualizerProgress progress; + progress.AddState(dot, title, std::nullopt); + return WrapFusionExplorer(progress, title); +} + +// Precondition: (url_renderer != nullptr || format != kUrl). +// +// (We specify this as a precondition rather than checking it in here and +// returning an error because we want to fail quickly when there's no URL +// renderer available, and this function runs only after we've done all the work +// of producing dot for the graph.) +static StatusOr WrapDotInFormat(const HloComputation& computation, + absl::string_view dot, + RenderedGraphFormat format) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { + switch (format) { + case RenderedGraphFormat::kUrl: + CHECK(url_renderer != nullptr) + << "Should have checked url_renderer != null before calling."; + return (*url_renderer)(dot); + case RenderedGraphFormat::kHtml: + return WrapDotInHtml(dot, GraphTitle(computation)); + case RenderedGraphFormat::kDot: + return std::string(dot); + } } void RegisterGraphToURLRenderer( From 12c16f14d6f2406fbfb54dfa03c137c070d427a3 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Tue, 14 Nov 2023 10:10:40 -0800 Subject: [PATCH 070/391] Redirect more references from the framework target to the new single-source-file targets. PiperOrigin-RevId: 582360353 --- tensorflow/lite/experimental/microfrontend/BUILD | 2 +- tensorflow/lite/python/BUILD | 6 +++--- tensorflow/lite/python/metrics/BUILD | 2 +- tensorflow/lite/python/testdata/BUILD | 2 +- tensorflow/python/framework/BUILD | 10 +++++++--- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/experimental/microfrontend/BUILD b/tensorflow/lite/experimental/microfrontend/BUILD index 1fb94ff67d2dea..e1c4f30baa7ffd 100644 --- a/tensorflow/lite/experimental/microfrontend/BUILD +++ b/tensorflow/lite/experimental/microfrontend/BUILD @@ -118,8 +118,8 @@ tf_custom_op_py_strict_library( srcs_version = "PY3", deps = [ ":audio_microfrontend_op", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:load_library", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:control_flow_ops", diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 9f690c41cacbbe..1547947dc80fe1 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -158,9 +158,9 @@ py_strict_test( "//tensorflow/python:tf2", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", @@ -200,11 +200,11 @@ py_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", - "//tensorflow/python/framework", "//tensorflow/python/framework:byte_swap_tensor", "//tensorflow/python/framework:convert_to_constants", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:versions", "//tensorflow/python/platform:gfile", @@ -363,9 +363,9 @@ py_strict_test( "//tensorflow/lite/python/testdata:double_op", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/lite/python/metrics/BUILD b/tensorflow/lite/python/metrics/BUILD index ba99624e18cef4..cc86ea7b46dc50 100644 --- a/tensorflow/lite/python/metrics/BUILD +++ b/tensorflow/lite/python/metrics/BUILD @@ -142,9 +142,9 @@ py_strict_test( "//tensorflow/python/client:session", "//tensorflow/python/eager:context", "//tensorflow/python/eager:monitoring", - "//tensorflow/python/framework", "//tensorflow/python/framework:convert_to_constants", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:importer", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index 5faaea63bc77af..f05dec25d9cab6 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -148,8 +148,8 @@ tf_custom_op_py_strict_library( srcs_version = "PY3", deps = [ ":gen_double_op_wrapper", - "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:load_library", "//tensorflow/python/platform:resource_loader", ], ) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index fcb92fb219d2d2..1217581e0cd647 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -24,7 +24,7 @@ load( "cuda_py_benchmark_test", ) -visibility = tf_python_framework_friends() +visibility = tf_python_framework_friends() # buildifier: disable=package-on-top package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -2004,7 +2004,9 @@ pytype_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", - ], + ] + if_xla_available([ + "//tensorflow/python:_pywrap_tfcompile", + ]), ) pytype_strict_library( @@ -2017,7 +2019,9 @@ pytype_strict_library( "//tensorflow/python/eager:context", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", - ], + ] + if_xla_available([ + "//tensorflow/python:_pywrap_tfcompile", + ]), ) py_strict_library( From f4efbd1151e09d5efd96450f30979e055953d098 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Tue, 14 Nov 2023 10:14:14 -0800 Subject: [PATCH 071/391] PR #6964: [ROCM] fixing build brakes 23-11-13 Imported from GitHub PR https://github.com/openxla/xla/pull/6964 Here is a fix for the oncoming build brakes due to recebt changes in GpuDriver API. Besides, I have also fixed the issue with headers in xla/service/gpu/ir_emitter_unnested.cc: otherwise, this would generate linker errors on ROCM platform when TF_HIPBLASLT=0 @xla-rotation: would you have a look, please ? Copybara import of the project: -- 14c2e300d827356ffc87c2e18a19f0514aa85e2a by Pavel Emeliyanenko : fixing buildbrakes -- 22b096266b0f65caf44317e03a63e9ff8d170de9 by Pavel Emeliyanenko : fixing buildifier warnings Merging this change closes #6964 PiperOrigin-RevId: 582361553 --- .../xla/xla/service/gpu/ir_emitter_unnested.cc | 4 ++-- third_party/xla/xla/service/gpu/kernels/BUILD | 8 +++++--- .../xla/xla/stream_executor/rocm/rocm_driver.cc | 12 ++++++++++++ .../xla/stream_executor/rocm/rocm_driver_wrapper.h | 1 + 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 7f4cf6dff9728b..4c2ca90915f744 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -162,12 +162,12 @@ limitations under the License. #include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA || TF_HIPBLASLT -#include "xla/service/gpu/cub_sort_thunk.h" #include "xla/service/gpu/gpublas_lt_matmul_thunk.h" -#include "xla/service/gpu/ir_emitter_triton.h" #endif // GOOGLE_CUDA || TF_HIPBLASLT #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/gpu/cub_sort_thunk.h" +#include "xla/service/gpu/ir_emitter_triton.h" #include "xla/service/gpu/runtime3/cholesky_thunk.h" #include "xla/service/gpu/runtime3/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 489e49c8d2b6a7..6c134e887a7af1 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -1,6 +1,8 @@ -load("//xla/tests:build_defs.bzl", "xla_test") -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") -load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +# copybara:uncomment_begin(google-only-loads) +# load("//xla/tests:build_defs.bzl", "xla_test") +# load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") +# load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") +# copybara:uncomment_end(google-only-loads) package( default_visibility = ["//visibility:public"], diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index ff8d949a1be594..6c4309575c8099 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -781,6 +781,18 @@ GpuDriver::GraphNodeGetType(hipGraphNode_t node) { return tsl::OkStatus(); } +/*static*/ tsl::Status GpuDriver::GraphExecChildNodeSetParams( + GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuGraphHandle child) { + VLOG(2) << "Set child node params " << node << " in graph executable " << exec + << "to params contained in " << child; + + RETURN_IF_ROCM_ERROR( + wrap::hipGraphExecChildGraphNodeSetParams(exec, node, child), + "Failed to set ROCm graph child node params"); + + return tsl::OkStatus(); +} + /* static */ tsl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 5e5240f73b935c..6749e09dc74d0d 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -106,6 +106,7 @@ namespace wrap { __macro(hipGraphAddChildGraphNode) \ __macro(hipGraphAddMemcpyNode) \ __macro(hipGraphAddMemcpyNode1D) \ + __macro(hipGraphExecChildGraphNodeSetParams) \ __macro(hipGraphCreate) \ __macro(hipGraphDebugDotPrint) \ __macro(hipGraphDestroy) \ From f4ef3ceba204ab289789deba8aa7f0ab6fc7ee4a Mon Sep 17 00:00:00 2001 From: Robert David Date: Tue, 14 Nov 2023 10:14:15 -0800 Subject: [PATCH 072/391] Use the right include for `kDefaultTensorAlignment`. Also fix the error message. PiperOrigin-RevId: 582361555 --- tensorflow/lite/kernels/BUILD | 2 +- tensorflow/lite/kernels/eigen_support.cc | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 367d53ea31a554..a0eb5ce425fc2a 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -301,7 +301,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":op_macros", - "//tensorflow/lite:arena_planner", + "//tensorflow/lite:util", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels/internal:optimized_eigen", ], diff --git a/tensorflow/lite/kernels/eigen_support.cc b/tensorflow/lite/kernels/eigen_support.cc index 0dc977e876cfbf..22cf62d36d14e5 100644 --- a/tensorflow/lite/kernels/eigen_support.cc +++ b/tensorflow/lite/kernels/eigen_support.cc @@ -18,11 +18,14 @@ limitations under the License. #include #include -#include "tensorflow/lite/arena_planner.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/lite/kernels/op_macros.h" +#ifndef EIGEN_DONT_ALIGN +#include "tensorflow/lite/util.h" +#endif // EIGEN_DONT_ALIGN + namespace tflite { namespace eigen_support { namespace { @@ -38,12 +41,11 @@ int GetNumThreads(int num_threads) { #ifndef EIGEN_DONT_ALIGN // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on -// hardware architecture and build configurations. -// If the static assertion fails, try to increase `kDefaultTensorAlignment` to -// in `arena_planner.h` to 32 or 64. +// hardware architecture and build configurations. If the static assertion +// fails, try to increase `kDefaultTensorAlignment` in `util.h` to 32 or 64. static_assert( kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, - "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement."); + "kDefaultTensorAlignment doesn't comply with Eigen alignment requirement."); #endif // EIGEN_DONT_ALIGN // Helper routine for updating the global Eigen thread count used for OpenMP. From ad9ad163f8a9210bc0e13e24df2088ed2cc6dcfb Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Tue, 14 Nov 2023 10:14:44 -0800 Subject: [PATCH 073/391] Minor grammar fix in comment. PiperOrigin-RevId: 582361720 --- tensorflow/lite/core/interpreter.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index ee9748c031e87d..5c2917e8be9f24 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -225,8 +225,8 @@ TfLiteStatus Interpreter::Invoke() { ScopedRuntimeInstrumentationProfile scoped_runtime_event(root_profiler_.get(), "invoke"); - // "Resets" cancellation flag so cancellation happens before this invoke will - // not take effect. + // "Resets" cancellation flag so cancellation that happens before this invoke + // will not take effect. if (cancellation_enabled_) (void)continue_invocation_.test_and_set(); // Denormal floating point numbers could cause significant slowdown on From 2346078bd1b19f90a71ed13e83cab5747acbe6b6 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 14 Nov 2023 10:20:15 -0800 Subject: [PATCH 074/391] [XLA] Minor tweaks to memory-bound loop optimizer. - Make loop detection more accurate by recording the latest instance of an instruction with matching fingerprint. - If the loop value allocation type isn't supported by the optimizer, still allow that tensor to get alternate memory allocation using the usual MSA algorithm. - Export the minimum num loop iteration as a field in the proto. PiperOrigin-RevId: 582363814 --- .../memory_space_assignment.cc | 17 +++++-- .../memory_space_assignment.h | 4 ++ .../memory_space_assignment.proto | 4 ++ .../memory_space_assignment_test.cc | 46 ++++++++++++++++++- 4 files changed, 66 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index a968b304940d4d..8e950d7dafafd9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -2692,6 +2692,12 @@ std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { values_str, "\n", allocations_str); } +bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { + return allocation_type == AllocationType::kTemporary || + allocation_type == AllocationType::kPinned || + allocation_type == AllocationType::kPrefetch; +} + void MemoryBoundLoopOptimizer::SortLoopValues() { absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { return a.savings_per_byte > b.savings_per_byte; @@ -3358,7 +3364,7 @@ Status AlternateMemoryBestFitHeap::OptimizeMemoryBoundLoop(int loop_start_idx, const int loop_optimized_allocations_original_size = loop_optimized_allocations_.size(); for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { - if (!value.allocations.empty()) { + if (!value.allocations.empty() && value.IsAllocationTypeSupported()) { loop_optimized_allocations_.push_back(std::move(value.allocations)); } } @@ -3476,7 +3482,6 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { // The minimum and maximum loop sizes that we consider. const int kMinLoopSize = 4; const int kMaxLoopSize = 400; - const float kMinNumIterations = 3.0; int optimized_loop_idx = 0; while (optimized_loop_idx < instruction_sequence.size()) { // Iterate over the flattened instruction sequence. We first try to find a @@ -3502,6 +3507,11 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { // We found two instructions with the same fingerprint. The distance // between the two is the loop size candidate. loop_size_candidate = distance; + // Update the fingerprint map with the current loop index so that if + // the loop size candidate doesn't find a valid loop, we can resume + // searching from this instruction. + fingerprint_schedule_map[fingerprint_it->second] = + optimized_loop_idx; break; } } @@ -3633,7 +3643,8 @@ void AlternateMemoryBestFitHeap::IdentifyAndOptimizeMemoryBoundLoops() { optimized_loop_idx = std::max(optimized_loop_idx, loop_end_idx) + 1; - if (num_iterations >= kMinNumIterations) { + if (num_iterations >= + options_.memory_bound_loop_optimizer_options.min_num_iterations()) { VLOG(2) << "Found valid loop. Loop start: " << loop_start_idx << " loop end: " << loop_end_idx << " num iterations: " << num_iterations; diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index 371b2c24708135..640b17b0fe59ab 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -1898,6 +1898,10 @@ class MemoryBoundLoopOptimizer { static std::string AllocationTypeToString(AllocationType allocation_type); std::string ToString() const; + // Returns true if memory-bound loop optimizer supports allocating this type + // of a loop value. + bool IsAllocationTypeSupported() const; + // The HloValues that correspond to this LoopValue. std::vector hlo_values; // The position in the header, if any. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto index f85692df7af210..426e4a154ff383 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -58,4 +58,8 @@ message MemoryBoundLoopOptimizerOptions { // pipelined prefetch starts the same time as its counterpart in the previous // iteration finishes. optional bool allow_unsatisfied_fully_pipelined_prefetch = 3; + + // The minimum number of iterations that the loop needs to be unrolled for the + // memory-bound loop optimizer to kick in. + optional float min_num_iterations = 4; } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index dc20c252483c56..be2bc573705c7f 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -9562,6 +9562,7 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { optimizer_options.set_enabled(true); optimizer_options.set_desired_copy_ratio(0.7); optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); + optimizer_options.set_min_num_iterations(3.0); options_.memory_bound_loop_optimizer_options = optimizer_options; options_.alternate_mem_bandwidth_bytes_per_second = 128; options_.async_copy_bandwidth_bytes_per_second = 32; @@ -9825,13 +9826,20 @@ ENTRY Entry { return preset_assignments; } - Status VerifyMsaEquivalence(HloModule* module) { + Status VerifyMsaEquivalence(HloModule* module, + bool expect_unsupported_allocations = false) { // Create a map indexed by instruction number and operand number. absl::flat_hash_map, const MemorySpaceAssignment::Allocation*> allocation_map; for (const MemoryBoundLoopOptimizer::LoopValue& value : optimizer_->loop_values()) { + // Skip verification for unsupported allocations as they will go through + // the usual MSA algorithm and may actually get an alternate memory + // allocation. + if (!value.IsAllocationTypeSupported()) { + continue; + } for (const auto& allocation : value.allocations) { for (const HloUse& use : allocation->uses()) { absl::string_view inst_name = use.instruction->name(); @@ -9873,7 +9881,10 @@ ENTRY Entry { for (int operand_number = 0; operand_number < 2; ++operand_number) { const HloInstruction* operand = inst->operand(operand_number); LOG(INFO) << inst->name() << ", operand " << operand_number; - TF_RET_CHECK(allocation_map.contains({inst_number, operand_number})); + if (!allocation_map.contains({inst_number, operand_number})) { + TF_RET_CHECK(expect_unsupported_allocations); + continue; + } const MemorySpaceAssignment::Allocation* allocation = allocation_map.at({inst_number, operand_number}); if (!allocation->is_copy_allocation()) { @@ -10393,6 +10404,37 @@ TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) { TF_ASSERT_OK(VerifyMsaEquivalence(module.get())); } +TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndUnsupportedAllocation) { + // op2 is a loop-carried dependency, which is currently not supported. But the + // usual MSA algorithm should still be able to give it an alternate memory + // allocation. + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $prev_op2, f32[1,4] $op0) + $op3 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $op2, f32[1,4] $op3) + ROOT $root = tuple($op1, $op4) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, + /*alternate_memory_size=*/1024, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + TF_ASSERT_OK_AND_ASSIGN(auto preset_assignments, + RunMsa(module.get(), /*alternate_memory_size=*/1024)); + + TF_ASSERT_OK(VerifyMsaEquivalence(module.get(), + /*expect_unsupported_allocations=*/true)); + + const HloInstruction* op2 = FindInstruction(module.get(), "op2"); + EXPECT_EQ(op2->shape().layout().memory_space(), kAlternateMemorySpace); +} + TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEndWhileLoop) { absl::string_view hlo_str = R"( HloModule module, is_scheduled=true From ef732533edc5b952260e2069c3ab0720a2de3ef8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 10:33:54 -0800 Subject: [PATCH 075/391] Re-enable layering_check for target. PiperOrigin-RevId: 582368537 --- tensorflow/python/BUILD | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8435bcc738268e..cd256c59759520 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -463,15 +463,6 @@ py_strict_library( ], ) -# Necessary for the pywrap inclusion below. -tf_pybind_cc_library_wrapper( - name = "tfcompile_headers_lib", - compatible_with = [], - deps = [ - "//tensorflow/compiler/aot:tfcompile_lib", - ], -) - tf_python_pybind_extension( name = "_pywrap_tfcompile", srcs = ["tfcompile_wrapper.cc"], @@ -481,15 +472,13 @@ tf_python_pybind_extension( "//tensorflow:windows": [], }), enable_stub_generation = True, - features = ["-layering_check"], pytype_srcs = [ "_pywrap_tfcompile.pyi", ], static_deps = tf_python_pybind_static_deps(), deps = [ - ":tfcompile_headers_lib", "@pybind11", - "//third_party/python_runtime:headers", + "//tensorflow/compiler/aot:tfcompile_lib", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", # The headers here cannot be brought in via cc_header_only_library From a26d6befeeaf40f36477c569d2c78c4ac8325bc6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 10:37:04 -0800 Subject: [PATCH 076/391] Enable running passes for H100 PiperOrigin-RevId: 582369654 --- .../xla/xla/service/gpu/ir_emitter_triton.cc | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index b31f2b5eb434fc..2ef9d0fdbe1c33 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -704,9 +704,11 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, const se::CudaComputeCapability& cc, int num_warps, int num_stages) { const int ccAsInt = cc.major * 10 + cc.minor; + const int threadsPerWarp = 32; + const int numCTAs = 1; // Based on optimize_ttir() in // @triton//:python/triton/compiler/compiler.py - pm.addPass(mt::createRewriteTensorPointerPass()); + pm.addPass(mt::createRewriteTensorPointerPass(ccAsInt)); pm.addPass(mlir::createInlinerPass()); pm.addPass(mt::createCombineOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -716,18 +718,25 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createSymbolDCEPass()); // Based on ttir_to_ttgir() in // @triton//:python/triton/compiler/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass(num_warps)); + pm.addPass(mt::createConvertTritonToTritonGPUPass(num_warps, threadsPerWarp, + numCTAs, ccAsInt)); // Based on optimize_ttgir() in // @triton//:python/triton/compiler/compiler.py pm.addPass(mlir::createTritonGPUCoalescePass()); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass()); + pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(/*clusterInfo=*/)); + pm.addPass(mlir::createTritonGPURewriteTensorPointerPass(ccAsInt)); + pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(/*clusterInfo=*/)); pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); pm.addPass(mlir::createTritonGPUAccelerateMatmulPass(ccAsInt)); pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); pm.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); - pm.addPass(mlir::createTritonGPUPipelinePass(num_stages, num_warps)); - pm.addPass(mlir::createTritonNvidiaGPUMaterializeLoadStorePass()); - pm.addPass(mlir::createTritonGPUPrefetchPass()); + pm.addPass(mlir::createTritonGPUPipelinePass(num_stages, num_warps, numCTAs, + ccAsInt)); + pm.addPass( + mlir::createTritonNvidiaGPUMaterializeLoadStorePass(num_warps, ccAsInt)); + if (ccAsInt <= 80) { + pm.addPass(mlir::createTritonGPUPrefetchPass()); + } pm.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); pm.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); pm.addPass(mlir::createTritonGPUDecomposeConversionsPass()); @@ -735,13 +744,17 @@ void CreateTritonPipeline(mlir::OpPassManager& pm, pm.addPass(mlir::createTritonGPUReorderInstructionsPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); + if (ccAsInt >= 90) { + pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); + } pm.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs()); // Based on translateTritonGPUToLLVMIR() in // @triton//:lib/Target/LLVMIR/LLVMIRTranslation.cpp pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass( - mt::createConvertTritonGPUToLLVMPass(ccAsInt, mt::Default, nullptr)); + pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt, + /*target=*/mt::Default, + /*tmaMetadata=*/nullptr)); pm.addPass(mt::createConvertNVGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); From ea9d0d7aa78537a4f1594aa5a6e84947e94cc730 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 10:40:58 -0800 Subject: [PATCH 077/391] Integrate LLVM at llvm/llvm-project@ed86e740effa Updates LLVM usage to match [ed86e740effa](https://github.com/llvm/llvm-project/commit/ed86e740effa) PiperOrigin-RevId: 582371041 --- third_party/llvm/generated.patch | 137 +++++++++++++++++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index a37125c400d30a..c1d1504a0ac731 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,4 +1,141 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp +--- a/llvm/lib/IR/Instruction.cpp ++++ b/llvm/lib/IR/Instruction.cpp +@@ -244,6 +244,8 @@ + Instruction::getDbgValueRange() const { + BasicBlock *Parent = const_cast(getParent()); + assert(Parent && "Instruction must be inserted to have DPValues"); ++ (void)Parent; ++ + if (!DbgMarker) + return DPMarker::getEmptyDPValueRange(); + +diff -ruN --strip-trailing-cr a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir +--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir ++++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir +@@ -54,7 +54,7 @@ + + func.func @mul(%arg0: tensor<4x6xf64>, + %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> { +- %out = tensor.empty() : tensor<4x4xf64> ++ %out = arith.constant dense<0.0> : tensor<4x4xf64> + %0 = linalg.generic #trait_mul + ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>) + outs(%out: tensor<4x4xf64>) { +@@ -68,7 +68,7 @@ + + func.func @mul_dense(%arg0: tensor<4x6xf64>, + %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> { +- %out = tensor.empty() : tensor<4x4xf64> ++ %out = arith.constant dense<0.0> : tensor<4x4xf64> + %0 = linalg.generic #trait_mul + ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>) + outs(%out: tensor<4x4xf64>) { +diff -ruN --strip-trailing-cr a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir +--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir ++++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir +@@ -85,30 +85,32 @@ + // A kernel that computes a BSR sampled dense matrix matrix multiplication + // using a "spy" function and in-place update of the sampling sparse matrix. + // +- func.func @SDDMM_block(%args: tensor, +- %arga: tensor, +- %argb: tensor) -> tensor { +- %result = linalg.generic #trait_SDDMM +- ins(%arga, %argb: tensor, tensor) +- outs(%args: tensor) { +- ^bb(%a: f32, %b: f32, %s: f32): +- %f0 = arith.constant 0.0 : f32 +- %u = sparse_tensor.unary %s : f32 to f32 +- present={ +- ^bb0(%p: f32): +- %mul = arith.mulf %a, %b : f32 +- sparse_tensor.yield %mul : f32 +- } +- absent={} +- %r = sparse_tensor.reduce %s, %u, %f0 : f32 { +- ^bb0(%p: f32, %q: f32): +- %add = arith.addf %p, %q : f32 +- sparse_tensor.yield %add : f32 +- } +- linalg.yield %r : f32 +- } -> tensor +- return %result : tensor +- } ++ // TODO: re-enable the following test. ++ // ++ // func.func @SDDMM_block(%args: tensor, ++ // %arga: tensor, ++ // %argb: tensor) -> tensor { ++ // %result = linalg.generic #trait_SDDMM ++ // ins(%arga, %argb: tensor, tensor) ++ // outs(%args: tensor) { ++ // ^bb(%a: f32, %b: f32, %s: f32): ++ // %f0 = arith.constant 0.0 : f32 ++ // %u = sparse_tensor.unary %s : f32 to f32 ++ // present={ ++ // ^bb0(%p: f32): ++ // %mul = arith.mulf %a, %b : f32 ++ // sparse_tensor.yield %mul : f32 ++ // } ++ // absent={} ++ // %r = sparse_tensor.reduce %s, %u, %f0 : f32 { ++ // ^bb0(%p: f32, %q: f32): ++ // %add = arith.addf %p, %q : f32 ++ // sparse_tensor.yield %add : f32 ++ // } ++ // linalg.yield %r : f32 ++ // } -> tensor ++ // return %result : tensor ++ // } + + func.func private @getTensorFilename(index) -> (!Filename) + +@@ -151,15 +153,15 @@ + // + %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) + %m_csr = sparse_tensor.new %fileName : !Filename to tensor +- %m_bsr = sparse_tensor.new %fileName : !Filename to tensor ++ // %m_bsr = sparse_tensor.new %fileName : !Filename to tensor + + // Call the kernel. + %0 = call @SDDMM(%m_csr, %a, %b) + : (tensor, + tensor, tensor) -> tensor +- %1 = call @SDDMM_block(%m_bsr, %a, %b) +- : (tensor, +- tensor, tensor) -> tensor ++ // %1 = call @SDDMM_block(%m_bsr, %a, %b) ++ // : (tensor, ++ // tensor, tensor) -> tensor + + // + // Print the result for verification. Note that the "spy" determines what +@@ -168,18 +170,18 @@ + // in the original zero positions). + // + // CHECK: ( 5, 10, 24, 19, 53, 42, 55, 56 ) +- // CHECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 ) ++ // C_HECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 ) + // + %v0 = sparse_tensor.values %0 : tensor to memref + %vv0 = vector.transfer_read %v0[%c0], %d0 : memref, vector<8xf32> + vector.print %vv0 : vector<8xf32> +- %v1 = sparse_tensor.values %1 : tensor to memref +- %vv1 = vector.transfer_read %v1[%c0], %d0 : memref, vector<12xf32> +- vector.print %vv1 : vector<12xf32> ++ // %v1 = sparse_tensor.values %1 : tensor to memref ++ // %vv1 = vector.transfer_read %v1[%c0], %d0 : memref, vector<12xf32> ++ // vector.print %vv1 : vector<12xf32> + + // Release the resources. + bufferization.dealloc_tensor %0 : tensor +- bufferization.dealloc_tensor %1 : tensor ++ // bufferization.dealloc_tensor %1 : tensor + + llvm.call @mgpuDestroySparseEnv() : () -> () + return diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 5df608affbc631..a8c20faa8cde27 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "b8a062061571b7868013f1fefb891bdaa2da1adc" - LLVM_SHA256 = "b5cafaa5d5f80f25e701a5c5e37898c2b6e0f925db298190c2b694f6e328275d" + LLVM_COMMIT = "ed86e740effaf1de540820a145a9df44eaf0df0e" + LLVM_SHA256 = "f5c849d3c450faa5a68a14e11f62bfd7d957df87603b39662d7f4179b21c7f7a" tf_http_archive( name = name, From f2d5de028fbc67b8f692d64324da20e22f683f2c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 11:00:25 -0800 Subject: [PATCH 078/391] Add the Python 3.12 classifier to setup.py. PiperOrigin-RevId: 582377585 --- tensorflow/tools/pip_package/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 471f322da06c57..3615a69b7f6fde 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -437,6 +437,7 @@ def find_files(pattern, root): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', From fe5b935aa50154e083dae8be9d79f69191783178 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Tue, 14 Nov 2023 11:04:32 -0800 Subject: [PATCH 079/391] Reduce visibility of api/v1 targets PiperOrigin-RevId: 582379057 --- tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 9b09ee088d00b4..477ce4cb0229a8 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -3,13 +3,17 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], + default_visibility = [ + "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", + "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", + ], ) cc_library( name = "compile_mlir_util_no_tf_dialect_passes", srcs = ["compile_mlir_util.cc"], hdrs = ["compile_mlir_util.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow", @@ -175,6 +179,9 @@ cc_library( name = "cluster_tf", srcs = ["cluster_tf.cc"], hdrs = ["cluster_tf.h"], + visibility = [ + "//tensorflow/compiler/tf2xla:__pkg__", + ], deps = [ ":tf_dialect_to_executor", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", @@ -233,6 +240,7 @@ cc_library( name = "tf_dialect_to_executor", srcs = ["tf_dialect_to_executor.cc"], hdrs = ["tf_dialect_to_executor.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", From d1c387d19f1fe49bd8a0332a3ba4c915fc623d87 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 14 Nov 2023 11:23:49 -0800 Subject: [PATCH 080/391] Add nvml headers for Windows, based on https://github.com/openxla/xla/pull/6994 PiperOrigin-RevId: 582385702 --- third_party/gpus/cuda/BUILD.windows.tpl | 8 ++++++++ .../tsl/third_party/gpus/cuda/BUILD.windows.tpl | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl index f20ecbd654bf6f..dee0e898d9ae7a 100644 --- a/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -197,6 +197,14 @@ cuda_header_library( deps = [":cuda_headers"], ) +cuda_header_library( + name = "nvml_headers", + hdrs = [":nvml"], + include_prefix = "third_party/gpus", + includes = ["cuda/nvml/include/"], + deps = [":cuda_headers"], +) + cc_import( name = "cupti_dsos", interface_library = "cuda/lib/%{cupti_lib}", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl index f20ecbd654bf6f..dee0e898d9ae7a 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -197,6 +197,14 @@ cuda_header_library( deps = [":cuda_headers"], ) +cuda_header_library( + name = "nvml_headers", + hdrs = [":nvml"], + include_prefix = "third_party/gpus", + includes = ["cuda/nvml/include/"], + deps = [":cuda_headers"], +) + cc_import( name = "cupti_dsos", interface_library = "cuda/lib/%{cupti_lib}", From f6bfbdcd67ce74d9e438503578f5dee54cce83c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 12:01:58 -0800 Subject: [PATCH 081/391] Use sharding propagation to infer dot/conv sharding given the operand shardings when enumerating sharding strategies for those ops. This is as opposed to the previous approach of using sharding propagation to infer operand shardings given the dot/conv. This approach does not work when one is looking to shard the contraction dimension and is therefore less cleaner than this new approach. PiperOrigin-RevId: 582397850 --- .../xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding.cc | 2 +- .../auto_sharding_dot_handler.cc | 328 ++++++++++-------- 3 files changed, 189 insertions(+), 142 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 72854e36ead212..27df7e0fe86b17 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -50,6 +50,7 @@ cc_library( "//xla/service:buffer_value", "//xla/service:call_graph", "//xla/service:computation_layout", + "//xla/service:dot_as_convolution_util", "//xla/service:dump", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 43985bb4cd1ef0..e7646ad59bf67f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -4573,7 +4573,7 @@ StatusOr AutoSharding::Run( std::unique_ptr module_with_default_solution = nullptr; if (option_.use_sharding_propagation_for_default_shardings) { module_with_default_solution = CloneModule(module); - // TODO(pratikf): Ensure that we're passing the correct customc all sharding + // TODO(pratikf): Ensure that we're passing the correct custom call sharding // helper to the sharding propagation pass. auto sharding_prop = ShardingPropagation( /*is_spmd */ true, /*propagate_metadata */ false, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index fb082f1169e2dd..a006072cc0a9f8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -34,8 +34,11 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" +#include "xla/service/dot_as_convolution_util.h" +#include "xla/service/sharding_propagation.h" #include "xla/status.h" #include "tsl/platform/errors.h" @@ -118,38 +121,88 @@ class HandlerBase { return Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh); } - HloSharding CreateInputSpecUsingShardingPropagation( - int operand_index, const HloSharding& output_spec) const { - std::optional operand_sharding = - GetInputSharding(ins_, ins_->operand(operand_index), operand_index, - output_spec, call_graph_, cluster_env_.NumDevices()); - CHECK(operand_sharding.has_value()); - return operand_sharding.value(); + // Given lhs and rhs dim maps, infers a sharding for the output by relying on + // the sharding_propagation pass. Given that this is a relatively new change + // (as of 11/2023), we also take an optional expected output dim map as an + // argument, to verify that sharding propagation in fact infers the sharding + // we expect (and to crash if it doesn't). + // TODO(b/309638633) As we build more confidence in this, we should remove + // this expected_output_dim_map argument and fully rely on sharding + // propagation. + void MaybeAppend( + const std::string& name, const DimMap& lhs_dim_map, + const DimMap& rhs_dim_map, + const std::optional& expected_output_dim_map, + const Array& device_mesh, double compute_cost = 0, + const std::optional>& + communication_cost_fn = std::nullopt) { + HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh); + HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh); + if (std::optional output_spec = + GetShardingFromUser(lhs_spec, rhs_spec); + output_spec.has_value()) { + if (expected_output_dim_map.has_value()) { + HloSharding expected_output_spec = + CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); + // TODO(b/308687597) Once the bug is resolved, we ideally either want + // have a CHECK statement verifying that the sharding inferred by + // sharding propagation is in fact what we expect, or we trust sharding + // propagation's results without the check. b/308687597 currently + // prevents us from doing so. AutoShardingTest.LargeSize in + // //third_party/tensorflow/compiler/xla/hlo/experimental/auto_sharding:auto_sharding_test + // currently fails due to the issue. + if (ins_->opcode() == HloOpcode::kDot && + *output_spec != expected_output_spec) { + output_spec = expected_output_spec; + LOG(ERROR) + << "The sharding inferred by sharding propagation in this case " + "does not match the expected sharding for the dot " + "instruction. This may be related to b/308687597. Given this " + "mismatch, we continue with the expected sharding"; + } + } + double communication_cost = 0; + if (communication_cost_fn.has_value()) { + communication_cost = communication_cost_fn.value()(*output_spec); + } + AppendNewStrategy(name, *output_spec, {lhs_spec, rhs_spec}, compute_cost, + communication_cost); + } else { + LOG(FATAL) << "Sharding propagation could not infer output sharding"; + } } - void MaybeAppend(const std::string& name, const HloSharding& output_spec, - const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, - const Array& device_mesh, double compute_cost = 0, - double communication_cost = 0, - bool use_sharding_propagation = true) { - if (!CheckDims(lhs_, lhs_dim_map) || !CheckDims(rhs_, rhs_dim_map)) return; - HloSharding lhs_spec = - use_sharding_propagation - ? CreateInputSpecUsingShardingPropagation(0, output_spec) - : CreateInputSpec(lhs_, lhs_dim_map, device_mesh); - HloSharding rhs_spec = - use_sharding_propagation - ? CreateInputSpecUsingShardingPropagation(1, output_spec) - : CreateInputSpec(rhs_, rhs_dim_map, device_mesh); - AppendNewStrategy(name, output_spec, {lhs_spec, rhs_spec}, compute_cost, - communication_cost); + std::optional GetShardingFromUser(const HloSharding& lhs_spec, + const HloSharding& rhs_spec) { + std::unique_ptr ins_clone = ins_->Clone(); + std::unique_ptr lhs_clone = lhs_->Clone(); + std::unique_ptr rhs_clone = rhs_->Clone(); + ins_clone->clear_sharding(); + lhs_clone->set_sharding(lhs_spec); + rhs_clone->set_sharding(rhs_spec); + CHECK_OK(ins_clone->ReplaceOperandWith(0, lhs_clone.get())); + CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); + if (ins_->opcode() == HloOpcode::kConvolution) { + xla::InferConvolutionShardingFromOperands( + ins_clone.get(), call_graph_, 10, + /* may_combine_partial_sharding */ true, /* is_spmd */ true); + } else { + xla::InferDotShardingFromOperands( + ins_clone.get(), call_graph_, + dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), + /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); + } + if (!ins_clone->has_sharding()) { + return std::nullopt; + } + return ins_clone->sharding(); } // Enumerates combinations of the given mesh + tensor dimensions. void Enumerate(std::function split_func, size_t num_outer_dims = 2, size_t num_inner_dims = 2, bool half = false) { - auto mesh_shape = device_mesh_.dimensions(); + absl::Span mesh_shape = device_mesh_.dimensions(); for (int64_t dim0 = 0; dim0 < mesh_shape.size(); ++dim0) { for (int64_t dim1 = 0; dim1 < mesh_shape.size(); ++dim1) { if (dim0 == dim1) continue; @@ -211,13 +264,12 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_space_dims_[e.j], e.mesh_dims[1]}}; std::string name = absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = - Tile(ins_->shape(), - {space_base_dim_ + e.i, - space_base_dim_ + static_cast(lhs_space_dims_.size()) + - e.j}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_); + + const DimMap out_dim_map = DimMap{ + {space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_space_dims_.size(), rhs_space_dims_.size()); } @@ -228,10 +280,10 @@ class DotHandler : public HandlerBase { {lhs_space_dims_[e.j], e.mesh_dims[1]}}; std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = - Tile(ins_->shape(), {space_base_dim_ + e.i, space_base_dim_ + e.j}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, {}, device_mesh_); + const DimMap out_dim_map = + DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, + {space_base_dim_ + e.j, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_); }; EnumerateHalf(func, lhs_space_dims_.size(), lhs_space_dims_.size()); } @@ -242,13 +294,12 @@ class DotHandler : public HandlerBase { {rhs_space_dims_[e.j], e.mesh_dims[1]}}; std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = Tile( - ins_->shape(), + const DimMap out_dim_map = DimMap{ {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - space_base_dim_ + static_cast(lhs_space_dims_.size()) + - e.j}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, {}, rhs_dim_map, device_mesh_); + e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, + e.mesh_dims[1]}}; + MaybeAppend(name, {}, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func, rhs_space_dims_.size(), rhs_space_dims_.size()); } @@ -264,13 +315,15 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, {lhs_con_dims_[e.j], e.mesh_dims[1]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.j], e.mesh_dims[1]}}; - HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + e.i}, - {e.mesh_dims[0]}, device_mesh_); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = + DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; + + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; Enumerate(func, lhs_space_dims_.size(), lhs_con_dims_.size()); } @@ -284,16 +337,15 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, {rhs_con_dims_[e.j], e.mesh_dims[0]}}; const DimMap lhs_dim_map = {{lhs_con_dims_[e.j], e.mesh_dims[0]}}; - HloSharding output_spec = - Tile(ins_->shape(), - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + - e.i}, - {e.mesh_dims[1]}, device_mesh_); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = DimMap{ + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; Enumerate(func, rhs_space_dims_.size(), lhs_con_dims_.size()); } @@ -307,8 +359,8 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_batch_dims_[e.i], e.j}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.i], e.j}}; std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", e.i, e.j); - HloSharding output_spec = Tile(ins_->shape(), {e.i}, {e.j}, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_); + const DimMap out_dim_map = DimMap{{e.i, e.j}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_batch_dims_.size(), device_mesh_.num_dimensions()); } @@ -325,9 +377,9 @@ class DotHandler : public HandlerBase { {rhs_batch_dims_[1], e.mesh_dims[1]}}; std::string name = absl::StrFormat("Sb = Sb x Sb @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = - Tile(ins_->shape(), {0, 1}, e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_); + const DimMap out_dim_map = + DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func, lhs_batch_dims_.size(), lhs_batch_dims_.size()); } @@ -343,10 +395,9 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[1]}, {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - HloSharding output_spec = - Tile(ins_->shape(), {e.j, space_base_dim_ + e.i}, e.mesh_dims, - device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_); + const DimMap out_dim_map = DimMap{ + {e.j, e.mesh_dims[0]}, {space_base_dim_ + e.i, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, lhs_space_dims_.size(), lhs_batch_dims_.size()); } @@ -362,12 +413,11 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap lhs_dim_map = {{lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - HloSharding output_spec = - Tile(ins_->shape(), - {e.j, space_base_dim_ + - static_cast(lhs_space_dims_.size()) + e.i}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_); + const DimMap out_dim_map = { + {e.j, e.mesh_dims[0]}, + {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, + e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; Enumerate(func, rhs_space_dims_.size(), lhs_batch_dims_.size()); } @@ -384,13 +434,13 @@ class DotHandler : public HandlerBase { const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - HloSharding output_spec = - Tile(ins_->shape(), {e.j}, {e.mesh_dims[0]}, device_mesh_); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; Enumerate(func, lhs_con_dims_.size(), lhs_batch_dims_.size()); } @@ -409,12 +459,14 @@ class DotHandler : public HandlerBase { {lhs_con_dims_[e.j], e.mesh_dims[1]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}, {rhs_con_dims_[e.j], e.mesh_dims[1]}}; - HloSharding output_spec = HloSharding::Replicate(); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = cluster_env_.AllReduceCost( - memory_cost, e.mesh_dims[0], e.mesh_dims[1]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = DimMap{}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0], + e.mesh_dims[1]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; EnumerateHalf(func, lhs_con_dims_.size(), lhs_con_dims_.size()); } @@ -428,14 +480,14 @@ class DotHandler : public HandlerBase { e.mesh_dims[0], e.mesh_dims[0]); const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}}; - HloSharding output_spec = HloSharding::Replicate(); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + const DimMap out_dim_map = DimMap{}; double compute_cost = cluster_env_.DotCost(lhs_->shape(), rhs_->shape()); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, - compute_cost, communication_cost, - /* use_sharding_propagation */ false); + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, + compute_cost, communication_cost_fn); }; Enumerate(func, lhs_con_dims_.size(), 1); } @@ -459,9 +511,8 @@ class DotHandler : public HandlerBase { continue; } std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); - HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + i}, - {mesh_dim}, device_mesh_1d_); - MaybeAppend(name, output_spec, lhs_dim_map, {}, device_mesh_1d_); + const DimMap out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); } // R = Sk x Sk @ (allreduce @ 0) @@ -478,13 +529,14 @@ class DotHandler : public HandlerBase { } std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", mesh_dim, mesh_dim); - HloSharding output_spec = HloSharding::Replicate(); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, mesh_dim); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, - device_mesh_1d_, 0, communication_cost, - /* use_sharding_propagation */ false); + const DimMap out_dim_map = DimMap{}; + auto communication_cost_fn = [this, mesh_dim]( + const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, mesh_dim); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, + device_mesh_1d_, 0, communication_cost_fn); } } } @@ -499,9 +551,8 @@ class DotHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_batch_dims_[i], mesh_dim}}; std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); - HloSharding output_spec = - Tile(ins_->shape(), {i}, {mesh_dim}, device_mesh_1d_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, + const DimMap out_dim_map = DimMap{{i, mesh_dim}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_); } } @@ -647,11 +698,9 @@ class ConvHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; std::string name = absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = - Tile(ins_->shape(), {out_batch_dim_, out_out_channel_dim_}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - 0, /* use_sharding_propagation */ false); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, + {out_out_channel_dim_, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func); } @@ -667,13 +716,13 @@ class ConvHandler : public HandlerBase { std::string name = absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - HloSharding output_spec = - Tile(ins_->shape(), {out_batch_dim_}, {e.mesh_dims[0]}, device_mesh_); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; EnumerateHalf(func); } @@ -687,13 +736,13 @@ class ConvHandler : public HandlerBase { std::string name = absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); - HloSharding output_spec = Tile(ins_->shape(), {out_out_channel_dim_}, - {e.mesh_dims[1]}, device_mesh_); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = - cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - communication_cost, /* use_sharding_propagation */ false); + const DimMap out_dim_map = {{out_out_channel_dim_, e.mesh_dims[1]}}; + auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, + communication_cost_fn); }; EnumerateHalf(func); } @@ -709,10 +758,8 @@ class ConvHandler : public HandlerBase { if (lhs_->shape().dimensions(lhs_batch_dim_) % num_devices == 0) { const DimMap lhs_dim_map = {{lhs_batch_dim_, mesh_dim}}; std::string name = absl::StrFormat("Si = Si x R @ 0"); - HloSharding output_spec = - Tile(ins_->shape(), {out_batch_dim_}, {mesh_dim}, device_mesh_1d_); - MaybeAppend(name, output_spec, lhs_dim_map, {}, device_mesh_1d_, 0, 0, - /* use_sharding_propagation */ false); + const DimMap out_dim_map = {{out_batch_dim_, mesh_dim}}; + MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); } // R = Sk x Sk @ (allreduce @ 0) @@ -722,13 +769,14 @@ class ConvHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_in_channel_dim_, mesh_dim}}; std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", mesh_dim, mesh_dim); - HloSharding output_spec = HloSharding::Replicate(); - double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); - double communication_cost = cluster_env_.AllReduceCost(memory_cost, 0) + - cluster_env_.AllReduceCost(memory_cost, 1); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, - device_mesh_1d_, 0, communication_cost, - /* use_sharding_propagation */ false); + const DimMap out_dim_map = {}; + auto communication_cost_fn = [this](const HloSharding& output_spec) { + double memory_cost = GetBytes(ins_->shape()) / output_spec.NumTiles(); + return cluster_env_.AllReduceCost(memory_cost, 0) + + cluster_env_.AllReduceCost(memory_cost, 1); + }; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, + device_mesh_1d_, 0, communication_cost_fn); } } } @@ -741,11 +789,9 @@ class ConvHandler : public HandlerBase { const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; std::string name = absl::StrFormat("SS = SS x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - HloSharding output_spec = - Tile(ins_->shape(), {out_batch_dim_, out_out_channel_dim_}, - e.mesh_dims, device_mesh_); - MaybeAppend(name, output_spec, lhs_dim_map, rhs_dim_map, device_mesh_, 0, - 0, /* use_sharding_propagation */ false); + const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, + {out_out_channel_dim_, e.mesh_dims[1]}}; + MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; EnumerateHalf(func); } From fe31095df7df3a522f8c52a4b16c8d4fd0b7b3f7 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Tue, 14 Nov 2023 12:05:57 -0800 Subject: [PATCH 082/391] [xla:gpu] Disable gpu_aot_compilation_test in jitrt_executable_tests PiperOrigin-RevId: 582399236 --- third_party/xla/xla/service/gpu/BUILD | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 20b36c1ec5204d..8619a614e135dc 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -4170,9 +4170,10 @@ test_suite( ":cudnn_fused_conv_rewriter_test", ":cudnn_fused_mha_rewriter_test", ":custom_call_test", - # copybara:uncomment ":gpu_aot_compilation_test", + # TODO(anlunx): Re-enable when AOT is available in Thunk-based runtime. + # copybara:uncomment # ":gpu_aot_compilation_test", # copybara:uncomment "//platforms/xla/tests/internal:xfeed_test_gpu", - # TODO(anlunx): Re-enable when the FFI mechanism is avalable in Thunk-based runtime. + # TODO(anlunx): Re-enable when the FFI mechanism is available in Thunk-based runtime. # copybara:uncomment # "//third_party/py/jax/experimental/jax2tf/tests:primitives_test_gpu", # copybara:uncomment "//third_party/py/jax/tests:pmap_test_gpu", # copybara:uncomment "//tensorflow/compiler/tests:fft_test_gpu", From a1e43d0070756e6b368157f5c0b965ccf3a2b855 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 12:11:54 -0800 Subject: [PATCH 083/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/f5091e2c05925158e0be192370a37a6cf6fcf241. PiperOrigin-RevId: 582401066 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index d92da02faa0943..dcbd5d3e632cd2 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "fa14cd8fbac47b3545e91b387df41d18262ead38" - TFRT_SHA256 = "1a8771c039520824dd66b404bf56bb4a387089fd1497b1ceace52e2bf3ce35f2" + TFRT_COMMIT = "f5091e2c05925158e0be192370a37a6cf6fcf241" + TFRT_SHA256 = "0b3cbc0ca3862115b2b15122402c42090dd3b52090183e5f6579fe7769e9df0f" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index d92da02faa0943..dcbd5d3e632cd2 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "fa14cd8fbac47b3545e91b387df41d18262ead38" - TFRT_SHA256 = "1a8771c039520824dd66b404bf56bb4a387089fd1497b1ceace52e2bf3ce35f2" + TFRT_COMMIT = "f5091e2c05925158e0be192370a37a6cf6fcf241" + TFRT_SHA256 = "0b3cbc0ca3862115b2b15122402c42090dd3b52090183e5f6579fe7769e9df0f" tf_http_archive( name = "tf_runtime", From 245a80785642c2b6adfbb91ccefdba179e9012d5 Mon Sep 17 00:00:00 2001 From: CJ Carey Date: Tue, 14 Nov 2023 12:54:44 -0800 Subject: [PATCH 084/391] Avoid putting a py_library target in the srcs list PiperOrigin-RevId: 582413012 --- tensorflow/python/compiler/tensorrt/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index bb563bc38836e1..0dc49c5a56f3fb 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -84,13 +84,13 @@ py_strict_library( py_strict_library( name = "tf_trt_integration_test_base", - srcs = ["//tensorflow/python/compiler/tensorrt/test:tf_trt_integration_test_base_srcs"], srcs_version = "PY3", deps = [ ":trt_convert_py", ":utils", "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/core:protos_all_py", + "//tensorflow/python/compiler/tensorrt/test:tf_trt_integration_test_base_srcs", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:graph_io", From ce174cc2c2ec32d7eb217ec11b77389b94ddf7e4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 14 Nov 2023 13:17:10 -0800 Subject: [PATCH 085/391] [xla:gpu] NFC: Fix typo in filecheck based test PiperOrigin-RevId: 582419743 --- third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc index b2a8c8656c590e..55268442fc48a7 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter_test.cc @@ -64,7 +64,7 @@ TEST_F(CustomFusionRewriterTest, SimpleGemm) { ; CHECK: [[P0:%[^ ]+]] = f16[15,19]{1,0} parameter(0) ; CHECK: [[P1:%[^ ]+]] = f16[19,17]{1,0} parameter(1) ; CHECK: ROOT [[DOT:%[^ ]+]] = f16[15,17]{1,0} dot([[P0]], [[P1]]), - ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} ; CHECK: } ; CHECK: ENTRY %main {{.*}} { From 49b6a05e97516c579ff7a6cb53763ac8e755e47c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 14 Nov 2023 13:47:31 -0800 Subject: [PATCH 086/391] [stream_executor] NFC: Rename KernelArgsArrayBase to KernelArgs Re-arrange structs/classes declarations in kernel.h to avoid forward declaring arguments types. PiperOrigin-RevId: 582428745 --- .../xla/xla/backends/interpreter/executor.h | 2 +- .../gpu/kernels/cutlass_gemm_kernel.cu.cc | 2 +- .../xla/xla/stream_executor/command_buffer.cc | 2 +- .../xla/xla/stream_executor/command_buffer.h | 2 +- .../cuda/cuda_command_buffer_test.cc | 2 +- .../xla/stream_executor/cuda/cuda_executor.cc | 3 +- .../stream_executor/gpu/gpu_command_buffer.cc | 2 +- .../stream_executor/gpu/gpu_command_buffer.h | 3 +- .../xla/stream_executor/gpu/gpu_executor.h | 2 +- .../xla/xla/stream_executor/gpu/gpu_graph.cc | 2 +- .../xla/xla/stream_executor/gpu/gpu_graph.h | 2 +- .../stream_executor/host/host_gpu_executor.h | 2 +- third_party/xla/xla/stream_executor/kernel.h | 147 +++++++++--------- .../xla/xla/stream_executor/kernel_spec.h | 4 +- .../stream_executor/rocm/rocm_gpu_executor.cc | 3 +- .../stream_executor_internal.h | 5 +- .../stream_executor/stream_executor_pimpl.cc | 2 +- .../stream_executor/stream_executor_pimpl.h | 2 +- .../xla/xla/stream_executor/trace_listener.h | 2 +- 19 files changed, 93 insertions(+), 98 deletions(-) diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 1f730788380bd6..a2c6fc13b360f9 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -57,7 +57,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } tsl::Status Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &kernel, - const KernelArgsArrayBase &args) override { + const KernelArgs &args) override { return tsl::errors::Unimplemented("Not Implemented"); } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc index 4f375d388f7854..6211d9eaf965e1 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc @@ -54,7 +54,7 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, size_t shared_memory_bytes = sizeof(typename GemmKernel::SharedStorage); // Packs device memory arguments into CUTLASS kernel parameters struct. - auto pack = [problem_size, tiled_shape](const se::KernelArgsArrayBase &args) { + auto pack = [problem_size, tiled_shape](const se::KernelArgs &args) { auto *mem_args = Cast(&args); // Converts DeviceMemoryBase to an opaque `void *` device pointer. diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 55e0c3fa20f975..9eeb5dad43be6b 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -77,7 +77,7 @@ CommandBuffer::CommandBuffer( tsl::Status CommandBuffer::Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, - const KernelArgsArrayBase& args) { + const KernelArgs& args) { return implementation_->Launch(threads, blocks, kernel, args); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2a1206307bc785..dade330b666579 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -100,7 +100,7 @@ class CommandBuffer { // Adds a kernel launch command to the command buffer. tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, const KernelArgsArrayBase& args); + const Kernel& kernel, const KernelArgs& args); // Adds a nested command buffer to the command buffer. tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 949c2cc0a98324..7f37980caa7f88 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -110,7 +110,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { // Register a kernel with a custom arguments packing function that packs // device memory arguments into a struct with pointers. - MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgsArrayBase& args) { + MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgs& args) { auto bufs = Cast(&args)->device_memory_args(); auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; return PackKernelArgs(add, internal::Ptrs3{ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index d46db3508b7360..4a9bbaa92e9625 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -414,8 +414,7 @@ tsl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, - const Kernel& kernel, - const KernelArgsArrayBase& args) { + const Kernel& kernel, const KernelArgs& args) { CUstream custream = AsGpuStreamValue(stream); const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 99037894a887df..d50bb091d040a6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -163,7 +163,7 @@ tsl::Status GpuCommandBuffer::CheckPrimary() { tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, - const KernelArgsArrayBase& args) { + const KernelArgs& args) { TF_RETURN_IF_ERROR(CheckNotFinalized()); const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index eb1a3633e5c442..7c30e40414b101 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -44,8 +44,7 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { absl::AnyInvocable function) override; tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, - const KernelArgsArrayBase& args) override; + const Kernel& kernel, const KernelArgs& args) override; tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 9b0fb7e4955a80..35a261fda6ace7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -117,7 +117,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& k, - const KernelArgsArrayBase& args) override; + const KernelArgs& args) override; tsl::Status Submit(Stream* stream, const CommandBuffer& command_buffer) override; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc b/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc index 0aa6610676e0ad..5d9df50df7c855 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_graph.cc @@ -211,7 +211,7 @@ tsl::StatusOr CreateGpuGraph() { tsl::StatusOr AddKernelNode( GpuGraphHandle graph, absl::Span deps, ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgsArrayBase& args) { + const KernelArgs& args) { const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h b/third_party/xla/xla/stream_executor/gpu/gpu_graph.h index 4331e8932d67c6..28abf986049e8b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_graph.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_graph.h @@ -116,7 +116,7 @@ tsl::StatusOr CreateGpuGraph(); tsl::StatusOr AddKernelNode( GpuGraphHandle graph, absl::Span deps, ThreadDim threads, BlockDim blocks, const Kernel& kernel, - const KernelArgsArrayBase& args); + const KernelArgs& args); // Adds a memory copy node to the graph. tsl::StatusOr AddMemcpyD2DNode( diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h index 91fda304be940b..6ca6d0bb6594c7 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_gpu_executor.h @@ -55,7 +55,7 @@ class HostExecutor : public internal::StreamExecutorInterface { } tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, - const KernelArgsArrayBase& args) override { + const KernelArgs& args) override { return tsl::errors::Unimplemented("Not Implemented"); } diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 075f7ec80b50cb..9077f80955ed8e 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -98,9 +98,6 @@ namespace internal { class KernelInterface; } // namespace internal -class KernelArgsArrayBase; // forward declare -class KernelArgsPackedArrayBase; // forward declare - //===----------------------------------------------------------------------===// // Kernel cache config //===----------------------------------------------------------------------===// @@ -151,6 +148,67 @@ class KernelMetadata { std::optional shared_memory_bytes_; }; +//===----------------------------------------------------------------------===// +// Kernel arguments +//===----------------------------------------------------------------------===// + +// A virtual base class for passing kernel arguments to a stream executor APIs. +class KernelArgs { + public: + template + using IsKernelArgs = std::enable_if_t::value>; + + enum class Kind { + // A list of type-erased DeviceMemoryBase pointers to on-device memory. This + // type of kernel arguments used only when the kernel has to do its own + // custom packing, e.g. wrap all device pointers into a custom + // structure, but can't be implemented as a TypedKernel because it has to be + // passed around as a generic Kernel. + kDeviceMemoryArray, + + // A list of kernel arguments packed into a storage that can be passed + // directly to device kernel as void** kernel parameters. + kPackedArray + }; + + virtual ~KernelArgs() = default; + + // Gets the number of arguments added so far, including shared memory + // arguments. + virtual size_t number_of_arguments() const = 0; + + // Gets the total number of shared memory bytes added so far. + virtual uint64_t number_of_shared_bytes() const = 0; + + virtual Kind kind() const = 0; +}; + +//===----------------------------------------------------------------------===// +// Kernel arguments packed array +//===----------------------------------------------------------------------===// + +// A virtual base class for passing kernel arguments packed into a storage so +// that we have stable addresses for all arguments. This is a low level API for +// passing arguments in a platform-specific way that relies on the knowledge of +// the ABI of the underlying platform. +// +// For example `cuLaunchKernel` accepts arguments as `void** kernelParams`, and +// packed array base guarantees that `argument_addresses` are compatible with +// the CUDA APIs. +// +// See: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html +class KernelArgsPackedArrayBase : public KernelArgs { + public: + // Gets the list of argument addresses. + virtual absl::Span argument_addresses() const = 0; + + static bool classof(const KernelArgs *args) { + return args->kind() == Kind::kPackedArray; + } + + Kind kind() const final { return Kind::kPackedArray; } +}; + //===----------------------------------------------------------------------===// // Kernel //===----------------------------------------------------------------------===// @@ -168,7 +226,7 @@ class Kernel { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgsArrayBase &args)>; + const KernelArgs &args)>; Kernel(Kernel &&from); @@ -255,66 +313,33 @@ class TypedKernel : public Kernel { }; //===----------------------------------------------------------------------===// -// Kernel arguments +// Kernel arguments LLVM-style RTTI library //===----------------------------------------------------------------------===// -// A virtual base class for passing kernel arguments to a stream executor APIs. -class KernelArgsArrayBase { - public: - template - using IsKernelArgs = - std::enable_if_t::value>; - - enum class Kind { - // A list of type-erased DeviceMemoryBase pointers to on-device memory. This - // type of kernel arguments used only when the kernel has to do its own - // custom packing, e.g. wrap all device pointers into a custom - // structure, but can't be implemented as a TypedKernel because it has to be - // passed around as a generic Kernel. - kDeviceMemoryArray, - - // A list of kernel arguments packed into a storage that can be passed - // directly to device kernel as void** kernel parameters. - kPackedArray - }; - - virtual ~KernelArgsArrayBase() = default; - - // Gets the number of arguments added so far, including shared memory - // arguments. - virtual size_t number_of_arguments() const = 0; - - // Gets the total number of shared memory bytes added so far. - virtual uint64_t number_of_shared_bytes() const = 0; - - virtual Kind kind() const = 0; -}; - -// A small LLVM-style RTTI library for casting kernel arguments. -template * = nullptr> -T *Cast(KernelArgsArrayBase *args) { +template * = nullptr> +T *Cast(KernelArgs *args) { CHECK(T::classof(args)) << "Invalid arguments casting to a destination type: " << typeid(T).name(); CHECK(args != nullptr) << "Casted arguments must be not null"; return static_cast(args); } -template * = nullptr> -const T *Cast(const KernelArgsArrayBase *args) { +template * = nullptr> +const T *Cast(const KernelArgs *args) { CHECK(T::classof(args)) << "Invalid arguments casting to a destination type: " << typeid(T).name(); CHECK(args != nullptr) << "Casted arguments must be not null"; return static_cast(args); } -template * = nullptr> -const T *DynCast(const KernelArgsArrayBase *args) { +template * = nullptr> +const T *DynCast(const KernelArgs *args) { CHECK(args != nullptr) << "Casted arguments must be not null"; return T::classof(args) ? static_cast(args) : nullptr; } -template * = nullptr> -const T *DynCastOrNull(const KernelArgsArrayBase *args) { +template * = nullptr> +const T *DynCastOrNull(const KernelArgs *args) { return args && T::classof(args) ? static_cast(args) : nullptr; } @@ -322,14 +347,14 @@ const T *DynCastOrNull(const KernelArgsArrayBase *args) { // Kernel arguments device memory array //===----------------------------------------------------------------------===// -class KernelArgsDeviceMemoryArray : public KernelArgsArrayBase { +class KernelArgsDeviceMemoryArray : public KernelArgs { public: KernelArgsDeviceMemoryArray(absl::Span args, size_t shared_memory_bytes) : device_memory_args_(args.begin(), args.end()), shared_memory_bytes_(shared_memory_bytes) {} - static bool classof(const KernelArgsArrayBase *args) { + static bool classof(const KernelArgs *args) { return args->kind() == Kind::kDeviceMemoryArray; } @@ -358,32 +383,6 @@ class KernelArgsDeviceMemoryArray : public KernelArgsArrayBase { size_t shared_memory_bytes_ = 0; }; -//===----------------------------------------------------------------------===// -// Kernel arguments packed array -//===----------------------------------------------------------------------===// - -// A virtual base class for passing kernel arguments packed into a storage so -// that we have stable addresses for all arguments. This is a low level API for -// passing arguments in a platform-specific way that relies on the knowledge of -// the ABI of the underlying platform. -// -// For example `cuLaunchKernel` accepts arguments as `void** kernelParams`, and -// packed array base guarantees that `argument_addresses` are compatible with -// the CUDA APIs. -// -// See: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html -class KernelArgsPackedArrayBase : public KernelArgsArrayBase { - public: - // Gets the list of argument addresses. - virtual absl::Span argument_addresses() const = 0; - - static bool classof(const KernelArgsArrayBase *args) { - return args->kind() == Kind::kPackedArray; - } - - Kind kind() const final { return Kind::kPackedArray; } -}; - //===----------------------------------------------------------------------===// // Kernel arguments packing for device memory and POD args //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index 7e7aef9d3887a2..6144944306bef2 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -61,7 +61,7 @@ limitations under the License. namespace stream_executor { -class KernelArgsArrayBase; // defined in kernel.h +class KernelArgs; // defined in kernel.h class KernelArgsPackedArrayBase; // defined in kernel.h // Describes how to load a kernel on a target platform. @@ -262,7 +262,7 @@ class MultiKernelLoaderSpec { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgsArrayBase &args)>; + const KernelArgs &args)>; explicit MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing = nullptr); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index 9e27e444cb3487..3b7e6dd0a178c6 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -256,8 +256,7 @@ tsl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, - const Kernel& kernel, - const KernelArgsArrayBase& args) { + const Kernel& kernel, const KernelArgs& args) { CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0), args.number_of_arguments()); GpuStreamHandle hipstream = AsGpuStreamValue(stream); diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index f8184848870107..1928d1afc913bf 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -130,8 +130,7 @@ class CommandBufferInterface { // Adds a kernel launch command to the command buffer. virtual tsl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, - const Kernel& kernel, - const KernelArgsArrayBase& args) = 0; + const Kernel& kernel, const KernelArgs& args) = 0; // Adds a nested command buffer to the command buffer. virtual tsl::Status AddNestedCommandBuffer(const CommandBuffer& nested) = 0; @@ -247,7 +246,7 @@ class StreamExecutorInterface { } virtual tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& k, - const KernelArgsArrayBase& args) { + const KernelArgs& args) { return absl::UnimplementedError("Not Implemented"); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index 56201c27e74d26..83bfbcfd8e4cc8 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -440,7 +440,7 @@ fft::FftSupport* StreamExecutor::AsFft() { tsl::Status StreamExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, - const KernelArgsArrayBase& args) { + const KernelArgs& args) { SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, kernel, args); diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index e38e8f880415d0..a7f16930aa8b58 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -414,7 +414,7 @@ class StreamExecutor { // implementation in StreamExecutorInterface::Launch(). tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, - const KernelArgsArrayBase& args); + const KernelArgs& args); // Submits command buffer for execution to the underlying platform driver. tsl::Status Submit(Stream* stream, const CommandBuffer& command_buffer); diff --git a/third_party/xla/xla/stream_executor/trace_listener.h b/third_party/xla/xla/stream_executor/trace_listener.h index 11977fbafa506c..79909261aca078 100644 --- a/third_party/xla/xla/stream_executor/trace_listener.h +++ b/third_party/xla/xla/stream_executor/trace_listener.h @@ -48,7 +48,7 @@ class TraceListener { virtual void LaunchSubmit(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, - const KernelArgsArrayBase& args) {} + const KernelArgs& args) {} virtual void SynchronousMemcpyH2DBegin(int64_t correlation_id, const void* host_src, int64_t size, From ccb912daad3050ee7ebb99e9eb3e20a77e4fd972 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 14 Nov 2023 14:25:44 -0800 Subject: [PATCH 087/391] Create script for generating `compile_commands.json` PiperOrigin-RevId: 582440759 --- third_party/xla/build_tools/lint/BUILD | 7 +- .../lint/generate_compile_commands.py | 119 ++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/build_tools/lint/generate_compile_commands.py diff --git a/third_party/xla/build_tools/lint/BUILD b/third_party/xla/build_tools/lint/BUILD index 8ca1872bb1b064..0270b76421a545 100644 --- a/third_party/xla/build_tools/lint/BUILD +++ b/third_party/xla/build_tools/lint/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -load("//xla:pytype.default.bzl", "pytype_strict_library") +load("//xla:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_library") # Placeholder: load py_test package( @@ -34,6 +34,11 @@ pytype_strict_library( visibility = ["//visibility:public"], ) +pytype_strict_binary( + name = "generate_compile_commands", + srcs = ["generate_compile_commands.py"], +) + py_test( name = "check_contents_test", srcs = ["check_contents_test.py"], diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py new file mode 100644 index 00000000000000..1dc84fbf14a8ff --- /dev/null +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -0,0 +1,119 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +r"""Produces a `compile_commands.json` from the output of `bazel aquery`. + +Example usage: + bazel aquery "mnemonic(CppCompile, //xla/...)" | \ + python3 build_tools/lint/generate_compile_commands.py +""" +import dataclasses +import json +import logging +import pathlib +import sys +from typing import Any + +_JSONDict = dict[Any, Any] # Approximates parsed JSON + +_DISALLOWED_ARGS = frozenset(["-fno-canonical-system-headers"]) +_XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent.parent.parent + + +@dataclasses.dataclass +class ClangTidyCommand: + """Represents a clang-tidy command with options on a specific file.""" + + file: str + arguments: list[str] + + @classmethod + def from_args_list(cls, args_list: list[str]) -> "ClangTidyCommand": + """Alternative constructor which uses the args_list from `bazel aquery`. + + This collects arguments and the file being run on from the output of + `bazel aquery`. Also filters out arguments which break clang-tidy. + + Arguments: + args_list: List of arguments generated by `bazel aquery` + + Returns: + The corresponding ClangTidyCommand. + """ + cc_file = None + filtered_args = [] + + for arg in args_list: + if arg in _DISALLOWED_ARGS: + continue + + if arg.endswith(".cc"): + cc_file = arg + + filtered_args.append(arg) + + return cls(cc_file, filtered_args) + + def to_dumpable_json(self, directory: str) -> _JSONDict: + return { + "directory": directory, + "file": self.file, + "arguments": self.arguments, + } + + +def extract_compile_commands( + parsed_aquery_output: _JSONDict, +) -> list[ClangTidyCommand]: + """Gathers clang-tidy commands to run from `bazel aquery` JSON output. + + Arguments: + parsed_aquery_output: Parsed JSON representing the output of `bazel aquery + --output=jsonproto`. + + Returns: + The list of ClangTidyCommands that should be executed. + """ + actions = parsed_aquery_output["actions"] + + commands = [] + for action in actions: + command = ClangTidyCommand.from_args_list(action["arguments"]) + commands.append(command) + return commands + + +def main(): + # Setup logging + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + # Gather and run clang-tidy invocations + logging.info("Reading `bazel aquery` output from stdin...") + parsed_aquery_output = json.loads(sys.stdin.read()) + + commands = extract_compile_commands(parsed_aquery_output) + + with (_XLA_SRC_ROOT / "compile_commands.json").open("w") as f: + json.dump( + [ + command.to_dumpable_json(directory=str(_XLA_SRC_ROOT)) + for command in commands + ], + f, + ) + + +if __name__ == "__main__": + main() From 5e48d6e144f1b5523d3c84452974dc5cfec7d8bf Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 14 Nov 2023 14:48:01 -0800 Subject: [PATCH 088/391] Prevent OOB indexing in StableHLO/MHLO ops. PiperOrigin-RevId: 582447653 --- third_party/stablehlo/temporary.patch | 75 +++++++++++++++++++ .../xla/third_party/stablehlo/temporary.patch | 75 +++++++++++++++++++ .../xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 2 +- .../mhlo/analysis/shape_component_analysis.cc | 4 + 4 files changed, 155 insertions(+), 1 deletion(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3b31bf25fb5ab6..ba7bcb6b5de29c 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -202,6 +202,45 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Ba -#endif +#endif // STABLEHLO_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp +--- stablehlo/stablehlo/dialect/TypeInference.cpp ++++ stablehlo/stablehlo/dialect/TypeInference.cpp +@@ -3510,12 +3510,21 @@ + } + + LogicalResult verifyDynamicIotaOp(std::optional location, +- Value outputShape, int64_t outputDimension, ++ Value outputShape, int64_t iotaDimension, + Value result) { +- if (!isCompatibleForHloTypeInference(outputShape, result.getType())) ++ auto shape = result.getType().cast(); ++ if (!isCompatibleForHloTypeInference(outputShape, shape)) + return emitOptionalError( + location, "output_shape is incompatible with return type of operation ", + result.getType()); ++ ++ if (!shape.hasRank()) return success(); ++ ++ if (iotaDimension >= shape.getRank() || iotaDimension < 0) ++ return emitOptionalError( ++ location, ++ "iota dimension cannot go beyond the output rank or be negative."); ++ + return success(); + } + +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h +--- stablehlo/stablehlo/dialect/TypeInference.h ++++ stablehlo/stablehlo/dialect/TypeInference.h +@@ -427,7 +427,7 @@ + Value result); + + LogicalResult verifyDynamicIotaOp(std::optional location, +- Value outputShape, int64_t outputDimension, ++ Value outputShape, int64_t iotaDimension, + Value result); + + LogicalResult verifyDynamicPadOp(std::optional location, diff --ruN a/stablehlo/stablehlo/experimental/BUILD b/stablehlo/stablehlo/experimental/BUILD --- stablehlo/stablehlo/experimental/BUILD +++ stablehlo/stablehlo/experimental/BUILD @@ -3884,6 +3923,42 @@ diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/B tags = ["stablehlo_tests"], ) for src in glob(["**/*.mlir"]) +diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir +--- stablehlo/stablehlo/tests/ops_stablehlo.mlir ++++ stablehlo/stablehlo/tests/ops_stablehlo.mlir +@@ -5648,6 +5648,32 @@ + + // ----- + ++func.func @dynamic_iota_unranked_large() -> tensor<*xf32> { ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ %1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32> ++ func.return %1 : tensor<*xf32> ++} ++ ++// ----- ++ ++func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor { ++ // expected-error@+2 {{iota dimension cannot go beyond the output rank or be negative}} ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ %1 = stablehlo.dynamic_iota %0, dim = -1 : (tensor<1xi64>) -> tensor ++ func.return %1 : tensor ++} ++ ++// ----- ++ ++func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} ++ %1 = stablehlo.dynamic_iota %0, dim = 2 : (tensor<1xi64>) -> tensor ++ func.return %1 : tensor ++} ++ ++// ----- ++ + func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { + // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 70705cb3bef92e..f132c3c6a10ce1 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -202,6 +202,45 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Ba -#endif +#endif // STABLEHLO_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp +--- stablehlo/stablehlo/dialect/TypeInference.cpp ++++ stablehlo/stablehlo/dialect/TypeInference.cpp +@@ -3510,12 +3510,21 @@ + } + + LogicalResult verifyDynamicIotaOp(std::optional location, +- Value outputShape, int64_t outputDimension, ++ Value outputShape, int64_t iotaDimension, + Value result) { +- if (!isCompatibleForHloTypeInference(outputShape, result.getType())) ++ auto shape = result.getType().cast(); ++ if (!isCompatibleForHloTypeInference(outputShape, shape)) + return emitOptionalError( + location, "output_shape is incompatible with return type of operation ", + result.getType()); ++ ++ if (!shape.hasRank()) return success(); ++ ++ if (iotaDimension >= shape.getRank() || iotaDimension < 0) ++ return emitOptionalError( ++ location, ++ "iota dimension cannot go beyond the output rank or be negative."); ++ + return success(); + } + +diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h +--- stablehlo/stablehlo/dialect/TypeInference.h ++++ stablehlo/stablehlo/dialect/TypeInference.h +@@ -427,7 +427,7 @@ + Value result); + + LogicalResult verifyDynamicIotaOp(std::optional location, +- Value outputShape, int64_t outputDimension, ++ Value outputShape, int64_t iotaDimension, + Value result); + + LogicalResult verifyDynamicPadOp(std::optional location, diff --ruN a/stablehlo/stablehlo/experimental/BUILD b/stablehlo/stablehlo/experimental/BUILD --- stablehlo/stablehlo/experimental/BUILD +++ stablehlo/stablehlo/experimental/BUILD @@ -3884,6 +3923,42 @@ diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/B tags = ["stablehlo_tests"], ) for src in glob(["**/*.mlir"]) +diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir +--- stablehlo/stablehlo/tests/ops_stablehlo.mlir ++++ stablehlo/stablehlo/tests/ops_stablehlo.mlir +@@ -5648,6 +5648,32 @@ + + // ----- + ++func.func @dynamic_iota_unranked_large() -> tensor<*xf32> { ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ %1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32> ++ func.return %1 : tensor<*xf32> ++} ++ ++// ----- ++ ++func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor { ++ // expected-error@+2 {{iota dimension cannot go beyond the output rank or be negative}} ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ %1 = stablehlo.dynamic_iota %0, dim = -1 : (tensor<1xi64>) -> tensor ++ func.return %1 : tensor ++} ++ ++// ----- ++ ++func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { ++ %0 = stablehlo.constant dense<[4]> : tensor<1xi64> ++ // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} ++ %1 = stablehlo.dynamic_iota %0, dim = 2 : (tensor<1xi64>) -> tensor ++ func.return %1 : tensor ++} ++ ++// ----- ++ + func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { + // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 99007914acbd9d..4d01cd520861c1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -7076,7 +7076,7 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, LogicalResult verifyCrossProgramPrefetchAttr(CrossProgramPrefetchAttr cpp, ModuleOp module) { func::FuncOp main = module.lookupSymbol("main"); - if (cpp.getParameter() >= main.getNumArguments()) + if (cpp.getParameter() >= main.getNumArguments() || cpp.getParameter() < 0) return module->emitOpError() << "cross_program_prefetch: parameter " << cpp.getParameter() << " out of range. main has only " << main.getNumArguments() diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc index 380d48d3b3754f..e1fa72e2a80ac0 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc @@ -16,10 +16,12 @@ limitations under the License. #include "mhlo/analysis/shape_component_analysis.h" #include +#include #include #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -520,6 +522,8 @@ struct ShapeVisitor { if (auto index = op.getIndex().getDefiningOp()) { int64_t i = index.getValue().cast().getInt(); auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getSource())); + if (i >= static_cast(in.size()) || i < 0) + llvm::report_fatal_error("tensor dim out of bounds"); dims.push_back({in[i].symbols, in[i].expr}); } else { forwardUnknown(op); From 20be0d19b68cca0a8762bf6a794e35b6f3d5164a Mon Sep 17 00:00:00 2001 From: Shan Han Date: Tue, 14 Nov 2023 15:03:13 -0800 Subject: [PATCH 089/391] Export batch costs by processed size. PiperOrigin-RevId: 582452027 --- .../batching_util/batch_resource_base.cc | 43 ++++++++++++++++--- .../batching_util/batch_resource_base.h | 2 + .../batching_util/batch_resource_base_test.cc | 35 ++++++--------- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index c1395ed464252c..7c601a35f0bca5 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -234,6 +234,23 @@ void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes, cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes); } +void RecordBatchCosts(const std::string& model_name, + const int64_t processed_size, + const absl::string_view cost_type, + const absl::Duration total_cost) { + static auto* cell = tensorflow::monitoring::Sampler<3>::New( + {"/tensorflow/serving/batching/costs", + "Tracks the batch costs (in microseconds) by model name and processed " + "size.", + "model_name", "processed_size", "cost_type"}, + // It's 27 buckets with the last bucket being 2^26 to DBL_MAX; + // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024 (~64s), DBL_MAX]. + monitoring::Buckets::Exponential(1, 2, 27)); + cell->GetCell(model_name, std::to_string(processed_size), + std::string(cost_type)) + ->Add(absl::ToDoubleMicroseconds(total_cost)); +} + const string& GetModelName(OpKernelContext* ctx) { static string* kModelNameUnset = new string("model_name_unset"); if (!ctx->session_metadata()) return *kModelNameUnset; @@ -827,6 +844,7 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { auto& last_task = batch->task(batch->num_tasks() - 1); OpKernelContext* last_task_context = last_task.context; + const std::string& model_name = GetModelName(last_task_context); // Regardless of the outcome, we need to propagate the status to the // individual tasks and signal that they are done. We use MakeCleanup() to @@ -838,8 +856,8 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { if (cleanup_done) { return; } - SplitBatchCostsAndRecordMetrics(batch_cost_measurements, processed_size, - *batch); + SplitBatchCostsAndRecordMetrics(model_name, batch_cost_measurements, + processed_size, *batch); // Clear the measurements before unblocking the batch task, as measurements // are associated with the task's thread context. batch_cost_measurements.clear(); @@ -878,7 +896,6 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); uint64 current_time = EnvTime::NowNanos(); - const string& model_name = GetModelName(last_task_context); for (int i = 0; i < batch->num_tasks(); ++i) { RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3, model_name, last_task_context->op_kernel().name(), @@ -930,15 +947,17 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { CreateCostMeasurements(batching_context); int64_t processed_size = batch->size(); - auto batch_cost_split_cleanup = gtl::MakeCleanup([&] { - SplitBatchCostsAndRecordMetrics(batch_cost_measurements, processed_size, - *batch); - }); OpKernelContext* last_task_context = batch->task(batch->num_tasks() - 1).context; AsyncOpKernel::DoneCallback last_task_callback = batch->task(batch->num_tasks() - 1).done_callback; + const std::string& model_name = GetModelName(last_task_context); + + auto batch_cost_cleanup = gtl::MakeCleanup([&] { + SplitBatchCostsAndRecordMetrics(model_name, batch_cost_measurements, + processed_size, *batch); + }); OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch), last_task_callback); @@ -1056,6 +1075,7 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, } void BatchResourceBase::SplitBatchCostsAndRecordMetrics( + const std::string& model_name, const std::vector>& batch_cost_measurements, const int64_t processed_size, BatchT& batch) { @@ -1078,6 +1098,15 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( const absl::string_view cost_type = batch_cost_measurement->GetCostType(); const absl::Duration total_cost = batch_cost_measurement->GetTotalCost(); + // Smeared batch cost: cost for processing this batch. + RecordBatchCosts(model_name, processed_size, + absl::StrCat(cost_type, kWithSmearSuffix), total_cost); + // Non-smeared batch cost: cost for processing inputs in this batch, i.e. + // cost for processing paddings is excluded. + RecordBatchCosts(model_name, processed_size, + absl::StrCat(cost_type, kNoSmearSuffix), + total_cost / processed_size * batch.size()); + for (int i = 0; i < batch.num_tasks(); i++) { RequestCost* request_cost = batch.task(i).request_cost; // Skip recording the cost if the request_cost is null. diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 5124e9f031733a..b86d25c097da39 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -238,6 +239,7 @@ class BatchResourceBase : public ResourceBase { // 2) the input size from this task; // 3) the padding amount. static void SplitBatchCostsAndRecordMetrics( + const std::string& model_name, const std::vector>& batch_cost_measurements, int64_t processed_size, BatchT& batch); diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index dc75fde050cc6f..cd4ae4644ed62e 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -70,9 +70,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoCostMeasurement) { batch.Close(); std::vector> batch_cost_measurements; - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( @@ -90,9 +89,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroCost) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("no_op", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( @@ -108,9 +106,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroBatchSize) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/0, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/0, batch); } TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoRequestCost) { @@ -123,9 +120,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoRequestCost) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/16, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/16, batch); EXPECT_EQ(batch.task(0).request_cost, nullptr); EXPECT_EQ(batch.task(1).request_cost, nullptr); @@ -142,9 +138,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitSingleCostType) { std::vector> batch_cost_measurements; batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), @@ -179,9 +174,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) { CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_gcu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), @@ -223,9 +217,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { CostMeasurementRegistry::CreateByNameOrNull("no_op", context)); batch_cost_measurements.push_back( CostMeasurementRegistry::CreateByNameOrNull("test_tpu", context)); - BatchResourceBase::SplitBatchCostsAndRecordMetrics(batch_cost_measurements, - /*processed_size=*/20, - batch); + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + "model_name", batch_cost_measurements, /*processed_size=*/20, batch); EXPECT_THAT( batch.task(0).request_cost->GetCosts(), From 06636eb12a87634ee4715e647b560bad8d9e07fa Mon Sep 17 00:00:00 2001 From: Raviteja Gorijala Date: Tue, 14 Nov 2023 15:09:49 -0800 Subject: [PATCH 090/391] Update Release notes after 2.15.0 release PiperOrigin-RevId: 582453966 --- RELEASE.md | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 75350aeccc5542..3e87a523116c95 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -164,29 +164,26 @@ This release contains contributions from many people at Google, as well as: * Provided a new `experimental_skip_saver` argument which, if specified, will suppress the addition of `SavedModel`-native save and restore ops to the `SavedModel`, for cases where users already build custom save/restore ops and checkpoint formats for the model being saved, and the creation of the SavedModel-native save/restore ops simply cause longer model serialization times. -* `tf.math.bincount` - * Updated documentation. Fixed "[Bincount doesn't check the tensor type](https://github.com/tensorflow/tensorflow/issues/56499)" and some other corner cases. - -## Keras - -### Breaking Changes - -### Known Caveats - -### Major Features and Improvements - -### Bug Fixes and Other Changes - * Add ops to `tensorflow.raw_ops` that were missing. + * `tf.CheckpointOptions` * It now takes in a new argument called `experimental_write_callbacks`. These are callbacks that will be executed after a saving event finishes writing the checkpoint file. + * Add an option `disable_eager_executer_streaming_enqueue` to `tensorflow.ConfigProto.Experimental` to control the eager runtime's behavior around parallel remote function invocations; when set to `True`, the eager runtime will be allowed to execute multiple function invocations in parallel. + * `tf.constant_initializer` - * It now takes a new argument called `support_partition`. If True, constant_initializers can create sharded variables. This is disabled by default similar to existing behavior. + * It now takes a new argument called `support_partition`. If True, constant_initializers can create sharded variables. This is disabled by default, similar to existing behavior. * `tf.lite` * Added support for `stablehlo.scatter`. +* `tf.estimator` + * The tf.estimator API removal is in progress and will be targeted for the 2.16 release. + +## Keras + +* This will be the final release before the launch of Keras 3.0, when Keras will become multi-backend. For the compatibility page and other info, please see: https://github.com/keras-team/keras-core + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: From daa1af4343134ab1489e35e9b7e0f44d52992a64 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 14 Nov 2023 15:24:06 -0800 Subject: [PATCH 091/391] [xla:gpu] Add gpu-schedule-postprocessing pass. Add a boolean field, no_parallel_gpu_op, to CollectiveBackendConfig. This field asserts that an asynchronous collective operation does not execute in parallel with other operations in GPU. The default value of the attribute is false, which should lead to conservative runtime behavior. Add BackendConfig test for the field. Add gpu-schedule-postprocessing pass, to refine the attribute value. Add test cases for the pass. PiperOrigin-RevId: 582457930 --- third_party/xla/xla/service/gpu/BUILD | 35 ++++ .../xla/xla/service/gpu/backend_configs.proto | 6 +- .../xla/service/gpu/backend_configs_test.cc | 1 + .../gpu/gpu_schedule_postprocessing.cc | 164 ++++++++++++++++++ .../service/gpu/gpu_schedule_postprocessing.h | 48 +++++ .../gpu/gpu_schedule_postprocessing_test.cc | 157 +++++++++++++++++ 6 files changed, 410 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc create mode 100644 third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h create mode 100644 third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8619a614e135dc..b760136fa34912 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3337,6 +3337,41 @@ xla_cc_test( ], ) +cc_library( + name = "gpu_schedule_postprocessing", + srcs = ["gpu_schedule_postprocessing.cc"], + hdrs = ["gpu_schedule_postprocessing.h"], + visibility = ["//visibility:public"], + deps = [ + ":backend_configs_cc", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gpu_schedule_postprocessing_test", + srcs = ["gpu_schedule_postprocessing_test.cc"], + deps = [ + ":gpu_schedule_postprocessing", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "gpu_hlo_schedule", srcs = ["gpu_hlo_schedule.cc"], diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 7f068fb77e1f7f..00504c2764b0f0 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -111,9 +111,13 @@ message BitcastBackendConfig { // Backend config for async collective operations. Note that for is_sync will // be false by default, so even if a backend config is not explicitly attached // to the HLOInstruction, getting the backend_config will yield a default valued -// proto which will have is_sync = false. +// proto which will have is_sync = false. Attribute no_parallel_gpu_op asserts +// that an asynchronous collective operation does not execute in parallel with +// other operations in GPU. This attribute will also be false by default, which +// should lead to conversative runtime behavior. message CollectiveBackendConfig { bool is_sync = 1; + bool no_parallel_gpu_op = 2; } message ReificationCost { diff --git a/third_party/xla/xla/service/gpu/backend_configs_test.cc b/third_party/xla/xla/service/gpu/backend_configs_test.cc index 89eaa639d033d5..d1b32e1abda201 100644 --- a/third_party/xla/xla/service/gpu/backend_configs_test.cc +++ b/third_party/xla/xla/service/gpu/backend_configs_test.cc @@ -50,6 +50,7 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { ags->backend_config(); EXPECT_THAT(collective_backend_config.status(), IsOk()); EXPECT_THAT(collective_backend_config->is_sync(), IsFalse()); + EXPECT_THAT(collective_backend_config->no_parallel_gpu_op(), IsFalse()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc new file mode 100644 index 00000000000000..603dfe28b94be7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc @@ -0,0 +1,164 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_schedule_postprocessing.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/statusor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { +// Maps a computation to a boolean that indicates whether the computation +// invokes gpu ops directly or indirectly. +using GpuOpInComputation = absl::flat_hash_map; + +// Returns whether the hlo may invoke gpu-ops which are operations that call +// into CUDA, directly or indirectly. Currently, we only check for custom-calls +// and fusion, because they are the only gpu-ops that can be parallel with +// asynchronous collectives operations. +bool MayInvokeGpuOp(HloInstruction* hlo, + GpuOpInComputation& gpu_op_in_computation) { + if (hlo->opcode() == HloOpcode::kCustomCall || + hlo->opcode() == HloOpcode::kFusion) { + return true; + } + + return std::any_of(hlo->called_computations().begin(), + hlo->called_computations().end(), + [&](const HloComputation* callee) { + return gpu_op_in_computation.find(callee)->second; + }); +} + +// Returns true if this is an asynchronous collective start operation, excluding +// P2P operations. +StatusOr IsRelevantAsynchronousStart(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + if (!hlo_query::IsAsyncCollectiveStartOp(opcode, + /*include_send_recv=*/false)) { + return false; + } + TF_ASSIGN_OR_RETURN(CollectiveBackendConfig collective_backend_config, + hlo->backend_config()); + return !collective_backend_config.is_sync(); +} + +// Returns true if this is a collective done operation, excluding P2P +// operations. +StatusOr IsRelevantAsynchronousDone(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + return hlo_query::IsAsyncCollectiveDoneOp(opcode, + /*include_send_recv=*/false); +} + +// For a given computation, finds all the asynchronous collective operations +// that aren't parallel with other gpu-op-invoking instructions and sets its +// no_parallel_gpu_op attribute to true. Also records whether the given +// computation may invoke gpu-ops. +StatusOr ProcessComputation(HloSchedule& schedule, + HloComputation* computation, + GpuOpInComputation& gpu_op_in_computation) { + bool changed = false; + bool has_gpu_op = false; + absl::flat_hash_set async_starts; + const HloInstructionSequence& sequence = schedule.sequence(computation); + + // Visit instructions in the sequence. Collect relevant asynchronous + // collective start ops. When we see a relevant asynchronous collective done + // op, remove the corresponding start op from the collection and set its + // attribute no_parallel_gpu_op to true. When we see a gpu-op, clear the start + // ops from the collection and keep their attribute no_parallel_gpu_op as + // false. + const std::vector all_instructions = sequence.instructions(); + for (auto instr_it = all_instructions.begin(); + instr_it != all_instructions.end(); ++instr_it) { + HloInstruction* hlo = *instr_it; + if (MayInvokeGpuOp(hlo, gpu_op_in_computation)) { + async_starts.clear(); + has_gpu_op = true; + continue; + } + TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo)); + if (is_async_start) { + async_starts.insert(hlo); + continue; + } + + TF_ASSIGN_OR_RETURN(bool is_async_done, IsRelevantAsynchronousDone(hlo)); + if (is_async_done) { + HloInstruction* async_start = hlo->mutable_operand(0); + if (async_starts.contains(async_start)) { + changed = true; + TF_ASSIGN_OR_RETURN( + CollectiveBackendConfig collective_backend_config, + async_start->backend_config()); + collective_backend_config.set_no_parallel_gpu_op(true); + TF_RETURN_IF_ERROR( + async_start->set_backend_config(collective_backend_config)); + async_starts.erase(async_start); + } + } + } + + gpu_op_in_computation[computation] = has_gpu_op; + return changed; +} + +} // anonymous namespace + +StatusOr GpuSchedulePostprocessing::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + if (!module->has_schedule()) return false; + HloSchedule& schedule = module->schedule(); + bool changed = false; + GpuOpInComputation gpu_op_in_computation; + + // We visit computations in the order of callees to callers, as information is + // propagated from calles to callers. + std::vector all_computations = + module->MakeComputationPostOrder(execution_threads); + for (auto iter = all_computations.begin(); iter != all_computations.end(); + ++iter) { + HloComputation* computation = *iter; + if (computation->IsFusionComputation()) { + gpu_op_in_computation[computation] = false; + continue; + } + + TF_ASSIGN_OR_RETURN(bool result, ProcessComputation(schedule, computation, + gpu_op_in_computation)); + changed |= result; + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h new file mode 100644 index 00000000000000..578dffabc146f2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ + +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Amends a schedule result with the needed information to support a runtime +// implementation. Currently, this pass refines attribute no_parallel_gpu_op +// for asynchronous collective operations to support runtime optimization, such +// as skipping rendezvous of all participating threads for NCCL collective +// operations. In particular, it sets the attribute value for Collective-start +// operations with is_sync=false; it also keeps the attribute value untouch for +// the operations with is_sync=true and for P2P operations, assumming the +// runtime won't use those values. +// +class GpuSchedulePostprocessing : public HloModulePass { + public: + absl::string_view name() const override { + return "gpu-schedule-postprocessing"; + } + + using HloPassInterface::Run; + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc new file mode 100644 index 00000000000000..e803e7843eb007 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_schedule_postprocessing.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_parser.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using GpuSchedulePostprocessingTest = HloTestBase; + +TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":true}" + ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + GpuSchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY main { + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}}" + } + recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=2 + ROOT recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + GpuSchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32.2), dimensions={0}, backend_config="{\"is_sync\":false}" + ROOT all-gather-done = f32[2] all-gather-done(all-gather-start) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + GpuSchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, + start->backend_config()); + EXPECT_TRUE(collective_backend_config.no_parallel_gpu_op()); +} + +TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + pf32 = f32[1] parameter(0) + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":false}" + pf32.2 = f32[1] custom-call(pf32), custom_call_target="my_custom_call" + all-gather-done = f32[2] all-gather-done(all-gather-start) + ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + GpuSchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, + start->backend_config()); + EXPECT_FALSE(collective_backend_config.no_parallel_gpu_op()); +} + +TEST_F(GpuSchedulePostprocessingTest, + AsynchronousOpsWithParallelNestedCustomcall) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + foo { + v = f32[1] parameter(0) + ROOT ret = f32[1] custom-call(v), custom_call_target="my_custom_call" + } + + ENTRY entry { + pf32 = f32[1] parameter(0) + all-gather-start = (f32[1], f32[2]) all-gather-start(pf32), dimensions={0}, backend_config="{\"is_sync\":false}" + pf32.2 = f32[1] call(f32[1] pf32), to_apply=foo + all-gather-done = f32[2] all-gather-done(all-gather-start) + ROOT out = (f32[1], f32[2]) tuple(f32[1] pf32.2, f32[2] all-gather-done) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kHloString))); + GpuSchedulePostprocessing pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_FALSE(changed); + + HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); + TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, + start->backend_config()); + EXPECT_FALSE(collective_backend_config.no_parallel_gpu_op()); +} + +} // namespace +} // namespace gpu +} // namespace xla From 5788f66ebcd9c269c252a49dd7d4b7aaca1e5519 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Tue, 14 Nov 2023 15:26:33 -0800 Subject: [PATCH 092/391] Populate manual sharding for asynchronous instructions. PiperOrigin-RevId: 582458600 --- .../xla/xla/service/sharding_propagation.cc | 21 ++- .../xla/xla/service/sharding_propagation.h | 8 +- .../xla/service/sharding_propagation_test.cc | 151 ++++++++++++++++++ 3 files changed, 172 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index ad89c0536f5f5a..c82bb6bc87daaf 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -2164,7 +2164,8 @@ bool ShardingPropagation::InferShardingFromShardGroup( // changed and false otherwise. bool ShardingPropagation::InferShardingFromOperands( HloInstruction* instruction, const ComputationMap& computation_map, - int64_t aggressiveness, const CallGraph& call_graph) { + int64_t aggressiveness, const CallGraph& call_graph, + const absl::flat_hash_set& execution_threads) { if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) { return false; } @@ -2175,16 +2176,27 @@ bool ShardingPropagation::InferShardingFromOperands( // Propagate manual sharding. Avoid tuple shaped HLOs that group independent // together. Reduce, ReduceWindow, and Sort can be tuples but the elements // are correlated, so we propagate manual sharding through them. + // For custom-calls with manual operand, the default propagation logic will // just assign manual to the whole custom-call. + const bool custom_call_condition = + instruction->opcode() == HloOpcode::kCustomCall && + instruction->shape().IsTuple(); + // For asynchronous instructions with manual operand, we assign manual to the + // whole instructions if the async_execution_thread is not in the + // execution_threads. + const bool async_instr_condition = + instruction->IsAsynchronous() && + !HloInstruction::IsThreadIncluded(instruction->async_execution_thread(), + execution_threads); + if ((!instruction->has_sharding() || instruction->sharding().IsTileMaximal()) && (instruction->shape().IsArray() || instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kSort || instruction->opcode() == HloOpcode::kReduceWindow || - (instruction->opcode() == HloOpcode::kCustomCall && - instruction->shape().IsTuple()))) { + custom_call_condition || async_instr_condition)) { for (const HloInstruction* op : instruction->operands()) { if (!op->has_sharding() || !op->sharding().IsManual()) continue; // Do not pass through manual sharding to SPMDShardToFullShape. @@ -3165,7 +3177,8 @@ StatusOr ShardingPropagation::Run( } already_inferred_from_operands.insert(instruction); if (InferShardingFromOperands(instruction, computation_map, - aggressiveness, *call_graph)) { + aggressiveness, *call_graph, + execution_threads)) { ++inferred_from_operand_counter; any_changed = true; VLOG(2) << "Add sharding (forward-pass): " diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 2cdf11a92ac197..82aa10c6deccc2 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -148,10 +148,10 @@ class ShardingPropagation : public HloModulePass { HloInstruction* instruction, const ComputationMap& computation_map, int64_t aggressiveness, const absl::flat_hash_set& shard_group); - bool InferShardingFromOperands(HloInstruction* instruction, - const ComputationMap& computation_map, - int64_t aggressiveness, - const CallGraph& call_graph); + bool InferShardingFromOperands( + HloInstruction* instruction, const ComputationMap& computation_map, + int64_t aggressiveness, const CallGraph& call_graph, + const absl::flat_hash_set& execution_threads); bool InferShardingFromUsers( HloInstruction* instruction, const ShardingPropagation::ComputationMap& computation_map, diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index a3a2143afa879f..e3cdf176be5133 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -10373,5 +10373,156 @@ ENTRY %entry { "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { + const char* const hlo_string = R"( +HloModule module + +called_computation { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + ROOT add = s32[8] add(p0, p1) +}, execution_thread="thread_1" // called_computation + +ENTRY entry_computation { + p0 = s32[8] parameter(0), sharding={manual} + p1 = s32[8] parameter(1), sharding={manual} + async-start = ((s32[8], s32[8]), s32[8], u32[]) call-start(p0, p1), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation + ROOT async-done = s32[8] call-done(async-start), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation +}, execution_thread="thread_0" // entry_computation + +)"; + + { + // Test with execution_threads = {"thread_0"} + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get(), {"thread_0"})); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + + auto* instruction = FindInstruction(module.get(), "async-start"); + ASSERT_NE(instruction, nullptr); + EXPECT_THAT(instruction, + op::Sharding("{{manual}, {manual}, {manual}, {manual}}")); + + auto* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + EXPECT_THAT(async_done, op::Sharding("{manual}")); + } + + { + // Test with execution_threads = {"thread_0", "thread_1"} + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get(), {"thread_0", "thread_1"})); + EXPECT_FALSE(changed); + } + + { + // Test with execution_threads = {}. Empty execution_threads means all + // execution_threads are included. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_FALSE(changed); + } +} + +TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingTuple) { + const char* const hlo_string = R"( +HloModule module + +called_computation { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + add = s32[8] add(p0, p1) + mul = s32[8] multiply(p0, p1) + ROOT result = (s32[8], s32[8]) tuple(add, mul) +}, execution_thread="thread_1" // called_computation + +ENTRY entry_computation { + p0 = s32[8] parameter(0), sharding={manual} + p1 = s32[8] parameter(1), sharding={manual} + async-start = ((s32[8], s32[8]), (s32[8], s32[8]), u32[]) call-start(p0, p1), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation + ROOT async-done = (s32[8], s32[8]) call-done(async-start), async_group_id=0, async_execution_thread="thread_1", to_apply=called_computation +}, execution_thread="thread_0" // entry_computation + +)"; + + { + // Test with execution_threads = {"thread_0"} + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get(), {"thread_0"})); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + + auto* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + EXPECT_THAT( + async_start, + op::Sharding("{{manual}, {manual}, {manual}, {manual}, {manual}}")); + + auto* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + EXPECT_THAT(async_done, op::Sharding("{{manual}, {manual}}")); + } + + { + // Test with execution_threads = {"thread_0", "thread_1"} + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get(), {"thread_0", "thread_1"})); + EXPECT_FALSE(changed); + } + + { + // Test with execution_threads = {}. Empty execution_threads means all + // execution_threads are included. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_FALSE(changed); + } +} + } // namespace } // namespace xla From b2249ef457380f959b076a4ef9073f9a62c6aaf9 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Tue, 14 Nov 2023 15:41:34 -0800 Subject: [PATCH 093/391] Mark api/v1 methods as deprecated. PiperOrigin-RevId: 582462791 --- tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 1 + .../mlir/tf2xla/api/v1/compile_mlir_util.h | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 477ce4cb0229a8..693d1f37766d81 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:status", "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/base:core_headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index 12e2212ba81445..3f6e446ca28fd9 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/base/attributes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -63,8 +64,7 @@ namespace tensorflow { // result shapes. // custom_legalization_passes: passes to run before the default TF legalization // passes for backend-specific ops. -// -// TODO(hinsu): Migrate options to a separate struct. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, @@ -98,6 +98,7 @@ Status ConvertMLIRToXlaComputation( // true, includes legalization and MHLO lowering passes. // allow_partial_conversion: when this is true, allow operations that can't be // legalized. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, bool enable_op_fallback, @@ -112,12 +113,14 @@ struct TensorOrResourceShape { }; // Refine MLIR types based on new shape information. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") Status RefineShapes(llvm::ArrayRef arg_shapes, mlir::ModuleOp module); // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level // inputs to module_op that have already been added to the XlaBuilder. returns // are the returned XlaOps. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, llvm::ArrayRef xla_params, std::vector& returns, @@ -129,6 +132,7 @@ Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, // Apply shape, description, and resource information to inputs and outputs // in the XlaCompilationResult. This should be called after // compilation_result->computation was set. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") Status PopulateResultIOInfo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, bool use_resource_updates_for_aliases, @@ -142,6 +146,7 @@ Status PopulateResultIOInfo( // // If enable_op_fallback is set to false, graph is legalized only if the graph // analysis for the graph is successful. Otherwise, an error is returned. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") StatusOr CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -157,6 +162,7 @@ StatusOr CompileMlirToXlaHlo( // // If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all // accompanying metadata and stores them in CompilationResult. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") StatusOr CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -172,6 +178,7 @@ StatusOr CompileSerializedMlirToXlaHlo( // metadata and stores them in CompilationResult. This will rewrite arguments // and run the TensorFlow standard pipeline prior to invoking // `CompileMlirToXlaHlo`. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") Status CompileGraphToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, @@ -183,6 +190,8 @@ Status CompileGraphToXlaHlo( // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata // and stores them in CompilationResult. +ABSL_DEPRECATED( + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef args, llvm::ArrayRef control_rets, llvm::StringRef device_type, @@ -197,6 +206,8 @@ Status CompileGraphToXlaHlo( // XlaBuilder. This function adds HLO to a larger HLO computation, so // HLO-level inputs are supplied, and HLO-level outputs are produced. // xla_params is the HLO-level inputs and returns is the HLO-level outputs. +ABSL_DEPRECATED( + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHloinstead.") Status BuildHloFromGraph( const Graph& graph, xla::XlaBuilder& builder, mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, From 3559d134bfc2f4cfca8e3f59c57cb40c75714077 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 14 Nov 2023 16:10:58 -0800 Subject: [PATCH 094/391] [XLA] Exclude async ops from elapsed times in MSA since we expect async ops to be efficiently scheduled. PiperOrigin-RevId: 582470977 --- .../memory_space_assignment.cc | 36 +++++++++++++++++++ .../memory_space_assignment_test.cc | 26 ++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index 8e950d7dafafd9..062d3189a0d201 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -882,8 +882,29 @@ float MemorySpaceAssignmentCostAnalysis::GetBytesAccessedFromAlternateMemory( return bytes_accessed_from_alternate_mem; } +namespace { +// Returns true on async instructions since we assume they are already +// efficiently scheduled such that they are not in the critical path and appear +// to take no time. +bool ExcludeInstructionFromElapsed(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAllGatherStart || + instruction.opcode() == HloOpcode::kAllGatherDone || + instruction.opcode() == HloOpcode::kAllReduceStart || + instruction.opcode() == HloOpcode::kAllReduceDone || + instruction.opcode() == HloOpcode::kAsyncStart || + instruction.opcode() == HloOpcode::kAsyncDone || + instruction.opcode() == HloOpcode::kCollectivePermuteStart || + instruction.opcode() == HloOpcode::kCollectivePermuteDone || + instruction.opcode() == HloOpcode::kCopyStart || + instruction.opcode() == HloOpcode::kCopyDone; +} +} // namespace + float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } return std::max( cost_analysis_.flop_count(instruction) / cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), @@ -895,6 +916,9 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( const HloInstruction& instruction, absl::Span> operands_in_alternate_mem, absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory( instruction, operands_in_alternate_mem, outputs_in_alternate_mem); @@ -910,6 +934,9 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( const HloInstruction& instruction, IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction); float bytes_accessed_from_alternate_mem = 0.0; for (int operand_num = 0; operand_num < instruction.operand_count(); @@ -948,6 +975,9 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( const HloInstruction& instruction) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } float overhead = GetDefaultMemoryAccessOverhead(instruction); return std::max(GetInstructionElapsedDueToCompute(instruction), GetInstructionElapsedDueToMemory(instruction) + overhead); @@ -957,6 +987,9 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( const HloInstruction& instruction, absl::Span> operands_in_alternate_mem, absl::Span outputs_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } float overhead = GetDefaultMemoryAccessOverhead( instruction, operands_in_alternate_mem, outputs_in_alternate_mem); return std::max( @@ -969,6 +1002,9 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( const HloInstruction& instruction, IsInAlternateMemoryFun is_in_alternate_mem) const { + if (ExcludeInstructionFromElapsed(instruction)) { + return 0.0f; + } return std::max( GetInstructionElapsedDueToCompute(instruction), GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem)); diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index be2bc573705c7f..91362ecaf034e0 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -7317,6 +7317,32 @@ ENTRY entry { AssignMemorySpaceUsingCostAnalysis(module.get(), options, cost_options); } +TEST_P(MemorySpaceAssignmentTest, AsyncOpElapsedTime) { + // Test that async ops are treated to take no time. We assume async operations + // are efficiently scheduled. So, in this example, collective-permute-start + // should take zero time, which should be insufficient time to overlap a + // prefetch for negate1's operand. + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + param0 = bf16[16]{0} parameter(0) + param1 = bf16[4]{0} parameter(1) + collective-permute-start = (bf16[16]{0}, bf16[16]{0}, u32[], u32[]) collective-permute-start(param0), source_target_pairs={{0,1},{1,2},{2,3}} + negate1 = bf16[4]{0} negate(param1) + collective-permute-done = bf16[16]{0} collective-permute-done(collective-permute-start) + ROOT negate2 = bf16[4]{0} negate(negate1) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AssignMemorySpaceUsingCostAnalysis(module.get()); + EXPECT_THAT(FindInstruction(module.get(), "negate1")->operand(0), + op::Parameter(1)); +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); From a6bf104fad37bcf4c6fcfb2e08c541473f886513 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Tue, 14 Nov 2023 17:18:33 -0800 Subject: [PATCH 095/391] Refactoring: Set default quantization method before checking PiperOrigin-RevId: 582488375 --- .../tensorflow/python/quantize_model.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 959d20f4148f3d..715eef9cf41837 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -1205,6 +1205,19 @@ def _populate_quantization_options_default_values( 'Legacy weight-only is deprecated. Use weight-only quantization method.' ) + # Converter assumes options are specified. So set SRQ explicitly. + if ( + quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_UNSPECIFIED + ): + logging.debug( + '"preset_method" for QuantizationMethod is not specified.' + 'Static range quantization is used by default.' + ) + quantization_options.quantization_method.preset_method = ( + _PresetMethod.METHOD_STATIC_RANGE_INT8 + ) + # Check default quantization option values for weight-only quantization. # TODO(b/242805842): Find good minimum_elements_for_weights number for server. # please also update default value in tflite converter: @@ -1266,19 +1279,6 @@ def _populate_quantization_options_default_values( ' quantization via TF Quantizer.' ) - # Converter assumes options are specified. So set SRQ explicitly. - if ( - quantization_options.quantization_method.preset_method - == _PresetMethod.METHOD_UNSPECIFIED - ): - logging.debug( - '"preset_method" for QuantizationMethod is not specified.' - 'Static range quantization is used by default.' - ) - quantization_options.quantization_method.preset_method = ( - _PresetMethod.METHOD_STATIC_RANGE_INT8 - ) - if quantization_options.HasField('debugger_options'): # Set `force_graph_mode_calibration` to True to avoid skipping op execution, # which are not connected to return ops, during calibration execution. From f4e2c81b5a32deb6d33a34bcf0d787275c31bc4b Mon Sep 17 00:00:00 2001 From: "Ryan M. Lefever" Date: Tue, 14 Nov 2023 17:52:29 -0800 Subject: [PATCH 096/391] Fix a bug in sliced prefetching in which it allows slice start times that are not necessarily in the same computation as the use. PiperOrigin-RevId: 582495253 --- .../xla/service/memory_space_assignment/BUILD | 3 + .../memory_space_assignment.cc | 122 ++--- .../memory_space_assignment.h | 34 +- .../memory_space_assignment_test.cc | 446 ++++++++++++++++++ 4 files changed, 545 insertions(+), 60 deletions(-) diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 475f9c67cdfd2a..9f800a22fb022f 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -64,6 +64,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -94,7 +95,9 @@ xla_cc_test( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/service:instruction_hoister", + "//xla/service:time_utils", "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index 062d3189a0d201..66ea87342f95f1 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/memory/memory.h" @@ -6937,9 +6938,23 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::CheckPrefetchFit( } // Update the prefetch start time in our working solution. - std::vector exclusive_slice_start_times = PickSliceStartTimes( - sliced_buffer_interval->num_slices(), - context.exclusive_prefetch_start_time, context.prefetch_end_time); + std::vector exclusive_slice_start_times = + SlicedPrefetchStartTimePicker::Pick( + sliced_buffer_interval->num_slices(), + context.exclusive_prefetch_start_time, context.prefetch_end_time, + [&](int64_t exclusive_start_time, + int64_t exclusive_end_time) -> float { + return options_.prefetch_interval_picker->GetLogicalIntervalElapsed( + exclusive_start_time, exclusive_end_time); + }, + [&](int64_t lhs_time, int64_t rhs_time) -> bool { + return hlo_live_range_.flattened_instruction_sequence() + .instructions()[lhs_time] + ->parent() == + hlo_live_range_.flattened_instruction_sequence() + .instructions()[rhs_time] + ->parent(); + }); CHECK_EQ(sliced_buffer_interval->num_slices(), exclusive_slice_start_times.size()); sliced_buffer_interval->UpdateExclusiveSliceStartTimes( @@ -7149,12 +7164,14 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::CheckPrefetchFit( return Result::kFailOutOfMemory; } -std::vector AlternateMemoryBestFitHeap::PickSliceStartTimes( - int64_t num_slices, int64_t prefetch_start_time, - int64_t prefetch_end_time) const { - CHECK_LE(prefetch_start_time, prefetch_end_time); +std::vector SlicedPrefetchStartTimePicker::Pick( + int64_t num_slices, int64_t exclusive_prefetch_start_time, + int64_t prefetch_end_time, absl::AnyInvocable elapsed_fn, + absl::AnyInvocable has_same_parent_fn) { + CHECK_LE(exclusive_prefetch_start_time, prefetch_end_time); VLOG(5) << "Picking slice start times. num_slices = " << num_slices - << "; prefetch_start_time = " << prefetch_start_time + << "; exclusive_prefetch_start_time = " + << exclusive_prefetch_start_time << "; prefetch_end_time = " << prefetch_end_time; // Prefetching starts after the selected start instruction and ends @@ -7162,59 +7179,54 @@ std::vector AlternateMemoryBestFitHeap::PickSliceStartTimes( // instructions worth of time to perform all of the sliced copies. So, the // only choices for start times that give us time to copy are <= // prefetch_end_time - 2. - if (prefetch_start_time >= prefetch_end_time - 2 || num_slices == 1) { - return std::vector(num_slices, prefetch_start_time); + if (exclusive_prefetch_start_time >= prefetch_end_time - 2 || + num_slices == 1) { + return std::vector(num_slices, exclusive_prefetch_start_time); } float total_elapsed = - options_.prefetch_interval_picker->GetLogicalIntervalElapsed( - prefetch_start_time, prefetch_end_time); + elapsed_fn(exclusive_prefetch_start_time, prefetch_end_time); if (total_elapsed <= 0.0) { - return std::vector(num_slices, prefetch_start_time); - } - - CHECK_LE(prefetch_start_time, prefetch_end_time - 2); - std::vector reverse_start_times; - reverse_start_times.reserve(num_slices); - for (int64_t candidate_start_time = prefetch_end_time - 2; - reverse_start_times.size() < num_slices && - candidate_start_time >= prefetch_start_time; - --candidate_start_time) { - if (candidate_start_time == prefetch_start_time) { - while (reverse_start_times.size() < num_slices) { - // This is the last good start time, so use it for all remaining slices. - reverse_start_times.push_back(candidate_start_time); - } - break; + return std::vector(num_slices, exclusive_prefetch_start_time); + } + + std::vector start_times; + start_times.reserve(num_slices); + start_times.push_back(exclusive_prefetch_start_time); + int64_t last_valid_candidate = exclusive_prefetch_start_time; + int64_t candidate = exclusive_prefetch_start_time; + while (candidate < prefetch_end_time - 1 && start_times.size() < num_slices) { + float target_elapsed = total_elapsed * + static_cast(num_slices - start_times.size()) / + static_cast(num_slices); + float elapsed = elapsed_fn(candidate, prefetch_end_time); + if (elapsed < target_elapsed) { + // We've gone past our target, so use the last valid candidate. + start_times.push_back(last_valid_candidate); + continue; } - float used = options_.prefetch_interval_picker->GetLogicalIntervalElapsed( - candidate_start_time, prefetch_end_time); - CHECK_GE(used, 0.0) << used << " real time elapses in logical interval (" - << candidate_start_time << ", " << prefetch_end_time - << "). Expected something >= 0.0."; - CHECK_LE(used, total_elapsed); - auto compute_target_fraction = - [num_slices](const std::vector& reverse_start_times) -> float { - return (static_cast(reverse_start_times.size()) + 1.0f) / - static_cast(num_slices); - }; - while (used >= - compute_target_fraction(reverse_start_times) * total_elapsed) { - CHECK_LE(reverse_start_times.size(), num_slices) - << "Num slices = " << num_slices - << "; Prefetch start = " << prefetch_start_time - << "; Slice candidate time = " << candidate_start_time - << "; Prefetch end = " << prefetch_end_time - << "; Total elapsed = " << total_elapsed << "; Used = " << used - << "; Target fraction = " - << compute_target_fraction(reverse_start_times); - reverse_start_times.push_back(candidate_start_time); + bool updating_candidate_impacts_elapsed = + last_valid_candidate != candidate && + elapsed_fn(last_valid_candidate, + ExclusiveToInclusiveStartTime(candidate)) > 0.0; + // has_same_parent_fn will look up the computation parent of the + // instructions at prefetch_start_time and prefetch_end_time. If + // prefetch_start_time is -1, no such instruction will exist. However, if we + // want to insert an instruction after the -1 schedule position, we can + // use the parent of the instruction at index 0 instead. Thus, we use + // std::max below. + if (has_same_parent_fn(std::max(0, exclusive_prefetch_start_time), + std::max(0, candidate)) && + updating_candidate_impacts_elapsed) { + last_valid_candidate = candidate; } + ++candidate; + } + while (start_times.size() < num_slices) { + start_times.push_back(last_valid_candidate); } - CHECK_EQ(reverse_start_times.size(), num_slices); - absl::c_reverse(reverse_start_times); - return reverse_start_times; + return start_times; } std::string @@ -7228,7 +7240,7 @@ AlternateMemoryBestFitHeap::AlternateMemoryAllocationAttemptToString( for (int i = 0; i < sliced_buffer_interval->num_slices(); ++i) { slice_times.push_back(absl::StrCat( - "(", sliced_buffer_interval->IntervalForMakeFreeChunks(i).start, ", ", + "[", sliced_buffer_interval->IntervalForMakeFreeChunks(i).start, ", ", sliced_buffer_interval->full_buffer_interval().end, ")")); if (context.slice_proposal_collection) { estimated_slice_prefetch_end_times.push_back( @@ -7918,7 +7930,7 @@ Status MemorySpaceAssignment::SlicedCopyAllocation::Process() { TF_RETURN_IF_ERROR(slice_detail.CreateAsyncSlice( shape, *producing_instruction, *computation, update_layout_fn_)); VLOG(4) << "Created " << slice_detail.copy_start->name() - << " for copy allocation: " << ToString(); + << " for sliced copy allocation: " << ToString(); slice_dones.push_back(slice_detail.copy_done); } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index 640b17b0fe59ab..cd26fdb21ee458 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -557,6 +559,33 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { std::optional shape_override_; }; +// A class for turning a copy start time and end time into slice start times. +class SlicedPrefetchStartTimePicker { + public: + // Returns the amount of time elapsed in the instruction schedule between + // (exclusive_start_time, exclusive_end_time). + using ElapsedTimeFn = std::add_pointer::type; + + // Returns true if the instructions at lhs_time and rhs_time are in the same + // computation. + using SameComputationParentFn = + std::add_pointer::type; + + // Picks slice start times, given the num_slices, prefetch_start_time, and + // prefetch_end_time. The returned times are exclusive. + // + // REQUIRES: + // - The instructions following each start time are guaranateed to be in the + // same computation. + // - The returned times sorted. + // - The first returned time is equal to prefetch_start_time. + static std::vector Pick( + int64_t num_slices, int64_t exclusive_prefetch_start_time, + int64_t prefetch_end_time, absl::AnyInvocable elapsed_fn, + absl::AnyInvocable has_same_parent_fn); +}; + // MemorySpaceAssignment assigns memory spaces (default or alternate) to each // instruction in the module. It will greedily try placing as as many values in // the alternate memory space as possible. It uses the heap simulator to @@ -2486,11 +2515,6 @@ class AlternateMemoryBestFitHeap // Check if for the specified type of solution, using the parameters in // context. If we find a solution, it will be stored in context. Result CheckPrefetchFit(bool for_sliced_solution, PrefetchContext& context); - // Given a specified number of slices, start times, and end times, pick times - // to start each slice. - std::vector PickSliceStartTimes(int64_t num_slices, - int64_t prefetch_start_time, - int64_t prefetch_end_time) const; // Creates a debugging string describing the timing of the prefetch solution // we are currently attempting (as dictated by for_sliced_solution and // context). diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 91362ecaf034e0..c54a1d03aebf7c 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -55,10 +56,12 @@ limitations under the License. #include "xla/service/instruction_hoister.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" #include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/time_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" @@ -10647,6 +10650,195 @@ ENTRY entry { RunMsa(module.get(), /*alternate_memory_size=*/512)); } +class SlicedPrefetchStartTimePickerTest : public ::testing::Test { + protected: + struct FakeInstructionData { + float elapsed_time = 0.0; + std::string computation; + }; + + std::vector Pick( + const std::vector& schedule_data, int64_t num_slices, + int64_t prefetch_start_time, int64_t prefetch_end_time) { + return memory_space_assignment::SlicedPrefetchStartTimePicker::Pick( + num_slices, prefetch_start_time, prefetch_end_time, + [&schedule_data](int64_t exclusive_start_time, + int64_t exclusive_end_time) { + auto start_it = schedule_data.begin() + + ExclusiveToInclusiveStartTime(exclusive_start_time); + auto end_it = (exclusive_end_time < schedule_data.size() + ? schedule_data.begin() + exclusive_end_time + : schedule_data.end()); + return std::accumulate( + start_it, end_it, 0.0, + [](float total, const FakeInstructionData& data) { + return total + data.elapsed_time; + }); + }, + [&schedule_data](int64_t lhs_time, int64_t rhs_time) { + CHECK_GE(lhs_time, 0); + CHECK_GE(rhs_time, 0); + CHECK_LT(lhs_time, schedule_data.size()); + CHECK_LT(rhs_time, schedule_data.size()); + return schedule_data[lhs_time].computation == + schedule_data[rhs_time].computation; + }); + } +}; + +TEST_F(SlicedPrefetchStartTimePickerTest, Base1) { + // The 2nd slice naturally should start after 1.5 time units have passed, + // forcing us to start before t=1. + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {1.0, "a"}, + /*t=2*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/3), + ::testing::ElementsAre(-1, 0)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, Base2) { + // The 2nd slice naturally should start after 6.0 time units have passed, + // forcing us to start before t=0. + EXPECT_THAT(Pick({ + /*t=0*/ {10.0, "a"}, + /*t=1*/ {1.0, "a"}, + /*t=2*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/3), + ::testing::ElementsAre(-1, -1)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, Base3) { + // The 2nd slice naturally should start after 1.0 time unit has passed. + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/2), + ::testing::ElementsAre(-1, 0)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, Zeros1) { + // The 2nd slice naturally should start after 1.0 time unit has passed. + // Make sure we don't add extra 0.0 cost instructions to the start time. + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {0.0, "a"}, + /*t=2*/ {0.0, "a"}, + /*t=3*/ {0.0, "a"}, + /*t=4*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/5), + ::testing::ElementsAre(-1, 0)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, Zeros2) { + // The 2nd slice naturally should start after 2.0 time units have passed. + // Make sure we don't add extra 0.0 cost instructions to the start time. + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {0.0, "a"}, + /*t=2*/ {1.0, "a"}, + /*t=3*/ {0.0, "a"}, + /*t=4*/ {1.0, "a"}, + /*t=5*/ {0.0, "a"}, + /*t=6*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/7), + ::testing::ElementsAre(-1, 2)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, Zeros3) { + // The first slice always comes at prefetch_start_time. The 2nd slice + // naturally should start after 1.5 time units have passed, causing us to + // start after t=2. Make sure we don't add extra 0.0 cost instructions to the + // start time. + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {0.0, "a"}, + /*t=2*/ {1.0, "a"}, + /*t=3*/ {0.0, "a"}, + /*t=4*/ {1.0, "a"}, + /*t=5*/ {0.0, "a"}, + /*t=6*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/1, + /*prefetch_end_time=*/7), + ::testing::ElementsAre(1, 2)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, MidSchedule) { + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {1.0, "a"}, + /*t=3*/ {1.0, "a"}, + /*t=4*/ {1.0, "a"}, + /*t=5*/ {1.0, "a"}, + /*t=6*/ {1.0, "a"}, + /*t=7*/ {1.0, "a"}, + /*t=8*/ {1.0, "a"}, + /*t=9*/ {1.0, "a"}, + /*t=10*/ {1.0, "a"}, + /*t=11*/ {1.0, "a"}, + /*t=12*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/5, + /*prefetch_end_time=*/10), + ::testing::ElementsAre(5, 7)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, ManySlices) { + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {1.0, "a"}, + /*t=2*/ {1.0, "a"}, + /*t=3*/ {1.0, "a"}, + /*t=4*/ {1.0, "a"}, + /*t=5*/ {1.0, "a"}, + /*t=6*/ {1.0, "a"}, + /*t=7*/ {1.0, "a"}, + /*t=8*/ {1.0, "a"}, + /*t=9*/ {1.0, "a"}, + /*t=10*/ {1.0, "a"}, + /*t=11*/ {1.0, "a"}, + /*t=12*/ {1.0, "a"}, + /*t=13*/ {1.0, "a"}, + /*t=14*/ {1.0, "a"}, + /*t=15*/ {1.0, "a"}, + /*t=16*/ {1.0, "a"}, + /*t=17*/ {1.0, "a"}, + /*t=18*/ {1.0, "a"}, + /*t=19*/ {1.0, "a"}, + }, + /*num_slices=*/5, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/20), + ::testing::ElementsAre(-1, 3, 7, 11, 15)); +} + +TEST_F(SlicedPrefetchStartTimePickerTest, DifferentParents) { + // The 2nd slice naturally should start after t=2, but we are forced to push + // it after t=1, since the instruction at t=3 has parent "b", while the first + // instruction has parent "a." + EXPECT_THAT(Pick({ + /*t=0*/ {1.0, "a"}, + /*t=1*/ {1.0, "a"}, + /*t=2*/ {1.0, "b"}, + /*t=3*/ {1.0, "b"}, + /*t=4*/ {1.0, "b"}, + /*t=5*/ {1.0, "a"}, + }, + /*num_slices=*/2, /*prefetch_start_time=*/-1, + /*prefetch_end_time=*/6), + ::testing::ElementsAre(-1, 1)); +} + class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { protected: // Used by CheckSchedule() to classify instructions in the schedule. @@ -11059,6 +11251,38 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { return nullptr; } + static StatusOr> GetSliceStartIndicies( + const std::vector& schedule, + const HloInstruction* concat_bitcast) { + std::vector indicies; + + if (!IsConcatBitcast(concat_bitcast)) { + return InvalidArgumentStrCat(concat_bitcast->name(), + " is not a concat-bitcast."); + } + for (int i = 0; i < concat_bitcast->operand_count(); ++i) { + const HloInstruction* async_slice_done = concat_bitcast->operand(i); + if (!IsAsyncSliceDone(async_slice_done)) { + return InvalidArgumentStrCat("Operand ", i, " of ", + concat_bitcast->name(), + " is not an async-slice-done."); + } + const HloInstruction* async_slice_start = async_slice_done->operand(0); + if (!IsAsyncSliceStart(async_slice_start)) { + return InvalidArgumentStrCat("Operand 0, of operand ", i, " of ", + concat_bitcast->name(), + " is not an async-slice-start."); + } + TF_ASSIGN_OR_RETURN( + int schedule_index, + FindScheduleIndexOfInstruction(schedule, async_slice_start->name(), + InstructionClass::kRelatedSliceStart)); + indicies.push_back(schedule_index); + } + + return indicies; + } + // REQUIRES: // - Concat-bitcast and all slices were found in the schedule used to // construct schedule_to_class. @@ -12124,5 +12348,227 @@ ENTRY main { EXPECT_EQ(p2_slice_offsets[1], 2048); } +struct ModuleAndAssignments { + std::unique_ptr module; + std::unique_ptr assignments; +}; + +// In this test, we ensure that sliced prefetching does not attempt to start a +// slice during a different computation than the one where the slice finishes. +// We do this by forcing a sliced prefetch to start just before back-to-back +// while loops and to immediately finish after them. We use while loops with +// different expected elapse times, so that the ideal place to start the second +// slice is during one of the while loops. +TEST_F(SlicedPrefetchTest, BackToBackWhileLoops) { + // Define constants for building our test HLO. + const std::string while_cond = R"zz( +WhileCond$ID { + cond_param = (f32[8,8], f32[8,8], f32[], f32[]) parameter(0) + i = f32[] get-tuple-element(cond_param), index=2 + limit = f32[] get-tuple-element(cond_param), index=3 + + ROOT cond_result = pred[] compare(i, limit), direction=LT +})zz"; + + const std::string while_body = R"zz( +WhileBody$ID { + body_param = (f32[8,8], f32[8,8], f32[], f32[]) parameter(0) + v0 = f32[8,8] get-tuple-element(body_param), index=0 + v1 = f32[8,8] get-tuple-element(body_param), index=1 + i = f32[] get-tuple-element(body_param), index=2 + limit = f32[] get-tuple-element(body_param), index=3 + one = f32[] constant(1) + + new_i = f32[] add(i, one) + $COMPUTATION + + ROOT while_result = (f32[8,8], f32[8,8], f32[], f32[]) tuple(v0, new_v1, new_i, limit) +})zz"; + + const std::string while_computation_cheap = R"zz( + new_v1 = f32[8,8] add(v0, v1))zz"; + + std::string while_computation_expensive = R"zz( + new_v1_0 = f32[8,8] add(v0, v1) + new_v1_1 = f32[8,8] tanh(new_v1_0) + new_v1_2 = f32[8,8] tanh(new_v1_1) + new_v1_3 = f32[8,8] tanh(new_v1_2) + new_v1 = f32[8,8] tanh(new_v1_3))zz"; + + std::string module_text = R"zz( +HloModule Slice, is_scheduled=true + +$WHILEBODY1 +$WHILECOND1 +$WHILEBODY2 +$WHILECOND2 + +ENTRY main { + loop1_input1 = f32[8,8] parameter(0) + loop1_input2 = f32[8,8] parameter(1) + loop1_iterations = f32[] parameter(2) + loop1_begin = f32[] constant(0) + loop1_tuple = (f32[8,8], f32[8,8], f32[], f32[]) tuple(loop1_input1, loop1_input2, loop1_iterations, loop1_begin) + loop2_input1 = f32[8,8] parameter(3) + loop2_input2 = f32[8,8] parameter(4) + loop2_iterations = f32[] parameter(5) + loop2_begin = f32[] constant(0) + loop2_tuple = (f32[8,8], f32[8,8], f32[], f32[]) tuple(loop2_input1, loop2_input2, loop2_iterations, loop2_begin) + + prefetch = f32[8,8] parameter(6) + loop1_output = (f32[8,8], f32[8,8], f32[], f32[]) while(loop1_tuple), condition=WhileCond1, body=WhileBody1 + loop2_output = (f32[8,8], f32[8,8], f32[], f32[]) while(loop2_tuple), condition=WhileCond2, body=WhileBody2 + prefetch_use = f32[8,8] tanh(prefetch) + + loop1_result = f32[8,8] get-tuple-element(loop1_output), index=1 + loop2_result = f32[8,8] get-tuple-element(loop2_output), index=1 + + tmp1 = f32[8,8] add(loop1_result, loop2_result) + ROOT r = f32[8,8] add(tmp1, prefetch_use) +})zz"; + + // A lambda for generating HLO with 2 while loops called back to back. The + // first while loop will execute while_computation1 and the second while loop + // will execute while_computation2. + auto gen_hlo = [&](std::string_view while_computation1, + std::string_view while_computation2) { + return absl::StrReplaceAll( + module_text, + { + {"$WHILEBODY1", + absl::StrReplaceAll( + while_body, + {{"$ID", "1"}, {"$COMPUTATION", while_computation1}})}, + {"$WHILECOND1", absl::StrReplaceAll(while_cond, {{"$ID", "1"}})}, + {"$WHILEBODY2", + absl::StrReplaceAll( + while_body, + {{"$ID", "2"}, {"$COMPUTATION", while_computation2}})}, + {"$WHILECOND2", absl::StrReplaceAll(while_cond, {{"$ID", "2"}})}, + }); + }; + + // Configure MSA. + SetupProposeSlicesToExpect2SlicesOfF32x8x8(); + // Force MSA to prefer prefetching 'prefetch'. + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& lhs, + const MemorySpaceAssignment::BufferInterval& rhs) { + auto lookup = [](const MemorySpaceAssignment::BufferInterval& x) { + // An arbitrary value that is greater than that used for 'prefetch'. + int priority = 100; + if (x.buffer->instruction()->name() == "prefetch") { + priority = 1; + } + return std::make_tuple(priority, x.buffer->instruction()->name()); + }; + + return lookup(lhs) < lookup(rhs); + }; + // We set the minimum prefetch interval to a large enough value (32) to force + // us to prefetch around both while loops, and not just 1. + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(32, 100); + options_.max_size_in_bytes = 4 * 64; + + // Define a lambda for running MSA on the specified HLO, with the + // configuration above. + auto run_msa = + [&](std::string_view hlo_text) -> StatusOr { + ModuleAndAssignments module_and_assignments; + TF_ASSIGN_OR_RETURN(module_and_assignments.module, + ParseAndReturnVerifiedModule(hlo_text)); + VLOG(1) << "Original module:\n" + << module_and_assignments.module->ToString( + HloPrintOptions::ShortParsable()); + module_and_assignments.assignments = + AssignMemorySpace(module_and_assignments.module.get(), options_, + buffer_interval_compare, &prefetch_interval_picker); + VLOG(1) << "Post-MSA module:\n" + << module_and_assignments.module->ToString( + HloPrintOptions::ShortParsable()); + return module_and_assignments; + }; + + // In this case, less time elapses during the first while loop than the + // second. Make sure we start the second slice between the two while loops, + // rather than during the second while loop. + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndAssignments module_and_assignments1, + run_msa(gen_hlo(while_computation_cheap, while_computation_expensive))); + auto root1 = + module_and_assignments1.module->entry_computation()->root_instruction(); + EXPECT_THAT(root1, op::Add(_, op::Tanh(IsAsyncSlicedCopy( + kAlternateMemorySpace, kDefaultMemorySpace, + {{{0, 4}, {0, 8}}, {{4, 8}, {0, 8}}}, + op::Parameter(6))))); + TF_EXPECT_OK(CheckSchedule( + *module_and_assignments1.module, root1->operand(1)->operand(0), + /*slices_start_after_instruction_name=*/"prefetch", + /*slices_done_before_instruction_name=*/"prefetch_use", + /*expect_slices_started_at_different_times=*/true)); + auto entry_schedule1 = + module_and_assignments1.module->schedule() + .sequence(module_and_assignments1.module->entry_computation()) + .instructions(); + TF_ASSERT_OK_AND_ASSIGN( + std::vector start_indicies, + GetSliceStartIndicies(entry_schedule1, root1->operand(1)->operand(0))); + ASSERT_EQ(start_indicies.size(), 2); + TF_ASSERT_OK_AND_ASSIGN( + int first_while, + FindScheduleIndexOfInstruction( + entry_schedule1, "loop1_output", + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + TF_ASSERT_OK_AND_ASSIGN( + int second_while, + FindScheduleIndexOfInstruction( + entry_schedule1, "loop2_output", + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + EXPECT_TRUE( + absl::c_is_sorted>( + {start_indicies[0], first_while, start_indicies[1], second_while}) || + absl::c_is_sorted>( + {start_indicies[1], first_while, start_indicies[0], second_while})); + + // In this case, more time elapses during the first while loop than the + // second. This should push us to use a normal prefetch, rather than slicing, + // since the ideal time to start the second slice will get pushed before + // both while loops. + TF_ASSERT_OK_AND_ASSIGN( + ModuleAndAssignments module_and_assignments2, + run_msa(gen_hlo(while_computation_expensive, while_computation_cheap))); + auto root2 = + module_and_assignments2.module->entry_computation()->root_instruction(); + EXPECT_THAT(root2, op::Add(_, op::Tanh(op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, + op::Parameter(6))))); + auto entry_schedule2 = + module_and_assignments2.module->schedule() + .sequence(module_and_assignments2.module->entry_computation()) + .instructions(); + TF_ASSERT_OK_AND_ASSIGN( + int copy_done, + FindScheduleIndexOfInstruction( + entry_schedule2, root2->operand(1)->operand(0)->name(), + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + TF_ASSERT_OK_AND_ASSIGN( + int copy_start, + FindScheduleIndexOfInstruction( + entry_schedule2, root2->operand(1)->operand(0)->operand(0)->name(), + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + TF_ASSERT_OK_AND_ASSIGN( + first_while, + FindScheduleIndexOfInstruction( + entry_schedule2, "loop1_output", + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + TF_ASSERT_OK_AND_ASSIGN( + second_while, + FindScheduleIndexOfInstruction( + entry_schedule2, "loop2_output", + SlicedPrefetchTest::InstructionClass::kUnrelatedNonCopy)); + EXPECT_TRUE(absl::c_is_sorted>( + {copy_start, first_while, second_while, copy_done})); +} + } // namespace } // namespace xla From 7eb2d53b26e41f47f83ea0cf0cb358fcd493c9b6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 14 Nov 2023 18:19:27 -0800 Subject: [PATCH 097/391] [TSL] Change coordination service const std::string& arguments for keys and values to std::string_view. I plan to add a caller that has a std::vector, and this saves a copy in that case. PiperOrigin-RevId: 582500430 --- ...coordination_service_barrier_proxy_test.cc | 25 ++--- .../coordination_service_agent.cc | 95 +++++++++---------- .../coordination/coordination_service_agent.h | 29 +++--- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 6 +- .../xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc | 6 +- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 13 +-- .../xla/xla/pjrt/distributed/client.cc | 18 ++-- third_party/xla/xla/pjrt/distributed/client.h | 10 +- .../pjrt/distributed/client_server_test.cc | 16 ++-- .../pjrt/distributed/topology_util_test.cc | 7 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 6 +- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 6 +- third_party/xla/xla/pjrt/pjrt_client.h | 7 +- third_party/xla/xla/python/xla.cc | 12 +-- .../functional_hlo_runner.cc | 6 +- 15 files changed, 133 insertions(+), 129 deletions(-) diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index c4a7af7c6a26fd..0261268a589e2c 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -78,36 +79,36 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(Status, ReportError, (const Status& error), (override)); MOCK_METHOD(Status, Shutdown, (), (override)); MOCK_METHOD(Status, Reset, (), (override)); - MOCK_METHOD(StatusOr, GetKeyValue, (const std::string& key), + MOCK_METHOD(StatusOr, GetKeyValue, (std::string_view key), (override)); MOCK_METHOD(StatusOr, GetKeyValue, (const char* key, int64_t key_size), (override)); MOCK_METHOD(StatusOr, GetKeyValue, - (const std::string& key, absl::Duration timeout), (override)); + (std::string_view key, absl::Duration timeout), (override)); MOCK_METHOD(std::shared_ptr, GetKeyValueAsync, - (const std::string& key, StatusOrValueCallback done), (override)); - MOCK_METHOD(StatusOr, TryGetKeyValue, (const std::string& key), + (std::string_view key, StatusOrValueCallback done), (override)); + MOCK_METHOD(StatusOr, TryGetKeyValue, (std::string_view key), (override)); MOCK_METHOD(StatusOr>, GetKeyValueDir, - (const std::string& key), (override)); + (std::string_view key), (override)); MOCK_METHOD(void, GetKeyValueDirAsync, - (const std::string& key, StatusOrValueDirCallback done), + (std::string_view key, StatusOrValueDirCallback done), (override)); MOCK_METHOD(Status, InsertKeyValue, - (const std::string& key, const std::string& value), (override)); + (std::string_view key, std::string_view value), (override)); MOCK_METHOD(Status, InsertKeyValue, (const char* key, int64_t key_size, const char* value, int64_t value_size), (override)); - MOCK_METHOD(Status, DeleteKeyValue, (const std::string& key), (override)); + MOCK_METHOD(Status, DeleteKeyValue, (std::string_view key), (override)); MOCK_METHOD(Status, DeleteKeyValue, (const char* key, int64_t key_size), (override)); MOCK_METHOD(Status, UpdateKeyValue, - (const std::string& key, const std::string& value), (override)); + (std::string_view key, std::string_view value), (override)); MOCK_METHOD(Status, StartWatchKey, - (const std::string& key, ChangedKeyValuesCallback on_change), + (std::string_view key, ChangedKeyValuesCallback on_change), (override)); - MOCK_METHOD(Status, StopWatchKey, (const std::string& key), (override)); + MOCK_METHOD(Status, StopWatchKey, (std::string_view key), (override)); MOCK_METHOD(void, WaitAtBarrierAsync, (const std::string& barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done), @@ -117,7 +118,7 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(StatusOr, GetEnv, (), (override)); MOCK_METHOD(void, SetError, (const Status& error), (override)); MOCK_METHOD(Status, ActivateWatch, - (const std::string& key, + (std::string_view key, (const std::map&)), (override)); }; diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc index a45213d1817624..79065f7a9118ab 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -92,29 +93,27 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { Status Shutdown() override; Status Reset() override; - StatusOr GetKeyValue(const std::string& key) override; + StatusOr GetKeyValue(std::string_view key) override; StatusOr GetKeyValue(const char* key, int64_t key_size) override; - StatusOr GetKeyValue(const std::string& key, + StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) override; std::shared_ptr GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) override; - StatusOr TryGetKeyValue(const std::string& key) override; + std::string_view key, StatusOrValueCallback done) override; + StatusOr TryGetKeyValue(std::string_view key) override; StatusOr> GetKeyValueDir( - const std::string& key) override; - void GetKeyValueDirAsync(const std::string& key, + std::string_view key) override; + void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) override; - Status InsertKeyValue(const std::string& key, - const std::string& value) override; + Status InsertKeyValue(std::string_view key, std::string_view value) override; Status InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) override; - Status DeleteKeyValue(const std::string& key) override; + Status DeleteKeyValue(std::string_view key) override; Status DeleteKeyValue(const char* key, int64_t key_size) override; - Status UpdateKeyValue(const std::string& key, - const std::string& value) override; + Status UpdateKeyValue(std::string_view key, std::string_view value) override; - Status StartWatchKey(const std::string& key, + Status StartWatchKey(std::string_view key, ChangedKeyValuesCallback on_change) override; - Status StopWatchKey(const std::string& key) override; + Status StopWatchKey(std::string_view key) override; Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout, const std::vector& tasks) override; void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout, @@ -128,7 +127,7 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { protected: void SetError(const Status& error) override; - Status ActivateWatch(const std::string& key, + Status ActivateWatch(std::string_view key, const std::map&) override; // Returns an error if agent is not running. If `allow_disconnected` is true, // returns OK even if the agent is in DISCONNECTED state. @@ -567,17 +566,17 @@ Status CoordinationServiceAgentImpl::Reset() { } StatusOr CoordinationServiceAgentImpl::GetKeyValue( - const std::string& key) { + std::string_view key) { return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration()); } StatusOr CoordinationServiceAgentImpl::GetKeyValue( const char* key, int64_t key_size) { - return GetKeyValue(std::string(key, key_size)); + return GetKeyValue(std::string_view(key, key_size)); } StatusOr CoordinationServiceAgentImpl::GetKeyValue( - const std::string& key, absl::Duration timeout) { + std::string_view key, absl::Duration timeout) { auto n = std::make_shared(); auto result = std::make_shared>(); GetKeyValueAsync(key, @@ -597,9 +596,9 @@ StatusOr CoordinationServiceAgentImpl::GetKeyValue( } std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) { + std::string_view key, StatusOrValueCallback done) { auto request = std::make_shared(); - request->set_key(key); + request->set_key(key.data(), key.size()); VLOG(3) << "GetKeyValueRequest: " << request->DebugString(); auto response = std::make_shared(); auto call_opts = std::make_shared(); @@ -633,33 +632,31 @@ std::shared_ptr CoordinationServiceAgentImpl::GetKeyValueAsync( } StatusOr CoordinationServiceAgentImpl::TryGetKeyValue( - const std::string& key) { + std::string_view key) { absl::Notification n; StatusOr result; TryGetKeyValueRequest request; - request.set_key(key); + request.set_key(key.data(), key.size()); VLOG(3) << "TryGetKeyValueRequest: " << request.DebugString(); TryGetKeyValueResponse response; - leader_client_->TryGetKeyValueAsync(&request, &response, - [&](const Status& s) { - if (s.ok()) { - result = response.kv().value(); - VLOG(3) << "TryGetKeyValueResponse: " - << result.value(); - } else { - result = s; - VLOG(3) << "TryGetKeyValueResponse: " - << s; - } - n.Notify(); - }); + leader_client_->TryGetKeyValueAsync( + &request, &response, [&](const Status& s) { + if (s.ok()) { + result = response.kv().value(); + VLOG(3) << "TryGetKeyValueResponse: " << result.value(); + } else { + result = s; + VLOG(3) << "TryGetKeyValueResponse: " << s; + } + n.Notify(); + }); n.WaitForNotification(); return result; } StatusOr> -CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) { +CoordinationServiceAgentImpl::GetKeyValueDir(std::string_view key) { absl::Notification n; StatusOr> result; GetKeyValueDirAsync( @@ -673,9 +670,9 @@ CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) { } void CoordinationServiceAgentImpl::GetKeyValueDirAsync( - const std::string& key, StatusOrValueDirCallback done) { + std::string_view key, StatusOrValueDirCallback done) { auto request = std::make_shared(); - request->set_directory_key(key); + request->set_directory_key(key.data(), key.size()); VLOG(3) << "GetKeyValueDirRequest: " << request->DebugString(); auto response = std::make_shared(); leader_client_->GetKeyValueDirAsync( @@ -694,8 +691,8 @@ void CoordinationServiceAgentImpl::GetKeyValueDirAsync( }); } -Status CoordinationServiceAgentImpl::InsertKeyValue(const std::string& key, - const std::string& value) { +Status CoordinationServiceAgentImpl::InsertKeyValue(std::string_view key, + std::string_view value) { InsertKeyValueRequest request; request.mutable_kv()->set_key(key.data(), key.size()); request.mutable_kv()->set_value(value.data(), value.size()); @@ -717,13 +714,13 @@ Status CoordinationServiceAgentImpl::InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) { - return InsertKeyValue(std::string(key, key_size), - std::string(value, value_size)); + return InsertKeyValue(std::string_view(key, key_size), + std::string_view(value, value_size)); } -Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) { +Status CoordinationServiceAgentImpl::DeleteKeyValue(std::string_view key) { DeleteKeyValueRequest request; - request.set_key(key); + request.set_key(key.data(), key.size()); request.set_is_directory(true); VLOG(3) << "DeleteKeyValueRequest: " << request.DebugString(); DeleteKeyValueResponse response; @@ -741,23 +738,23 @@ Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) { Status CoordinationServiceAgentImpl::DeleteKeyValue(const char* key, int64_t key_size) { - return DeleteKeyValue(std::string(key, key_size)); + return DeleteKeyValue(std::string_view(key, key_size)); } -Status CoordinationServiceAgentImpl::UpdateKeyValue(const std::string& key, - const std::string& value) { +Status CoordinationServiceAgentImpl::UpdateKeyValue(std::string_view key, + std::string_view value) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::UpdateKeyValue is not implemented.")); } Status CoordinationServiceAgentImpl::StartWatchKey( - const std::string& key, + std::string_view key, CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::StartWatchKey is not implemented.")); } -Status CoordinationServiceAgentImpl::StopWatchKey(const std::string& key) { +Status CoordinationServiceAgentImpl::StopWatchKey(std::string_view key) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::StopWatchKey is not implemented.")); } @@ -774,7 +771,7 @@ void CoordinationServiceAgentImpl::SetError(const Status& error) { } Status CoordinationServiceAgentImpl::ActivateWatch( - const std::string& key, const std::map& kvs) { + std::string_view key, const std::map& kvs) { return MakeCoordinationError(errors::Unimplemented( "CoordinationServiceAgent::ActivateWatch is not implemented.")); } diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h index a567272f9d72ef..f94e6ac9dcb209 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -168,19 +169,19 @@ class CoordinationServiceAgent { // If the key-value is not inserted yet, this is a blocking call that waits // until the corresponding key is inserted. // - errors::DeadlineExceeded: timed out waiting for key. - virtual StatusOr GetKeyValue(const std::string& key) = 0; + virtual StatusOr GetKeyValue(std::string_view key) = 0; virtual StatusOr GetKeyValue(const char* key, int64_t key_size) = 0; - virtual StatusOr GetKeyValue(const std::string& key, + virtual StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) = 0; // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and // `call_opts->ClearCancelCallback()`. virtual std::shared_ptr GetKeyValueAsync( - const std::string& key, StatusOrValueCallback done) = 0; + std::string_view, StatusOrValueCallback done) = 0; // Get config key-value from the service. // - errors::NotFound: the requested key does not exist. - virtual StatusOr TryGetKeyValue(const std::string& key) = 0; + virtual StatusOr TryGetKeyValue(std::string_view key) = 0; // Get all values under a directory (key). // A value is considered to be in the directory if its key is prefixed with @@ -188,30 +189,30 @@ class CoordinationServiceAgent { // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. virtual StatusOr> GetKeyValueDir( - const std::string& key) = 0; - virtual void GetKeyValueDirAsync(const std::string& key, + std::string_view key) = 0; + virtual void GetKeyValueDirAsync(std::string_view key, StatusOrValueDirCallback done) = 0; // Insert config key-value to the service. // - errors::AlreadyExists: key is already set. - virtual Status InsertKeyValue(const std::string& key, - const std::string& value) = 0; + virtual Status InsertKeyValue(std::string_view key, + std::string_view value) = 0; virtual Status InsertKeyValue(const char* key, int64_t key_size, const char* value, int64_t value_size) = 0; // Delete config keys in the coordination service. - virtual Status DeleteKeyValue(const std::string& key) = 0; + virtual Status DeleteKeyValue(std::string_view key) = 0; virtual Status DeleteKeyValue(const char* key, int64_t key_size) = 0; // Update the value of a config key. - virtual Status UpdateKeyValue(const std::string& key, - const std::string& value) = 0; + virtual Status UpdateKeyValue(std::string_view key, + std::string_view value) = 0; // Register a callback that will be invoked when the key or keys under the key // directory are changed (inserted, deleted, or updated). - virtual Status StartWatchKey(const std::string& key, + virtual Status StartWatchKey(std::string_view key, ChangedKeyValuesCallback on_change) = 0; - virtual Status StopWatchKey(const std::string& key) = 0; + virtual Status StopWatchKey(std::string_view key) = 0; // Blocks until all (or a subset of) tasks are at the barrier or the barrier // fails. @@ -273,7 +274,7 @@ class CoordinationServiceAgent { virtual void SetError(const Status& error) = 0; // Activate the key-value callback watch. - virtual Status ActivateWatch(const std::string& key, + virtual Status ActivateWatch(std::string_view, const std::map&) = 0; private: diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 561c9d99c4e429..ba1a144b0d9468 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include // NOLINT(build/c++11) #include #include @@ -157,7 +158,7 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::flat_hash_map* kv_store, absl::Mutex& mu) { xla::PjRtClient::KeyValueGetCallback kv_get = - [kv_store, &mu](const std::string& k, + [kv_store, &mu](std::string_view k, absl::Duration timeout) -> xla::StatusOr { absl::Duration wait_interval = absl::Milliseconds(10); int num_retry = timeout / wait_interval; @@ -175,8 +176,7 @@ std::unique_ptr<::pjrt::PJRT_KeyValueCallbackData> CreateTestCKVCallback( absl::StrCat(k, " is not found in the kv store.")); }; xla::PjRtClient::KeyValuePutCallback kv_put = - [kv_store, &mu](const std::string& k, - const std::string& v) -> xla::Status { + [kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { { absl::MutexLock lock(&mu); kv_store->insert(std::pair(k, v)); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 1a9de396c193db..65af68ba2bb9f9 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -135,7 +136,7 @@ TEST(PjRtCApiHelperTest, Callback) { absl::flat_hash_map kv_store; absl::Mutex mu; xla::PjRtClient::KeyValueGetCallback kv_get = - [&kv_store, &mu](const std::string& k, + [&kv_store, &mu](std::string_view k, absl::Duration timeout) -> xla::StatusOr { absl::Duration wait_interval = absl::Milliseconds(10); int num_retry = timeout / wait_interval; @@ -153,8 +154,7 @@ TEST(PjRtCApiHelperTest, Callback) { absl::StrCat(k, " is not found in the kv store.")); }; xla::PjRtClient::KeyValuePutCallback kv_put = - [&kv_store, &mu](const std::string& k, - const std::string& v) -> xla::Status { + [&kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { { absl::MutexLock lock(&mu); kv_store[k] = v; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 0c65b5f4ba7e4c..7d8d1cb1b8a2e2 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -230,7 +231,7 @@ xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( return nullptr; } return [c_callback, user_arg]( - const std::string& key, + std::string_view key, absl::Duration timeout) -> xla::StatusOr { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -239,7 +240,7 @@ xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( std::string(message, message_size))}; }; PJRT_KeyValueGetCallback_Args args; - args.key = key.c_str(); + args.key = key.data(); args.key_size = key.size(); args.timeout_in_ms = timeout / absl::Milliseconds(1); args.callback_error = &callback_error; @@ -259,8 +260,8 @@ xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( if (c_callback == nullptr) { return nullptr; } - return [c_callback, user_arg](const std::string& key, - const std::string& value) -> xla::Status { + return [c_callback, user_arg](std::string_view key, + std::string_view value) -> xla::Status { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, size_t message_size) { @@ -268,9 +269,9 @@ xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( std::string(message, message_size))}; }; PJRT_KeyValuePutCallback_Args args; - args.key = key.c_str(); + args.key = key.data(); args.key_size = key.size(); - args.value = value.c_str(); + args.value = value.data(); args.value_size = value.size(); args.callback_error = &callback_error; args.user_arg = user_arg; diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index b3027eaf8b6169..60dc6878acd843 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -46,11 +47,12 @@ class DistributedRuntimeCoordinationServiceClient xla::Status Connect() override; xla::Status Shutdown() override; xla::StatusOr BlockingKeyValueGet( - std::string key, absl::Duration timeout) override; + std::string_view key, absl::Duration timeout) override; xla::StatusOr>> - KeyValueDirGet(absl::string_view key) override; - xla::Status KeyValueSet(std::string key, std::string value) override; - xla::Status KeyValueDelete(std::string key) override; + KeyValueDirGet(std::string_view key) override; + xla::Status KeyValueSet(std::string_view key, + std::string_view value) override; + xla::Status KeyValueDelete(std::string_view key) override; xla::Status WaitAtBarrier(std::string barrier_id, absl::Duration timeout) override; xla::StatusOr GetCoordinationServiceAgent() @@ -133,13 +135,13 @@ xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { xla::StatusOr DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( - std::string key, absl::Duration timeout) { + std::string_view key, absl::Duration timeout) { return coord_agent_->GetKeyValue(key, timeout); } xla::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( - absl::string_view key) { + std::string_view key) { // TODO(hanyangtay): Migrate to string_view for both client and coordination // agent APIs. TF_ASSIGN_OR_RETURN(const auto results, @@ -157,12 +159,12 @@ DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( } xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete( - std::string key) { + std::string_view key) { return coord_agent_->DeleteKeyValue(key); } xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( - std::string key, std::string value) { + std::string_view key, std::string_view value) { return coord_agent_->InsertKeyValue(key, value); } diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 6ee9ce976df557..c780389c63ef22 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -115,7 +116,7 @@ class DistributedRuntimeClient { // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual xla::StatusOr BlockingKeyValueGet( - std::string key, absl::Duration timeout) = 0; + std::string_view key, absl::Duration timeout) = 0; // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with @@ -123,13 +124,14 @@ class DistributedRuntimeClient { // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. virtual xla::StatusOr>> - KeyValueDirGet(absl::string_view key) = 0; + KeyValueDirGet(std::string_view key) = 0; - virtual xla::Status KeyValueSet(std::string key, std::string value) = 0; + virtual xla::Status KeyValueSet(std::string_view key, + std::string_view value) = 0; // Delete the key-value. If the key is a directory, recursively clean // up all key-values under the directory. - virtual xla::Status KeyValueDelete(std::string key) = 0; + virtual xla::Status KeyValueDelete(std::string_view key) = 0; // Blocks until all nodes are at the barrier or the barrier times out. // `barrier_id` should be unique across barriers. diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index efd3f7cda98e59..7e3c05d3b154af 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/strings/str_cat.h" @@ -217,12 +218,11 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // Sleep a short while for the other thread to send their device info first. absl::SleepFor(absl::Seconds(1)); - auto kv_get = [&](const std::string& k, + auto kv_get = [&](std::string_view k, absl::Duration timeout) -> xla::StatusOr { return client->BlockingKeyValueGet(k, timeout); }; - auto kv_put = [&](const std::string& k, - const std::string& v) -> xla::Status { + auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { return client->KeyValueSet(k, v); }; TF_RETURN_IF_ERROR( @@ -250,12 +250,11 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) { // We cannot send the notification after the call since there is a barrier // within the call that would cause a deadlock. n.Notify(); - auto kv_get = [&](const std::string& k, + auto kv_get = [&](std::string_view k, absl::Duration timeout) -> xla::StatusOr { return client->BlockingKeyValueGet(k, timeout); }; - auto kv_put = [&](const std::string& k, - const std::string& v) -> xla::Status { + auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { return client->KeyValueSet(k, v); }; TF_RETURN_IF_ERROR( @@ -316,12 +315,11 @@ TEST_F(ClientServerTest, EnumerateElevenDevices) { auto client = GetClient(node_id); GlobalTopologyProto topology; TF_RETURN_IF_ERROR(client->Connect()); - auto kv_get = [&](const std::string& k, + auto kv_get = [&](std::string_view k, absl::Duration timeout) -> xla::StatusOr { return client->BlockingKeyValueGet(k, timeout); }; - auto kv_put = [&](const std::string& k, - const std::string& v) -> xla::Status { + auto kv_put = [&](std::string_view k, std::string_view v) -> xla::Status { return client->KeyValueSet(k, v); }; TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index fd2b1a87709f50..c8baa3e4ac6f58 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include +#include #include #include "absl/container/flat_hash_map.h" @@ -66,7 +67,7 @@ TEST(TopologyTest, ExchangeTopology) { absl::Mutex mu; absl::flat_hash_map kv; - auto kv_get = [&](const std::string& key, + auto kv_get = [&](std::string_view key, absl::Duration timeout) -> xla::StatusOr { absl::MutexLock lock(&mu); auto ready = [&]() { return kv.contains(key); }; @@ -76,8 +77,8 @@ TEST(TopologyTest, ExchangeTopology) { return absl::NotFoundError("key not found"); }; - auto kv_put = [&](const std::string& key, - const std::string& value) -> xla::Status { + auto kv_put = [&](std::string_view key, + std::string_view value) -> xla::Status { absl::MutexLock lock(&mu); kv[key] = value; return absl::OkStatus(); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 87ec53368e8f82..9a7a5d96788031 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -903,7 +903,7 @@ StatusOr> GetStreamExecutorGpuClient( absl::Mutex mu; if (enable_mock_nccl) { kv_get = [&device_maps, &mu, &num_nodes]( - const std::string& k, + std::string_view k, absl::Duration timeout) -> xla::StatusOr { std::string result; { @@ -929,8 +929,8 @@ StatusOr> GetStreamExecutorGpuClient( } return result; }; - kv_put = [&device_maps, &mu](const std::string& k, - const std::string& v) -> xla::Status { + kv_put = [&device_maps, &mu](std::string_view k, + std::string_view v) -> xla::Status { { absl::MutexLock lock(&mu); device_maps[k] = v; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index a2141f31162589..c6c018eed9ed34 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -542,7 +543,7 @@ TEST(StreamExecutorGpuClientTest, DistributeInit) { absl::flat_hash_map kv_store; absl::Mutex mu; PjRtClient::KeyValueGetCallback kv_get = - [&kv_store, &mu](const std::string& k, + [&kv_store, &mu](std::string_view k, absl::Duration timeout) -> xla::StatusOr { absl::Duration wait_interval = absl::Milliseconds(10); int num_retry = timeout / wait_interval; @@ -560,8 +561,7 @@ TEST(StreamExecutorGpuClientTest, DistributeInit) { absl::StrCat(k, " is not found in the kv store.")); }; PjRtClient::KeyValuePutCallback kv_put = - [&kv_store, &mu](const std::string& k, - const std::string& v) -> xla::Status { + [&kv_store, &mu](std::string_view k, std::string_view v) -> xla::Status { { absl::MutexLock lock(&mu); kv_store[k] = v; diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 69c8ebda0cbd07..2bf0e62af4db63 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -466,9 +467,9 @@ class PjRtClient { // Subclasses of PjRtClient can optionally take these callbacks in their // constructors. using KeyValueGetCallback = std::function( - const std::string& key, absl::Duration timeout)>; - using KeyValuePutCallback = std::function; + std::string_view key, absl::Duration timeout)>; + using KeyValuePutCallback = + std::function; PjRtClient() = default; explicit PjRtClient(std::unique_ptr diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index a4db3deef34730..9ece9636a1d7d7 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -546,14 +547,13 @@ static void Init(py::module_& m) { // Use the plugin name as key prefix. std::string key_prefix = "gpu:"; kv_get = [distributed_client, key_prefix]( - const std::string& k, + std::string_view k, absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; kv_put = [distributed_client, key_prefix]( - const std::string& k, - const std::string& v) -> xla::Status { + std::string_view k, std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; @@ -584,13 +584,13 @@ static void Init(py::module_& m) { PjRtClient::KeyValueGetCallback kv_get = nullptr; PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { - kv_get = [distributed_client, platform_name](const std::string& k, + kv_get = [distributed_client, platform_name](std::string_view k, absl::Duration timeout) { return distributed_client->BlockingKeyValueGet( absl::StrCat(platform_name, ":", k), timeout); }; - kv_put = [distributed_client, platform_name](const std::string& k, - const std::string& v) { + kv_put = [distributed_client, platform_name](std::string_view k, + std::string_view v) { return distributed_client->KeyValueSet( absl::StrCat(platform_name, ":", k), v); }; diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 91414e4f6871a1..490a04ca06b534 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -273,15 +273,15 @@ StatusOr> FunctionalHloRunner::CreateGpuClient( xla::PjRtClient::KeyValueGetCallback kv_get = [distributed_client]( - const std::string& k, + std::string_view k, absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet(absl::StrCat(kKeyPrefix, k), timeout); }; xla::PjRtClient::KeyValuePutCallback kv_put = - [distributed_client](const std::string& k, - const std::string& v) -> xla::Status { + [distributed_client](std::string_view k, + std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(kKeyPrefix, k), v); }; From c76be80a4b3f834627caafb7e7578fbf29934425 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Tue, 14 Nov 2023 18:41:40 -0800 Subject: [PATCH 098/391] Factor out type functions from uniform_quantized_stableho_to_tfl_pass These functions are general enough to be reused in other files. Also adds a safeguard for pattern matching JAX quantizer produced patterns. PiperOrigin-RevId: 582504409 --- tensorflow/compiler/mlir/lite/stablehlo/BUILD | 1 + .../uniform-quantized-stablehlo-to-tfl.mlir | 28 +++- ...uniform_quantized_stablehlo_to_tfl_pass.cc | 103 ++++--------- .../stablehlo/uniform_quantized_types.cc | 91 ++++++++++++ .../stablehlo/uniform_quantized_types.h | 19 +++ .../stablehlo/uniform_quantized_types_test.cc | 135 +++++++++++++++++- 6 files changed, 302 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 826836a8acf9f7..152b48b1f9043a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -454,6 +454,7 @@ cc_library( deps = [ ":passes_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 15b3e37326cfe0..706106dad3c24b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -24,7 +24,7 @@ func.func @uniform_quantize_op_quantized_input(%arg: tensor<2x2x!quant.uniform) -> tensor<2x // ----- -// Tests that the pattern doesn't match when the output tensor's sotrage type +// Tests that the pattern doesn't match when the output tensor's storage type // is i32. i32 storage type for quantized type is not compatible with // `tfl.quantize`. @@ -237,6 +237,30 @@ func.func @dot_general_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.unifo // ----- +// Tests that the pattern does not match when the output tensor's storage +// type is i32. Currently we support qi8, qi8 -> qi8 only for GEMM ops that +// are quantized upstream. Other cases should be handled by regular quantized +// stablehlo.dot_general case. + +// CHECK-LABEL: dot_general_op_i32_output +func.func @dot_general_op_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + %1 = "stablehlo.dot_general"(%arg0, %0) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> + return %1 : tensor<1x2x3x5x!quant.uniform> +} +// CHECK: stablehlo.dot_general +// CHECK-NOT: tfl.quantize + +// ----- + // Test full integer quantized dot_general with activation as RHS // CHECK-LABEL: dot_general_full_integer_activation_rhs diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 18070fe59134e3..cba64e220f82f0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" #define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" @@ -46,6 +47,8 @@ namespace mlir { namespace odml { namespace { +using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; +using ::mlir::quant::IsI8F32UniformQuantizedType; using ::mlir::quant::QuantizedType; using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; @@ -75,69 +78,6 @@ bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { return true; } -// Returns true iff the storage type of `quantized_type` is 8-bit integer. -bool IsStorageTypeI8(QuantizedType quantized_type) { - const Type storage_type = quantized_type.getStorageType(); - return storage_type.isInteger(/*width=*/8); -} - -// Returns true iff the expressed type of `quantized_type` is f32. -bool IsExpressedTypeF32(QuantizedType quantized_type) { - const Type expressed_type = quantized_type.getExpressedType(); - return expressed_type.isa(); -} - -// Returns true iff `type` is a uniform quantized type whose storage type is -// 8-bit integer and expressed type is f32. -bool IsI8F32UniformQuantizedType(const Type type) { - auto quantized_type = type.dyn_cast_or_null(); - if (!quantized_type) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized type. Got: " << type << ".\n"); - return false; - } - - if (!IsStorageTypeI8(quantized_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " - << quantized_type << ".\n"); - return false; - } - - if (!IsExpressedTypeF32(quantized_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " - << quantized_type << ".\n"); - return false; - } - - return true; -} - -// Returns true iff `type` is a uniform quantized per-axis (per-channel) type -// whose storage type is 8-bit integer and expressed type is f32. -bool IsI8F32UniformQuantizedPerAxisType(const Type type) { - auto quantized_per_axis_type = - type.dyn_cast_or_null(); - if (!quantized_per_axis_type) { - LLVM_DEBUG(llvm::dbgs() - << "Expected a uniform quantized type. Got: " << type << ".\n"); - return false; - } - - if (!IsStorageTypeI8(quantized_per_axis_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " - << quantized_per_axis_type << ".\n"); - return false; - } - - if (!IsExpressedTypeF32(quantized_per_axis_type)) { - LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " - << quantized_per_axis_type << ".\n"); - return false; - } - - return true; -} - // Bias scales for matmul-like ops should be input scale * filter scale. Here it // is assumed that the input is per-tensor quantized and filter is per-channel // quantized. @@ -257,7 +197,7 @@ class RewriteUniformDequantizeOp // * Not a depthwise convolution. // * Does not consider bias add fusion. // TODO: b/294771704 - Support bias quantization. -class RewriteQuantizedConvolutionOp +class RewriteUpstreamQuantizedConvolutionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -654,7 +594,7 @@ class RewriteQuantizedConvolutionOp // // TODO: b/293650675 - Relax the conversion condition to support dot_general in // general. -class RewriteFullIntegerQuantizedDotGeneralOp +class RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -662,7 +602,7 @@ class RewriteFullIntegerQuantizedDotGeneralOp static LogicalResult MatchLhs( Value lhs, stablehlo::DotDimensionNumbersAttr dimension_numbers) { auto lhs_type = lhs.getType().cast(); - if (!(IsI8F32UniformQuantizedType(lhs_type.getElementType()))) { + if (!IsI8F32UniformQuantizedType(lhs_type.getElementType())) { LLVM_DEBUG(llvm::dbgs() << "Expected a per-tensor uniform " "quantized (i8->f32) input for dot_general. Got: " @@ -704,7 +644,7 @@ class RewriteFullIntegerQuantizedDotGeneralOp } auto rhs_type = rhs.getType().cast(); - if (!(IsI8F32UniformQuantizedType(rhs_type.getElementType()))) { + if (!IsI8F32UniformQuantizedType(rhs_type.getElementType())) { LLVM_DEBUG(llvm::dbgs() << "Expected a per-tensor uniform " "quantized (i8->f32) weight for dot_general. Got: " @@ -714,6 +654,19 @@ class RewriteFullIntegerQuantizedDotGeneralOp return success(); } + static LogicalResult MatchOutput( + Value output, stablehlo::DotDimensionNumbersAttr dimension_numbers) { + auto output_type = output.getType().cast(); + if (!IsI8F32UniformQuantizedType(output_type.getElementType())) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a per-tensor uniform " + "quantized (i8->f32) output for dot_general. Got: " + << output_type << "\n"); + return failure(); + } + return success(); + } + LogicalResult match(stablehlo::DotGeneralOp op) const override { stablehlo::DotDimensionNumbersAttr dimension_numbers = op.getDotDimensionNumbers(); @@ -746,6 +699,12 @@ class RewriteFullIntegerQuantizedDotGeneralOp return failure(); } + if (failed(MatchOutput(op.getResult(), dimension_numbers))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general.\n"); + return failure(); + } + return success(); } @@ -816,10 +775,10 @@ class RewriteFullIntegerQuantizedDotGeneralOp // * Does not consider bias add fusion. // // TODO: b/294983811 - Merge this pattern into -// `RewriteFullIntegerQuantizedDotGeneralOp`. +// `RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp`. // TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands // is not specified in the StableHLO dialect. Update the spec to allow this. -class RewriteQuantizedDotGeneralOpToTflFullyConnectedOp +class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1071,9 +1030,9 @@ void UniformQuantizedStablehloToTflPass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); + RewriteUpstreamQuantizedConvolutionOp, + RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp, + RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc index bfd9de9ca60d25..5c1d362110799b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc @@ -16,13 +16,17 @@ limitations under the License. #include +#include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#define DEBUG_TYPE "uniform-quantized-types" + namespace mlir { namespace quant { @@ -60,5 +64,92 @@ UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( /*storageTypeMax=*/llvm::maxIntN(8)); } +bool IsStorageTypeI8(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/8); +} + +bool IsStorageTypeI32(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/32); +} + +bool IsExpressedTypeF32(const QuantizedType quantized_type) { + const Type expressed_type = quantized_type.getExpressedType(); + return expressed_type.isa(); +} + +bool IsI8F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + type.dyn_cast_or_null(); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI8F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + type.dyn_cast_or_null(); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + type.dyn_cast_or_null(); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h index 68774b2ecb876b..d938c3a235343a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -54,6 +55,24 @@ UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( Location loc, MLIRContext& context, ArrayRef scales, ArrayRef zero_points, int quantization_dimension); +bool IsStorageTypeI8(QuantizedType quantized_type); + +bool IsStorageTypeI32(QuantizedType quantized_type); + +bool IsExpressedTypeF32(QuantizedType quantized_type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(Type type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedType(Type type); + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc index 0888bfa8d22908..ec91997fb9dc14 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -30,6 +31,7 @@ namespace quant { namespace { using ::testing::ElementsAreArray; +using ::testing::NotNull; class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { protected: @@ -64,7 +66,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { EXPECT_TRUE(quantized_type.isSigned()); } -TEST_F(CreateI8F32UniformQuantizedTypeTest, SotrageTypeMinMaxEqualToI8MinMax) { +TEST_F(CreateI8F32UniformQuantizedTypeTest, StrageTypeMinMaxEqualToI8MinMax) { const UniformQuantizedType quantized_type = CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, /*scale=*/1.0, /*zero_point=*/0); @@ -216,6 +218,137 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, EXPECT_THAT(quantized_type.getZeroPoints(), ElementsAreArray({98, 99})); } +class IsI8F32UniformQuantizedTypeTest : public ::testing::Test { + protected: + IsI8F32UniformQuantizedTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI8F32UniformQuantizedTypeTest, IsI8F32UniformQuantizedType) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsI8F32UniformQuantizedType(qi8_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, IsQuantizedType) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, IsStorageTypeI8) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI8(qi8_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, IsExpressedTypeF32) { + const UniformQuantizedType qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi8_type)); +} + +class IsI8F32UniformQuantizedPerAxisTypeTest : public ::testing::Test { + protected: + IsI8F32UniformQuantizedPerAxisTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, + IsI8F32UniformQuantizedPerAxisType) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsI8F32UniformQuantizedPerAxisType(qi8_per_axis_type)); + EXPECT_FALSE(IsI8F32UniformQuantizedType(qi8_per_axis_type)); +} + +TEST_F(IsI8F32UniformQuantizedTypeTest, IsQuantizedPerAxisType) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), + NotNull()); +} + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsStorageTypeI8) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI8(qi8_per_axis_type)); +} + +TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, IsExpressedTypeF32) { + const UniformQuantizedPerAxisType qi8_per_axis_type = + quant::UniformQuantizedPerAxisType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scales=*/{1.0}, + /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/0, + /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi8_per_axis_type)); +} + +class IsI32F32UniformQuantizedTypeTest : public ::testing::Test { + protected: + IsI32F32UniformQuantizedTypeTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsI32F32UniformQuantizedTypeTest, IsI32F32UniformQuantizedType) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, IsQuantizedType) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, IsStorageTypeI32) { + const UniformQuantizedType qi32_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI32Type(), builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsStorageTypeI32(qi32_type)); +} + +TEST_F(IsI32F32UniformQuantizedTypeTest, IsExpressedTypeF32) { + const UniformQuantizedType qi32_per_axis_type = + quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getI8Type(), builder_.getF32Type(), + /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); +} + } // namespace } // namespace quant } // namespace mlir From 3360f0dbb4d7622387d3bf9f9ecfebc3a8f99366 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 14 Nov 2023 20:38:42 -0800 Subject: [PATCH 099/391] [xla:gpu] Add name to custom kernels to improve logging and debugging PiperOrigin-RevId: 582526517 --- third_party/xla/xla/service/gpu/kernel_thunk.cc | 8 +++----- third_party/xla/xla/service/gpu/kernel_thunk.h | 3 --- third_party/xla/xla/service/gpu/kernels/BUILD | 1 + .../xla/xla/service/gpu/kernels/custom_kernel.cc | 15 +++++++++++++-- .../xla/xla/service/gpu/kernels/custom_kernel.h | 9 +++++++-- .../service/gpu/kernels/cutlass_gemm_kernel.cu.cc | 4 ++-- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.cc b/third_party/xla/xla/service/gpu/kernel_thunk.cc index c5fcc33231607c..58dfab7b11320a 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/kernel_thunk.cc @@ -188,10 +188,7 @@ CustomKernelThunk::CustomKernelThunk( } std::string CustomKernelThunk::ToStringExtra(int indent) const { - // TODO(ezhulenev): Add `name` to a custom kernel and add pretty printing for - // custom kernel launch dimensions. - return absl::StrFormat(", kernel = %s, launch dimensions = %s", "", - ""); + return custom_kernel_.ToString(); } Status CustomKernelThunk::Initialize(se::StreamExecutor* executor, @@ -217,7 +214,8 @@ Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { return kernel_cache_[executor].get(); }(); - VLOG(3) << "Launching " << kernel->name(); + VLOG(3) << "Launching " << custom_kernel_.ToString() << " as device kernel " + << kernel->name(); absl::InlinedVector buffer_args; for (const BufferAllocation::Slice& arg : args_) { diff --git a/third_party/xla/xla/service/gpu/kernel_thunk.h b/third_party/xla/xla/service/gpu/kernel_thunk.h index ef857b52f9ccff..7f8ff1331324e4 100644 --- a/third_party/xla/xla/service/gpu/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/kernel_thunk.h @@ -152,9 +152,6 @@ class CustomKernelThunk : public Thunk { // args_[i] is written iff (written_[i] == true). std::vector written_; - // mlir::Value(s) corresponding to the buffer slice arguments. - std::vector values_; - CustomKernel custom_kernel_; // Loaded kernels for each `StreamExecutor`. diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 6c134e887a7af1..bb68c506b01f34 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -54,6 +54,7 @@ cc_library( deps = [ "//xla:statusor", "//xla/stream_executor", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc index 0bc2de781bbe3b..b9451eb6ff0154 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc @@ -16,17 +16,21 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_kernel.h" #include +#include #include +#include "absl/strings/str_format.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" namespace xla::gpu { -CustomKernel::CustomKernel(se::MultiKernelLoaderSpec kernel_spec, +CustomKernel::CustomKernel(std::string name, + se::MultiKernelLoaderSpec kernel_spec, se::BlockDim block_dims, se::ThreadDim thread_dims, size_t shared_memory_bytes) - : kernel_spec_(std::move(kernel_spec)), + : name_(std::move(name)), + kernel_spec_(std::move(kernel_spec)), block_dims_(block_dims), thread_dims_(thread_dims), @@ -44,4 +48,11 @@ size_t CustomKernel::shared_memory_bytes() const { return shared_memory_bytes_; } +std::string CustomKernel::ToString() const { + return absl::StrFormat( + "%s grid: [%d, %d, %d] threads: [%d, %d, %d] shared_memory: %d bytes", + name_, block_dims_.x, block_dims_.y, block_dims_.z, thread_dims_.x, + thread_dims_.y, thread_dims_.z, shared_memory_bytes_); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h index 85b1a2a9e639c3..32f65b76a7a49c 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_KERNELS_CUSTOM_KERNEL_H_ #include +#include #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -42,8 +43,9 @@ namespace se = ::stream_executor; // NOLINT // define if it has to be zeroed first. class CustomKernel { public: - CustomKernel(se::MultiKernelLoaderSpec kernel_spec, se::BlockDim block_dims, - se::ThreadDim thread_dims, size_t shared_memory_bytes); + CustomKernel(std::string name, se::MultiKernelLoaderSpec kernel_spec, + se::BlockDim block_dims, se::ThreadDim thread_dims, + size_t shared_memory_bytes); const se::MultiKernelLoaderSpec& kernel_spec() const; @@ -53,7 +55,10 @@ class CustomKernel { size_t shared_memory_bytes() const; + std::string ToString() const; + private: + std::string name_; se::MultiKernelLoaderSpec kernel_spec_; se::BlockDim block_dims_; se::ThreadDim thread_dims_; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc index 6211d9eaf965e1..07c7601af55987 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc @@ -99,8 +99,8 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, kernel_spec.AddInProcessSymbol( reinterpret_cast(cutlass::Kernel), "cutlass_gemm"); - return CustomKernel(std::move(kernel_spec), block_dims, thread_dims, - shared_memory_bytes); + return CustomKernel("cutlass_gemm:f32<-f32xf32", std::move(kernel_spec), + block_dims, thread_dims, shared_memory_bytes); } } // namespace xla::gpu::kernel From 6b43cb7462ea7235da49f6629d18e2f1afabddbe Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Tue, 14 Nov 2023 20:51:41 -0800 Subject: [PATCH 100/391] Move the calibration step into `py_function_lib`. This change moves the python-implemented calibration functions into `PyFunctionLibrary` so that it can be used within the c++ environment. As a consequence, the calibration step becomes a part of the `pywrap_quantize_model.quantize_ptq_model_pre_calibration`. PiperOrigin-RevId: 582528682 --- .../mlir/quantization/tensorflow/python/BUILD | 19 +- .../tensorflow/python/py_function_lib.h | 23 + .../tensorflow/python/py_function_lib.py | 540 +++++++++++++++++- .../tensorflow/python/pywrap_function_lib.cc | 21 +- .../tensorflow/python/pywrap_function_lib.pyi | 12 + .../python/pywrap_quantize_model.cc | 12 +- .../python/pywrap_quantize_model.pyi | 2 + .../tensorflow/python/quantize_model.py | 473 +-------------- .../tensorflow/python/type_casters.h | 13 +- 9 files changed, 634 insertions(+), 481 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index a7bc257909d8b2..6e39a8237575dc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -104,10 +104,25 @@ pytype_strict_library( visibility = ["//visibility:private"], deps = [ ":pywrap_function_lib", + ":representative_dataset", ":save_model", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_algorithm", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_py", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:pywrap_calibration", "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:wrap_function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/saved_model:load", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/trackable:autotrackable", + "//tensorflow/python/types:core", + "//third_party/py/numpy", "@absl_py//absl/logging", ], ) @@ -154,9 +169,11 @@ cc_library( visibility = ["//visibility:private"], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", + "@pybind11", ], ) @@ -169,8 +186,8 @@ tf_python_pybind_extension( ":py_function_lib", ":type_casters", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", - "//tensorflow/python/lib/core:pybind11_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", "@pybind11", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h index e0a2d85c15fde1..4951e7746e923a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -20,7 +20,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "pybind11/pytypes.h" // from @pybind11 #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow::quantization { @@ -70,6 +72,27 @@ class PyFunctionLibrary { // pywrap_function_lib.pyi:save_exported_model, // py_function_lib.py:save_exported_model, // ) + + // Runs calibration on a model saved at `saved_model_path`. `exported_model` + // should be the corresponding exported model resulting from the + // pre-calibration step. `representative_dataset` is a python object of type + // `RepresentativeDatasetOrMapping`, which is used to run the calibration. + // + // Returns the updated exported model where the collected calibration + // statistics are added to `CustomAggregator` nodes at the `min` and `max` + // attributes. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange + virtual ExportedModel RunCalibration( + absl::string_view saved_model_path, const ExportedModel& exported_model, + const QuantizationOptions& quantization_options, + pybind11::object representative_dataset) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:run_calibration, + // py_function_lib.py:run_calibration, + // ) }; } // namespace tensorflow::quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py index 1051d9504e9fcb..1abf80417d869b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.py @@ -13,17 +13,33 @@ # limitations under the License. # ============================================================================== """Defines a wrapper class for overridden python method definitions.""" +from collections.abc import Callable, Collection, Mapping, Sequence from typing import Optional import uuid from absl import logging from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.eager import wrap_function +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_conversion from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import load +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.trackable import autotrackable +from tensorflow.python.types import core # Name of the saved model assets directory. _ASSETS_DIR = 'assets' @@ -80,6 +96,474 @@ def _copy_assets(src_path: str, dst_path: str) -> None: ) +def _validate_representative_dataset( + representative_dataset: rd.RepresentativeDatasetOrMapping, + signature_keys: Collection[str], +) -> None: + """Validates the representative dataset, based on the signature keys. + + Representative dataset can be provided in two different forms: a single + instance of `RepresentativeDataset` or a map of signature key to the + corresponding `RepresentativeDataset`. These have a relationship with + `signature_keys`. + + This function validates the following conditions: + * If `len(signature_keys) > 1`, then `representative_dataset` should be a + mapping where the keys exactly match the elements in `signature_keys`. + * If `len(signature_keys) == 1`, then both a mapping and a single instance of + `RepresentativeDataset` are allowed. + * This function also assumes `len(signature_keys) > 0`. + + Args: + representative_dataset: A `RepresentativeDataset` or a map of string to + `RepresentativeDataset` to be validated. + signature_keys: A collection of strings that contains the signature keys, + each identifying a `SignatureDef`. + + Raises: + ValueError: Iff `representative_dataset` does not satisfy the conditions + above. + """ + if isinstance(representative_dataset, Mapping): + if set(signature_keys) != set(representative_dataset.keys()): + raise ValueError( + 'The signature keys and the keys of representative dataset map ' + f'do not match. Signature keys: {set(signature_keys)}, ' + f'representative dataset map: {set(representative_dataset.keys())}.' + ) + else: + if len(signature_keys) > 1: + raise ValueError( + 'Representative dataset is not a mapping ' + f'(got: {type(representative_dataset)}), ' + 'but there is more than one signature key provided. ' + 'Please provide a map of {signature_key -> dataset} ' + 'with more than one signature key.' + ) + + +def _replace_tensors_by_numpy_ndarrays( + repr_ds_map: rd.RepresentativeDatasetMapping, +) -> None: + """Replaces tf.Tensors by their evaluated numpy arrays. + + This assumes that tf.Tensors in representative samples are created in the + default Graph. It will raise an error if tensors are created in a different + graph. + + Args: + repr_ds_map: SignatureDef key -> RepresentativeDataset mapping. + """ + with session.Session() as sess: + for signature_def_key in repr_ds_map: + # Replaces the dataset with a new dataset where tf.Tensors are replaced + # by their evaluated values. + ds = repr_ds_map[signature_def_key] + repr_ds_map[signature_def_key] = rd.replace_tensors_by_numpy_ndarrays( + ds, sess + ) + + +def _create_sample_validator( + expected_input_keys: Collection[str], +) -> Callable[[rd.RepresentativeSample], rd.RepresentativeSample]: + """Creates a validator function for a representative sample. + + Args: + expected_input_keys: Input keys (keyword argument names) that the function + the sample will be used for is expecting to receive. + + Returns: + A callable that validates a `RepresentativeSample`. + """ + + def validator( + sample: rd.RepresentativeSample, + ) -> rd.RepresentativeSample: + """Validates a single instance of representative sample. + + This provides a simple check for `sample` that this is a mapping of + {input_key: input_value}. + + Args: + sample: A `RepresentativeSample` to validate. + + Returns: + `sample` iff it is valid. + + Raises: + ValueError: iff the sample isn't an instance of `Mapping`. + KeyError: iff the sample does not have the set of input keys that match + the input keys of the function. + """ + if not isinstance(sample, Mapping): + raise ValueError( + 'Invalid representative sample type. Provide a mapping ' + '(usually a dict) of {input_key: input_value}. ' + f'Got type: {type(sample)} instead.' + ) + + if set(sample.keys()) != expected_input_keys: + raise KeyError( + 'Invalid input keys for representative sample. The function expects ' + f'input keys of: {set(expected_input_keys)}. ' + f'Got: {set(sample.keys())}. Please provide correct input keys for ' + 'representative samples.' + ) + + return sample + + return validator + + +# TODO(b/249918070): Implement a progress bar. +def _log_sample_num_for_calibration( + representative_dataset: rd.RepresentativeDataset, +) -> rd.RepresentativeDataset: + """Logs the sample number for calibration. + + If in debug logging level, the "sample number / total num samples" is logged + for every 5 iterations. + + This is often useful when tracking the progress of the calibration step which + is often slow and may look stale if there's no logs being printed. + + Args: + representative_dataset: The representative dataset. + + Yields: + The representative samples from `representative_dataset` without any + modification. + """ + num_samples: Optional[int] = rd.get_num_samples(representative_dataset) + if num_samples is None: + total_num_samples = '?' + logging.info('Representative dataset size unknown.') + else: + total_num_samples = str(num_samples) + logging.info('Using representative dataset of size: %s', total_num_samples) + + sample_num = 0 + for sample in representative_dataset: + sample_num += 1 + + # Log the sample number for every 5 iterations. + logging.log_every_n( + logging.DEBUG, + 'Running representative sample for calibration: %d / %s', + 5, + sample_num, + total_num_samples, + ) + yield sample + + logging.info( + 'Running representative samples complete: %d / %s', + sample_num, + total_num_samples, + ) + + +def _run_function_for_calibration_graph_mode( + sess: session.Session, + signature_def: meta_graph_pb2.SignatureDef, + representative_dataset: rd.RepresentativeDataset, +) -> None: + """Runs the representative dataset through a function for calibration. + + NOTE: This is intended to be run in graph mode (TF1). + + The function is identified by the SignatureDef. + + Args: + sess: The Session object to run the function in. + signature_def: A SignatureDef that identifies a function by specifying the + inputs and outputs. + representative_dataset: The representative dataset to run through the + function. + """ + output_tensor_names = [ + output_tensor_info.name + for output_tensor_info in signature_def.outputs.values() + ] + + sample_validator = _create_sample_validator( + expected_input_keys=signature_def.inputs.keys() + ) + + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): + # Create a mapping from input tensor name to the input tensor value. + # ex) "Placeholder:0" -> [0, 1, 2] + feed_dict = rd.create_feed_dict_from_input_data(sample, signature_def) + sess.run(output_tensor_names, feed_dict=feed_dict) + + +def _run_graph_for_calibration_graph_mode( + model_dir: str, + tags: Collection[str], + representative_dataset_map: rd.RepresentativeDatasetMapping, +) -> None: + """Runs the graph for calibration in graph mode. + + This function assumes _graph mode_ (used when legacy TF1 is used or when eager + mode is explicitly disabled) when running the graph. This step is used in + order to collect the statistics in CustomAggregatorOp for quantization using + the representative dataset for the actual data provided for inference. + + Args: + model_dir: Path to SavedModel directory. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel. + representative_dataset_map: A map where signature keys are mapped to + corresponding representative datasets. + + Raises: + ValueError: When running the function with the representative dataset fails. + """ + # Replace tf.Tensors by numpy ndarrays in order to reuse the samples in a + # different graph when running the calibration. + _replace_tensors_by_numpy_ndarrays(representative_dataset_map) + + # Run the calibration in a new graph to avoid name collision, which could + # happen when the same model is loaded multiple times in the default graph. + with ops.Graph().as_default(), session.Session() as sess: + meta_graph: meta_graph_pb2.MetaGraphDef = loader_impl.load( + sess, tags, export_dir=model_dir + ) + + for signature_key, repr_ds in representative_dataset_map.items(): + sig_def = meta_graph.signature_def[signature_key] + + try: + _run_function_for_calibration_graph_mode( + sess, signature_def=sig_def, representative_dataset=repr_ds + ) + except Exception as ex: + raise ValueError( + 'Failed to run representative dataset through the ' + f'function with the signature key: {signature_key}.' + ) from ex + + +def _convert_values_to_tf_tensors( + sample: rd.RepresentativeSample, +) -> Mapping[str, core.Tensor]: + """Converts TensorLike values of `sample` to Tensors. + + Creates a copy of `sample`, where each value is converted to Tensors + unless it is already a Tensor. + The values are not converted in-place (i.e. `sample` is not mutated). + + Args: + sample: A representative sample, which is a map of {name -> tensorlike + value}. + + Returns: + Converted map of {name -> tensor}. + """ + tensor_mapping = {} + for name, tensorlike_value in sample.items(): + if isinstance(tensorlike_value, core.Tensor): + tensor_value = tensorlike_value + else: + tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch( + tensorlike_value + ) + + tensor_mapping[name] = tensor_value + + return tensor_mapping + + +def _run_function_for_calibration_eager_mode( + func: wrap_function.WrappedFunction, + representative_dataset: rd.RepresentativeDataset, +) -> None: + """Runs the representative dataset through a function for calibration. + + NOTE: This is intended to be run in eager mode (TF2). + + Args: + func: The function to run the representative samples through. + representative_dataset: Representative dataset used for calibration. The + input keys and input values of the representative samples should match the + keyword arguments of `func`. + """ + _, keyword_args = func.structured_input_signature + sample_validator = _create_sample_validator( + expected_input_keys=keyword_args.keys() + ) + + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): + # Convert any non-Tensor values from the sample to Tensors. + # This conversion is required because the model saved in `model_dir` is + # saved using TF1 SavedModelBuilder, which doesn't save the + # SavedObjectGraph. + # TODO(b/236795224): Remove the need for this conversion by keeping the + # FunctionSpec (object graph) in the SavedModel. Related: b/213406917. + func_kwargs = _convert_values_to_tf_tensors(sample) + func(**func_kwargs) + + +def _run_graph_for_calibration_eager_mode( + model_dir: str, + tags: Collection[str], + representative_dataset_map: rd.RepresentativeDatasetMapping, +) -> None: + """Runs the graph for calibration in eager mode. + + This function assumes _eager mode_ (enabled in TF2 by default) when running + the graph. This step is used in order to collect the statistics in + CustomAggregatorOp for quantization using the representative dataset for the + actual data provided for inference. + + Args: + model_dir: Path to SavedModel directory. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel. + representative_dataset_map: A map where signature keys are mapped to + corresponding representative datasets. + + Raises: + ValueError: When running the function with the representative dataset fails. + """ + root: autotrackable.AutoTrackable = load.load(model_dir, tags) + for signature_key, repr_ds in representative_dataset_map.items(): + try: + _run_function_for_calibration_eager_mode( + func=root.signatures[signature_key], representative_dataset=repr_ds + ) + except Exception as ex: + raise ValueError( + 'Failed to run representative dataset through the ' + f'function with the signature key: {signature_key}.' + ) from ex + + +def _run_graph_for_calibration( + float_model_dir: str, + signature_keys: Sequence[str], + tags: Collection[str], + representative_dataset: rd.RepresentativeDatasetOrMapping, + force_graph_mode_calibration: bool, +) -> None: + """Runs the graph for calibration using representative datasets. + + Args: + float_model_dir: Path to the model to calibrate. + signature_keys: Sequence of keys identifying SignatureDef containing inputs + and outputs. + tags: Collection of tags identifying the MetaGraphDef within the SavedModel + to analyze. + representative_dataset: An iterator that returns a dictionary of {input_key: + input_value} or a mapping from signature keys to such iterators. When + `signature_keys` contains more than one signature key, + `representative_datsaet` should be a mapping that maps each signature keys + to the corresponding representative dataset. + force_graph_mode_calibration: If set to true, it forces calibration in graph + model instead of eager mode when the context is in eager mode. + + Raises: + ValueError iff: + * The representative dataset format is invalid. + * It fails to run the functions using the representative datasets. + """ + try: + _validate_representative_dataset(representative_dataset, signature_keys) + except Exception as ex: + raise ValueError('Invalid representative dataset.') from ex + + # If `representative_dataset` is not a mapping, convert to a mapping for the + # following functions to handle representative datasets more conveniently. + representative_dataset_map = representative_dataset + if not isinstance(representative_dataset, Mapping): + # `signature_keys` is guaranteed to have only one element after the + # validation. + representative_dataset_map = {signature_keys[0]: representative_dataset} + + try: + if context.executing_eagerly() and not force_graph_mode_calibration: + logging.info('Calibration step is executed in eager mode.') + _run_graph_for_calibration_eager_mode( + float_model_dir, tags, representative_dataset_map + ) + else: + logging.info('Calibration step is executed in graph mode.') + _run_graph_for_calibration_graph_mode( + float_model_dir, tags, representative_dataset_map + ) + except Exception as ex: + raise ValueError( + 'Failed to run graph for post-training quantization calibration.' + ) from ex + + logging.info('Calibration step complete.') + + +def _get_min_max_from_calibrator( + node_id: bytes, + calib_opts: quantization_options_pb2.CalibrationOptions, +) -> tuple[float, float]: + """Calculate min and max from statistics using calibration options. + + Args: + node_id: bytes of node id. + calib_opts: Calibration options used for calculating min and max. + + Returns: + (min_value, max_value): Min and max calculated using calib_opts. + + Raises: + ValueError: Unsupported calibration method is given. + """ + statistics: calibration_statistics_pb2.CalibrationStatistics = ( + pywrap_calibration.get_statistics_from_calibrator(node_id) + ) + min_value, max_value = calibration_algorithm.get_min_max_value( + statistics, calib_opts + ) + return min_value, max_value + + +def _add_calibration_statistics( + graph_def: graph_pb2.GraphDef, + calib_opts: quantization_options_pb2.CalibrationOptions, +) -> None: + """Adds calibration statistics to the graph def. + + This function must be run after running the graph with a representative + dataset. Retrieves calibration statistics from the global calibrator and adds + them to the corresponding nodes as attributes. + + Args: + graph_def: GraphDef to add calibration statistics to. + calib_opts: Calibration options to calculate min and max. + """ + for function_def in graph_def.library.function: + for node_def in function_def.node_def: + if node_def.op != 'CustomAggregator': + continue + + node_id = node_def.attr['id'].s + try: + min_value, max_value = _get_min_max_from_calibrator(node_id, calib_opts) + pywrap_calibration.clear_data_from_calibrator(node_id) + + node_def.attr['min'].f = min_value + node_def.attr['max'].f = max_value + except ValueError: + logging.warning( + ( + 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' + 'min or max values. Parts of this function are not quantized.' + ), + node_id.decode('utf-8'), + function_def.signature.name, + ) + + class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary): """Wrapper class for overridden python method definitions. @@ -92,7 +576,7 @@ def assign_ids_to_custom_aggregator_ops( self, exported_model_serialized: bytes, ) -> bytes: - # LINT.ThenChange(py_function_lib.h:assign_ids_to_custom_aggregator_ops) + # LINT.ThenChange(py_function_lib.h:assign_ids_to_custom_aggregator_ops) """Assigns UUIDs to each CustomAggregator op find in the graph def. Args: @@ -123,7 +607,7 @@ def save_exported_model( tags: set[str], serialized_signature_def_map: dict[str, bytes], ) -> None: - # LINT.ThenChange(py_function_lib.h:save_exported_model) + # LINT.ThenChange(py_function_lib.h:save_exported_model) """Saves `ExportedModel` to `dst_saved_model_path` as a SavedModel. Args: @@ -158,3 +642,55 @@ def save_exported_model( ) _copy_assets(src_saved_model_path, dst_saved_model_path) + + # TODO: b/311097139 - Extract calibration related functions into a separate + # file. + # LINT.IfChange(run_calibration) + def run_calibration( + self, + saved_model_path: str, + exported_model_serialized: bytes, + quantization_options_serialized: bytes, + representative_dataset: rd.RepresentativeDatasetOrMapping, + ) -> bytes: + # LINT.ThenChange(py_function_lib.h:run_calibration) + """Runs calibration and adds calibration statistics to exported model. + + Args: + saved_model_path: Path to the SavedModel to run calibration. + exported_model_serialized: Serialized `ExportedModel` that corresponds to + the SavedModel at `saved_model_path`. + quantization_options_serialized: Serialized `QuantizationOptions`. + representative_dataset: Representative dataset to run calibration. + + Returns: + Updated exported model (serialized) where the collected calibration + statistics are added to `CustomerAggregator` nodes at the `min` and `max` + attributes. + """ + quantization_options = ( + quantization_options_pb2.QuantizationOptions.FromString( + quantization_options_serialized + ) + ) + + # Uses the representative dataset to collect statistics for calibration. + # After this operation, min & max values are stored separately in a global + # CalibratorSingleton instance. + _run_graph_for_calibration( + saved_model_path, + quantization_options.signature_keys, + quantization_options.tags, + representative_dataset, + quantization_options.force_graph_mode_calibration, + ) + + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) + _add_calibration_statistics( + exported_model.graph_def, + quantization_options.calibration_options, + ) + + return exported_model.SerializeToString() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index 48464cef4341b5..24857792653c73 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -23,14 +23,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" -#include "tensorflow/python/lib/core/pybind11_lib.h" + +namespace py = ::pybind11; namespace { using ::tensorflow::SignatureDef; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::QuantizationOptions; // A "trampoline" class that redirects virtual function calls to the python // implementation. @@ -57,6 +60,16 @@ class PyFunctionLibraryTrampoline : public PyFunctionLibrary { dst_saved_model_path, exported_model, src_saved_model_path, tags, signature_def_map); } + + ExportedModel RunCalibration( + const absl::string_view saved_model_path, + const ExportedModel& exported_model, + const QuantizationOptions& quantization_options, + const py::object representative_dataset) const override { + PYBIND11_OVERRIDE_PURE(ExportedModel, PyFunctionLibrary, run_calibration, + saved_model_path, exported_model, + quantization_options, representative_dataset); + } }; } // namespace @@ -72,5 +85,9 @@ PYBIND11_MODULE(pywrap_function_lib, m) { py::arg("dst_saved_model_path"), py::arg("exported_model_serialized"), py::arg("src_saved_model_path"), py::arg("tags"), - py::arg("serialized_signature_def_map")); + py::arg("serialized_signature_def_map")) + .def("run_calibration", &PyFunctionLibrary::RunCalibration, + py::arg("saved_model_path"), py::arg("exported_model_serialized"), + py::arg("quantization_options_serialized"), + py::arg("representative_dataset")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi index 0e0586983b096d..f9b358be53eeee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.pyi @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from typing import Any + class PyFunctionLibrary: # LINT.IfChange(assign_ids_to_custom_aggregator_ops) @@ -30,3 +32,13 @@ class PyFunctionLibrary: serialized_signature_def_map: dict[str, bytes], ) -> None: ... # LINT.ThenChange() + + # LINT.IfChange(run_calibration) + def run_calibration( + self, + saved_model_path: str, + exported_model_serialized: bytes, + quantization_options_serialized: bytes, + representative_dataset: Any, + ) -> bytes: ... + # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 9d05447a4c4bc1..945468f9a178a0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -230,7 +230,8 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { const absl::flat_hash_map& signature_def_map, const absl::flat_hash_map& function_aliases, - const PyFunctionLibrary& py_function_library) + const PyFunctionLibrary& py_function_library, + py::object representative_dataset) -> absl::StatusOr> { // LINT.ThenChange(pywrap_quantize_model.pyi:quantize_ptq_model_pre_calibration) std::unordered_set tags; @@ -252,7 +253,12 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { precalibrated_saved_model_dir, exported_model_ids_assigned, saved_model_path, tags, signature_def_map); - return std::make_pair(exported_model_ids_assigned, + const ExportedModel calibrated_exported_model = + py_function_library.RunCalibration( + precalibrated_saved_model_dir, exported_model_ids_assigned, + quantization_options, representative_dataset); + + return std::make_pair(calibrated_exported_model, precalibrated_saved_model_dir); }, R"pbdoc( @@ -273,7 +279,7 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { py::arg("saved_model_path"), py::arg("quantization_options_serialized"), py::kw_only(), py::arg("signature_keys"), py::arg("signature_def_map_serialized"), py::arg("function_aliases"), - py::arg("py_function_library")); + py::arg("py_function_library"), py::arg("representative_dataset")); m.def( // If the function signature changes, likely its corresponding .pyi type diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi index 6f42691c6cdae4..6ffcdefda905e8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi @@ -16,6 +16,7 @@ from typing import Any from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd # LINT.IfChange(quantize_qat_model) def quantize_qat_model( @@ -67,6 +68,7 @@ def quantize_ptq_model_pre_calibration( signature_def_map_serialized: dict[str, bytes], function_aliases: dict[str, str], py_function_library: py_function_lib.PyFunctionLibrary, + representative_dataset: rd.RepresentativeDatasetOrMapping, ) -> tuple[bytes, str]: ... # LINT.ThenChange() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 715eef9cf41837..125f51e0d0b9e4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -14,9 +14,8 @@ # ============================================================================== """Defines TF Quantization API from SavedModel to SavedModel.""" -import collections.abc import tempfile -from typing import Callable, Collection, Mapping, Optional, Sequence +from typing import Mapping, Optional from absl import logging @@ -31,18 +30,12 @@ from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.client import session -from tensorflow.python.eager import context -from tensorflow.python.eager import wrap_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_conversion from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import load as saved_model_load from tensorflow.python.saved_model import loader_impl as saved_model_loader from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.trackable import autotrackable -from tensorflow.python.types import core from tensorflow.python.util import tf_export # Type aliases for quant_opts_pb2 messages. @@ -90,60 +83,6 @@ def _is_qat_saved_model(saved_model_path: str): return False -def _create_sample_validator( - expected_input_keys: Collection[str], -) -> Callable[ - [repr_dataset.RepresentativeSample], repr_dataset.RepresentativeSample -]: - """Creates a validator function for a representative sample. - - Args: - expected_input_keys: Input keys (keyword argument names) that the function - the sample will be used for is expecting to receive. - - Returns: - A callable that validates a `RepresentativeSample`. - """ - - def validator( - sample: repr_dataset.RepresentativeSample, - ) -> repr_dataset.RepresentativeSample: - """Validates a single instance of representative sample. - - This provides a simple check for `sample` that this is a mapping of - {input_key: input_value}. - - Args: - sample: A `RepresentativeSample` to validate. - - Returns: - `sample` iff it is valid. - - Raises: - ValueError: iff the sample isn't an instance of `Mapping`. - KeyError: iff the sample does not have the set of input keys that match - the input keys of the function. - """ - if not isinstance(sample, collections.abc.Mapping): - raise ValueError( - 'Invalid representative sample type. Provide a mapping ' - '(usually a dict) of {input_key: input_value}. ' - f'Got type: {type(sample)} instead.' - ) - - if set(sample.keys()) != expected_input_keys: - raise KeyError( - 'Invalid input keys for representative sample. The function expects ' - f'input keys of: {set(expected_input_keys)}. ' - f'Got: {set(sample.keys())}. Please provide correct input keys for ' - 'representative samples.' - ) - - return sample - - return validator - - def _serialize_signature_def_map( signature_def_map: _SignatureDefMap, ) -> dict[str, bytes]: @@ -162,364 +101,6 @@ def _serialize_signature_def_map( return signature_def_map_serialized -def _validate_representative_dataset( - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, - signature_keys: Collection[str], -) -> None: - """Validates the representative dataset, based on the signature keys. - - Representative dataset can be provided in two different forms: a single - instance of `RepresentativeDataset` or a map of signature key to the - corresponding `RepresentativeDataset`. These have a relationship with - `signature_keys`. - - This function validates the following conditions: - * If `len(signature_keys) > 1`, then `representative_dataset` should be a - mapping where the keys exactly match the elements in `signature_keys`. - * If `len(signature_keys) == 1`, then both a mapping and a single instance of - `RepresentativeDataset` are allowed. - * This function also assumes `len(signature_keys) > 0`. - - Args: - representative_dataset: A `RepresentativeDataset` or a map of string to - `RepresentativeDataset` to be validated. - signature_keys: A collection of strings that contains the signature keys, - each identifying a `SignatureDef`. - - Raises: - ValueError: Iff `representative_dataset` does not satisfy the conditions - above. - """ - if isinstance(representative_dataset, collections.abc.Mapping): - if set(signature_keys) != set(representative_dataset.keys()): - raise ValueError( - 'The signature keys and the keys of representative dataset map ' - f'do not match. Signature keys: {set(signature_keys)}, ' - f'representative dataset map: {set(representative_dataset.keys())}.' - ) - else: - if len(signature_keys) > 1: - raise ValueError( - 'Representative dataset is not a mapping ' - f'(got: {type(representative_dataset)}), ' - 'but there is more than one signature key provided. ' - 'Please provide a map of {signature_key -> dataset} ' - 'with more than one signature key.' - ) - - -def _convert_values_to_tf_tensors( - sample: repr_dataset.RepresentativeSample, -) -> Mapping[str, core.Tensor]: - """Converts TensorLike values of `sample` to Tensors. - - Creates a copy of `sample`, where each value is converted to Tensors - unless it is already a Tensor. - The values are not converted in-place (i.e. `sample` is not mutated). - - Args: - sample: A representative sample, which is a map of {name -> tensorlike - value}. - - Returns: - Converted map of {name -> tensor}. - """ - tensor_mapping = {} - for name, tensorlike_value in sample.items(): - if isinstance(tensorlike_value, core.Tensor): - tensor_value = tensorlike_value - else: - tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch( - tensorlike_value - ) - - tensor_mapping[name] = tensor_value - - return tensor_mapping - - -# TODO(b/249918070): Implement a progress bar. -def _log_sample_num_for_calibration( - representative_dataset: repr_dataset.RepresentativeDataset, -) -> repr_dataset.RepresentativeDataset: - """Logs the sample number for calibration. - - If in debug logging level, the "sample number / total num samples" is logged - for every 5 iterations. - - This is often useful when tracking the progress of the calibration step which - is often slow and may look stale if there's no logs being printed. - - Args: - representative_dataset: The representative dataset. - - Yields: - The representative samples from `representative_dataset` without any - modification. - """ - num_samples: Optional[int] = repr_dataset.get_num_samples( - representative_dataset - ) - if num_samples is None: - total_num_samples = '?' - logging.info('Representative dataset size unknown.') - else: - total_num_samples = str(num_samples) - logging.info('Using representative dataset of size: %s', total_num_samples) - - sample_num = 0 - for sample in representative_dataset: - sample_num += 1 - - # Log the sample number for every 5 iterations. - logging.log_every_n( - logging.DEBUG, - 'Running representative sample for calibration: %d / %s', - 5, - sample_num, - total_num_samples, - ) - yield sample - - logging.info( - 'Running representative samples complete: %d / %s', - sample_num, - total_num_samples, - ) - - -def _run_function_for_calibration_graph_mode( - sess: session.Session, - signature_def: meta_graph_pb2.SignatureDef, - representative_dataset: repr_dataset.RepresentativeDataset, -) -> None: - """Runs the representative dataset through a function for calibration. - - NOTE: This is intended to be run in graph mode (TF1). - - The function is identified by the SignatureDef. - - Args: - sess: The Session object to run the function in. - signature_def: A SignatureDef that identifies a function by specifying the - inputs and outputs. - representative_dataset: The representative dataset to run through the - function. - """ - output_tensor_names = [ - output_tensor_info.name - for output_tensor_info in signature_def.outputs.values() - ] - - sample_validator = _create_sample_validator( - expected_input_keys=signature_def.inputs.keys() - ) - - for sample in map( - sample_validator, _log_sample_num_for_calibration(representative_dataset) - ): - # Create a mapping from input tensor name to the input tensor value. - # ex) "Placeholder:0" -> [0, 1, 2] - feed_dict = repr_dataset.create_feed_dict_from_input_data( - sample, signature_def - ) - sess.run(output_tensor_names, feed_dict=feed_dict) - - -def _replace_tensors_by_numpy_ndarrays( - repr_ds_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Replaces tf.Tensors by their evaluated numpy arrays. - - This assumes that tf.Tensors in representative samples are created in the - default Graph. It will raise an error if tensors are created in a different - graph. - - Args: - repr_ds_map: SignatureDef key -> RepresentativeDataset mapping. - """ - with session.Session() as sess: - for signature_def_key in repr_ds_map: - # Replaces the dataset with a new dataset where tf.Tensors are replaced - # by their evaluated values. - ds = repr_ds_map[signature_def_key] - repr_ds_map[signature_def_key] = ( - repr_dataset.replace_tensors_by_numpy_ndarrays(ds, sess) - ) - - -def _run_graph_for_calibration_graph_mode( - model_dir: str, - tags: Collection[str], - representative_dataset_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Runs the graph for calibration in graph mode. - - This function assumes _graph mode_ (used when legacy TF1 is used or when eager - mode is explicitly disabled) when running the graph. This step is used in - order to collect the statistics in CustomAggregatorOp for quantization using - the representative dataset for the actual data provided for inference. - - Args: - model_dir: Path to SavedModel directory. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel. - representative_dataset_map: A map where signature keys are mapped to - corresponding representative datasets. - - Raises: - ValueError: When running the function with the representative dataset fails. - """ - # Replace tf.Tensors by numpy ndarrays in order to reuse the samples in a - # different graph when running the calibration. - _replace_tensors_by_numpy_ndarrays(representative_dataset_map) - - # Run the calibration in a new graph to avoid name collision, which could - # happen when the same model is loaded multiple times in the default graph. - with ops.Graph().as_default(), session.Session() as sess: - meta_graph: meta_graph_pb2.MetaGraphDef = saved_model_loader.load( - sess, tags, export_dir=model_dir - ) - - for signature_key, repr_ds in representative_dataset_map.items(): - sig_def = meta_graph.signature_def[signature_key] - - try: - _run_function_for_calibration_graph_mode( - sess, signature_def=sig_def, representative_dataset=repr_ds - ) - except Exception as ex: - raise ValueError( - 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.' - ) from ex - - -def _run_function_for_calibration_eager_mode( - func: wrap_function.WrappedFunction, - representative_dataset: repr_dataset.RepresentativeDataset, -) -> None: - """Runs the representative dataset through a function for calibration. - - NOTE: This is intended to be run in eager mode (TF2). - - Args: - func: The function to run the representative samples through. - representative_dataset: Representative dataset used for calibration. The - input keys and input values of the representative samples should match the - keyword arguments of `func`. - """ - _, keyword_args = func.structured_input_signature - sample_validator = _create_sample_validator( - expected_input_keys=keyword_args.keys() - ) - - for sample in map( - sample_validator, _log_sample_num_for_calibration(representative_dataset) - ): - # Convert any non-Tensor values from the sample to Tensors. - # This conversion is required because the model saved in `model_dir` is - # saved using TF1 SavedModelBuilder, which doesn't save the - # SavedObjectGraph. - # TODO(b/236795224): Remove the need for this conversion by keeping the - # FunctionSpec (object graph) in the SavedModel. Related: b/213406917. - func_kwargs = _convert_values_to_tf_tensors(sample) - func(**func_kwargs) - - -def _run_graph_for_calibration_eager_mode( - model_dir: str, - tags: Collection[str], - representative_dataset_map: repr_dataset.RepresentativeDatasetMapping, -) -> None: - """Runs the graph for calibration in eager mode. - - This function assumes _eager mode_ (enabled in TF2 by default) when running - the graph. This step is used in order to collect the statistics in - CustomAggregatorOp for quantization using the representative dataset for the - actual data provided for inference. - - Args: - model_dir: Path to SavedModel directory. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel. - representative_dataset_map: A map where signature keys are mapped to - corresponding representative datasets. - - Raises: - ValueError: When running the function with the representative dataset fails. - """ - root: autotrackable.AutoTrackable = saved_model_load.load(model_dir, tags) - for signature_key, repr_ds in representative_dataset_map.items(): - try: - _run_function_for_calibration_eager_mode( - func=root.signatures[signature_key], representative_dataset=repr_ds - ) - except Exception as ex: - raise ValueError( - 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.' - ) from ex - - -def _run_graph_for_calibration( - float_model_dir: str, - signature_keys: Sequence[str], - tags: Collection[str], - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, - force_graph_mode_calibration: bool, -) -> None: - """Runs the graph for calibration using representative datasets. - - Args: - float_model_dir: Path to the model to calibrate. - signature_keys: Sequence of keys identifying SignatureDef containing inputs - and outputs. - tags: Collection of tags identifying the MetaGraphDef within the SavedModel - to analyze. - representative_dataset: An iterator that returns a dictionary of {input_key: - input_value} or a mapping from signature keys to such iterators. When - `signature_keys` contains more than one signature key, - `representative_datsaet` should be a mapping that maps each signature keys - to the corresponding representative dataset. - force_graph_mode_calibration: If set to true, it forces calibration in graph - model instead of eager mode when the context is in eager mode. - - Raises: - ValueError iff: - * The representative dataset format is invalid. - * It fails to run the functions using the representative datasets. - """ - try: - _validate_representative_dataset(representative_dataset, signature_keys) - except Exception as ex: - raise ValueError('Invalid representative dataset.') from ex - - # If `representative_dataset` is not a mapping, convert to a mapping for the - # following functions to handle representative datasets more conveniently. - representative_dataset_map = representative_dataset - if not isinstance(representative_dataset, collections.abc.Mapping): - # `signature_keys` is guaranteed to have only one element after the - # validation. - representative_dataset_map = {signature_keys[0]: representative_dataset} - - try: - if context.executing_eagerly() and not force_graph_mode_calibration: - logging.info('Calibration step is executed in eager mode.') - _run_graph_for_calibration_eager_mode( - float_model_dir, tags, representative_dataset_map - ) - else: - logging.info('Calibration step is executed in graph mode.') - _run_graph_for_calibration_graph_mode( - float_model_dir, tags, representative_dataset_map - ) - except Exception as ex: - raise ValueError( - 'Failed to run graph for post-training quantization calibration.' - ) from ex - - logging.info('Calibration step complete.') - - def _run_static_range_qat( src_saved_model_path: str, dst_saved_model_path: str, @@ -581,43 +162,6 @@ def _get_min_max_from_calibrator( return min_value, max_value -def _add_calibration_statistics( - graph_def: graph_pb2.GraphDef, - calib_opts: quant_opts_pb2.CalibrationOptions, -) -> None: - """Adds calibration statistics to the graph def. - - This function must be run after running the graph with a representative - dataset. Retrieves calibration statistics from the global calibrator and adds - them to the corresponding nodes as attributes. - - Args: - graph_def: GraphDef to add calibration statistics to. - calib_opts: Calibration options to calculate min and max. - """ - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'CustomAggregator': - continue - - node_id = node_def.attr['id'].s - try: - min_value, max_value = _get_min_max_from_calibrator(node_id, calib_opts) - pywrap_calibration.clear_data_from_calibrator(node_id) - - node_def.attr['min'].f = min_value - node_def.attr['max'].f = max_value - except ValueError: - logging.warning( - ( - 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' - 'min or max values. Parts of this function are not quantized.' - ), - node_id.decode('utf-8'), - function_def.signature.name, - ) - - def _enable_dump_tensor(graph_def: graph_pb2.GraphDef) -> None: """Enable DumpTensor in the graph def. @@ -697,6 +241,7 @@ def _run_static_range_ptq( signature_def_map_serialized=signature_def_map_serialized, function_aliases=dict(function_aliases), py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset=representative_dataset, ) ) exported_model = exported_model_pb2.ExportedModel.FromString( @@ -704,20 +249,6 @@ def _run_static_range_ptq( ) graph_def = exported_model.graph_def - # Uses the representative dataset to collect statistics for calibration. - # Handles the graph mode execution separately in case TF2 is disabled or - # eager execution is disabled. The min & max values are stored separately - # in a global CalibratorSingleton instance. - _run_graph_for_calibration( - pre_calib_output_model_path, - quant_opts.signature_keys, - quant_opts.tags, - representative_dataset, - quant_opts.force_graph_mode_calibration, - ) - - _add_calibration_statistics(graph_def, quant_opts.calibration_options) - py_function_library = py_function_lib.PyFunctionLibrary() if quant_opts.HasField('debugger_options'): # Since DumpTensor was disabled by default, we need to enable them. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h index ef0dfee6394d22..35d7458e614775 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h @@ -91,14 +91,15 @@ struct type_caster { } }; -// Python -> cpp conversion for `QuantizationOptions`. Accepts a serialized -// protobuf string and deserializes into an instance of `QuantizationOptions`. +// Handles type conversion for `QuantizationOptions`. template <> struct type_caster { public: PYBIND11_TYPE_CASTER(tensorflow::quantization::QuantizationOptions, const_name("QuantizationOptions")); + // Python -> C++. Converts a serialized protobuf string and deserializes into + // an instance of `QuantizationOptions`. bool load(handle src, const bool convert) { auto caster = make_caster(); // The user should have passed a valid python string. @@ -112,6 +113,14 @@ struct type_caster { // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. return value.ParseFromString(std::string(quantization_opts_serialized)); } + + // C++ -> Python. Constructs a `bytes` object after serializing `src`. + static handle cast(const tensorflow::quantization::QuantizationOptions& src, + return_value_policy policy, handle parent) { + // release() prevents the reference count from decreasing upon the + // destruction of py::bytes and returns a raw python object handle. + return py::bytes(internal::Serialize(src)).release(); + } }; template <> From d187e69335d3a50ba3a18ee988f9ef7a78737706 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Tue, 14 Nov 2023 21:06:14 -0800 Subject: [PATCH 101/391] Add in-memory cache for ifrt executable PiperOrigin-RevId: 582531337 --- tensorflow/core/tfrt/ifrt/BUILD | 36 ++++ .../core/tfrt/ifrt/ifrt_serving_executable.cc | 82 +++++++-- .../core/tfrt/ifrt/ifrt_serving_executable.h | 36 +++- .../tfrt/ifrt/ifrt_serving_executable_test.cc | 162 ++++++++++++++++++ tensorflow/core/tfrt/ifrt/testdata/BUILD | 12 ++ .../core/tfrt/ifrt/testdata/executable.mlir | 6 + 6 files changed, 316 insertions(+), 18 deletions(-) create mode 100644 tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc create mode 100644 tensorflow/core/tfrt/ifrt/testdata/BUILD create mode 100644 tensorflow/core/tfrt/ifrt/testdata/executable.mlir diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index 5fdc9b3ad8aa10..7aa13d0926a474 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ @@ -22,11 +24,14 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/concurrency:ref_count", @@ -66,3 +71,34 @@ cc_library( "@local_xla//xla/python/ifrt", ], ) + +tf_cc_test( + name = "ifrt_serving_executable_test", + srcs = [ + "ifrt_serving_executable_test.cc", + ], + data = [ + "//tensorflow/core/tfrt/ifrt/testdata", + ], + tags = ["no_oss"], + deps = [ + ":ifrt_serving_executable", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + ], +) diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index a03958bcbd644d..649e813a511ddd 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -1,4 +1,3 @@ - /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +24,9 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" @@ -84,24 +86,70 @@ IfrtServingExecutable::ConvertTensorToArray(const tensorflow::Tensor& tensor) { return single_array; } -absl::StatusOr> IfrtServingExecutable::Execute( +xla::ifrt::Future>> +IfrtServingExecutable::LookUpOrCreateExecutable( absl::Span inputs) { - // TODO(b/304839793): Build cache based on tensorshape etc - if (!ifrt_executable_) { - LOG(INFO) << "Cache missed. Building executable"; - - TF_ASSIGN_OR_RETURN(auto mlir_hlo_module, - CompileTfToHlo(*module_, inputs, signature_name(), - ifrt_client_->GetDefaultCompiler(), - shape_representation_fn_)); - - TF_ASSIGN_OR_RETURN( - ifrt_executable_, - ifrt_client_->GetDefaultCompiler()->Compile( - std::make_unique(mlir_hlo_module.get()), - std::make_unique())); + std::vector input_shapes; + for (const auto& tensor : inputs) { + input_shapes.push_back(tensor.shape()); + } + Key key(input_shapes); + + xla::ifrt::Promise< + absl::StatusOr>> + promise; + xla::ifrt::Future< + absl::StatusOr>> + future; + + { + absl::MutexLock lock(&mutex_); + + const auto it = ifrt_executables_.find(key); + if (it != ifrt_executables_.end()) { + return it->second; + } + + // Only create promise and future when cache missed. + promise = xla::ifrt::Future>>::CreatePromise(); + future = xla::ifrt::Future< + absl::StatusOr>>(promise); + + ifrt_executables_.emplace(key, future); + } + + LOG(INFO) << "Cache missed. Building executable"; + + absl::StatusOr> mlir_hlo_module = + CompileTfToHlo(*module_, inputs, signature_name(), + ifrt_client_->GetDefaultCompiler(), + shape_representation_fn_); + if (!mlir_hlo_module.ok()) { + promise.Set(mlir_hlo_module.status()); + return future; + } + + absl::StatusOr> ifrt_executable = + ifrt_client_->GetDefaultCompiler()->Compile( + std::make_unique(mlir_hlo_module->get()), + std::make_unique()); + if (!ifrt_executable.ok()) { + promise.Set(ifrt_executable.status()); + return future; } + promise.Set(std::shared_ptr( + std::move(*ifrt_executable))); + return future; +} + +absl::StatusOr> IfrtServingExecutable::Execute( + absl::Span inputs) { + TF_ASSIGN_OR_RETURN( + std::shared_ptr ifrt_executable, + LookUpOrCreateExecutable(inputs).Await()); + std::vector> args; args.reserve(inputs.size()); for (auto& tensor : inputs) { @@ -110,7 +158,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( } TF_ASSIGN_OR_RETURN(auto execution_result, - ifrt_executable_->Execute( + ifrt_executable->Execute( absl::MakeSpan(args), /*options=*/{.untuple_result = true}, std::nullopt)); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 64ce6580f8cfab..9b1d86cbcbbfd7 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -21,9 +21,12 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -32,6 +35,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tsl/concurrency/ref_count.h" @@ -65,7 +69,30 @@ class IfrtServingExecutable { absl::StatusOr> Execute( absl::Span inputs); + int num_executables() const { + absl::MutexLock lock(&mutex_); + return ifrt_executables_.size(); + } + private: + // In memory cache key. + struct Key { + std::vector input_shapes; + template + friend H AbslHashValue(H h, const Key& key) { + for (const auto& shape : key.input_shapes) { + for (auto size : shape.dim_sizes()) { + h = H::combine(std::move(h), size); + } + } + return h; + } + + friend bool operator==(const Key& x, const Key& y) { + return x.input_shapes == y.input_shapes; + } + }; + std::string model_name_; std::string signature_name_; @@ -76,10 +103,17 @@ class IfrtServingExecutable { tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; - std::unique_ptr ifrt_executable_; + mutable absl::Mutex mutex_; + absl::flat_hash_map>>> + ifrt_executables_ ABSL_GUARDED_BY(mutex_); absl::StatusOr> ConvertTensorToArray( const tensorflow::Tensor& tensor); + + xla::ifrt::Future< + absl::StatusOr>> + LookUpOrCreateExecutable(absl::Span inputs); }; } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc new file mode 100644 index 00000000000000..a2de6e9a68e16e --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/test.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +TEST(IfrtServingExecutableTest, Basic) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/core/tfrt/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/executable.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + IfrtServingExecutable executable("test", "main", std::move(mlir_module), + client, + tensorflow::IdentityShapeRepresentationFn()); + + tensorflow::Tensor x(tensorflow::DT_INT32, tensorflow::TensorShape({1, 3})); + tensorflow::Tensor y(tensorflow::DT_INT32, tensorflow::TensorShape({3, 1})); + for (int i = 0; i < 3; ++i) { + x.flat()(i) = i + 1; + y.flat()(i) = i + 1; + } + + std::vector inputs{x, y}; + TF_ASSERT_OK_AND_ASSIGN(auto result, + executable.Execute(absl::MakeSpan(inputs))); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(result[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(result[0].flat()(0), 14); +} + +TEST(IfrtServingExecutableTest, MultipleShapes) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/core/tfrt/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/executable.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + IfrtServingExecutable executable("test", "main", std::move(mlir_module), + client, + tensorflow::IdentityShapeRepresentationFn()); + + constexpr int kDim1 = 3; + tensorflow::Tensor x1(tensorflow::DT_INT32, + tensorflow::TensorShape({1, kDim1})); + tensorflow::Tensor y1(tensorflow::DT_INT32, + tensorflow::TensorShape({kDim1, 1})); + for (int i = 0; i < kDim1; ++i) { + x1.flat()(i) = i + 1; + y1.flat()(i) = i + 1; + } + std::vector inputs1{x1, y1}; + + constexpr int kDim2 = 4; + tensorflow::Tensor x2(tensorflow::DT_INT32, + tensorflow::TensorShape({1, kDim2})); + tensorflow::Tensor y2(tensorflow::DT_INT32, + tensorflow::TensorShape({kDim2, 1})); + for (int i = 0; i < kDim2; ++i) { + x2.flat()(i) = i + 1; + y2.flat()(i) = i + 1; + } + std::vector inputs2{x2, y2}; + + std::vector outputs1, outputs2; + for (int i = 0; i < 3; i++) { + TF_ASSERT_OK_AND_ASSIGN(outputs1, + executable.Execute(absl::MakeSpan(inputs1))); + TF_ASSERT_OK_AND_ASSIGN(outputs2, + executable.Execute(absl::MakeSpan(inputs2))); + } + ASSERT_EQ(outputs1.size(), 1); + ASSERT_EQ(outputs1[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(outputs1[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(outputs1[0].flat()(0), 14); + + ASSERT_EQ(outputs2.size(), 1); + ASSERT_EQ(outputs2[0].dtype(), tensorflow::DT_INT32); + ASSERT_EQ(outputs2[0].shape(), tensorflow::TensorShape({1, 1})); + ASSERT_EQ(outputs2[0].flat()(0), 30); + + ASSERT_EQ(executable.num_executables(), 2); +} + +} // namespace +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/testdata/BUILD b/tensorflow/core/tfrt/ifrt/testdata/BUILD new file mode 100644 index 00000000000000..948ce54ab983a7 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/testdata/BUILD @@ -0,0 +1,12 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/core/tfrt/ifrt:__subpackages__"], + licenses = ["notice"], +) + +filegroup( + name = "testdata", + srcs = glob( + ["*"], + ), +) diff --git a/tensorflow/core/tfrt/ifrt/testdata/executable.mlir b/tensorflow/core/tfrt/ifrt/testdata/executable.mlir new file mode 100644 index 00000000000000..95c558ddb7ae0b --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/testdata/executable.mlir @@ -0,0 +1,6 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.MatMul"(%arg0, %arg1): (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> + } +} \ No newline at end of file From 5b4049ba9aa6ecba6fd5e499c6a4a6704b973376 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 21:58:39 -0800 Subject: [PATCH 102/391] Integrate LLVM at llvm/llvm-project@5d6304f01742 Updates LLVM usage to match [5d6304f01742](https://github.com/llvm/llvm-project/commit/5d6304f01742) PiperOrigin-RevId: 582539795 --- third_party/llvm/generated.patch | 137 ------------------------------- third_party/llvm/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 139 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index c1d1504a0ac731..a37125c400d30a 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,141 +1,4 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp ---- a/llvm/lib/IR/Instruction.cpp -+++ b/llvm/lib/IR/Instruction.cpp -@@ -244,6 +244,8 @@ - Instruction::getDbgValueRange() const { - BasicBlock *Parent = const_cast(getParent()); - assert(Parent && "Instruction must be inserted to have DPValues"); -+ (void)Parent; -+ - if (!DbgMarker) - return DPMarker::getEmptyDPValueRange(); - -diff -ruN --strip-trailing-cr a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir ---- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir -+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir -@@ -54,7 +54,7 @@ - - func.func @mul(%arg0: tensor<4x6xf64>, - %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> { -- %out = tensor.empty() : tensor<4x4xf64> -+ %out = arith.constant dense<0.0> : tensor<4x4xf64> - %0 = linalg.generic #trait_mul - ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>) - outs(%out: tensor<4x4xf64>) { -@@ -68,7 +68,7 @@ - - func.func @mul_dense(%arg0: tensor<4x6xf64>, - %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> { -- %out = tensor.empty() : tensor<4x4xf64> -+ %out = arith.constant dense<0.0> : tensor<4x4xf64> - %0 = linalg.generic #trait_mul - ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>) - outs(%out: tensor<4x4xf64>) { -diff -ruN --strip-trailing-cr a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir ---- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir -+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir -@@ -85,30 +85,32 @@ - // A kernel that computes a BSR sampled dense matrix matrix multiplication - // using a "spy" function and in-place update of the sampling sparse matrix. - // -- func.func @SDDMM_block(%args: tensor, -- %arga: tensor, -- %argb: tensor) -> tensor { -- %result = linalg.generic #trait_SDDMM -- ins(%arga, %argb: tensor, tensor) -- outs(%args: tensor) { -- ^bb(%a: f32, %b: f32, %s: f32): -- %f0 = arith.constant 0.0 : f32 -- %u = sparse_tensor.unary %s : f32 to f32 -- present={ -- ^bb0(%p: f32): -- %mul = arith.mulf %a, %b : f32 -- sparse_tensor.yield %mul : f32 -- } -- absent={} -- %r = sparse_tensor.reduce %s, %u, %f0 : f32 { -- ^bb0(%p: f32, %q: f32): -- %add = arith.addf %p, %q : f32 -- sparse_tensor.yield %add : f32 -- } -- linalg.yield %r : f32 -- } -> tensor -- return %result : tensor -- } -+ // TODO: re-enable the following test. -+ // -+ // func.func @SDDMM_block(%args: tensor, -+ // %arga: tensor, -+ // %argb: tensor) -> tensor { -+ // %result = linalg.generic #trait_SDDMM -+ // ins(%arga, %argb: tensor, tensor) -+ // outs(%args: tensor) { -+ // ^bb(%a: f32, %b: f32, %s: f32): -+ // %f0 = arith.constant 0.0 : f32 -+ // %u = sparse_tensor.unary %s : f32 to f32 -+ // present={ -+ // ^bb0(%p: f32): -+ // %mul = arith.mulf %a, %b : f32 -+ // sparse_tensor.yield %mul : f32 -+ // } -+ // absent={} -+ // %r = sparse_tensor.reduce %s, %u, %f0 : f32 { -+ // ^bb0(%p: f32, %q: f32): -+ // %add = arith.addf %p, %q : f32 -+ // sparse_tensor.yield %add : f32 -+ // } -+ // linalg.yield %r : f32 -+ // } -> tensor -+ // return %result : tensor -+ // } - - func.func private @getTensorFilename(index) -> (!Filename) - -@@ -151,15 +153,15 @@ - // - %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) - %m_csr = sparse_tensor.new %fileName : !Filename to tensor -- %m_bsr = sparse_tensor.new %fileName : !Filename to tensor -+ // %m_bsr = sparse_tensor.new %fileName : !Filename to tensor - - // Call the kernel. - %0 = call @SDDMM(%m_csr, %a, %b) - : (tensor, - tensor, tensor) -> tensor -- %1 = call @SDDMM_block(%m_bsr, %a, %b) -- : (tensor, -- tensor, tensor) -> tensor -+ // %1 = call @SDDMM_block(%m_bsr, %a, %b) -+ // : (tensor, -+ // tensor, tensor) -> tensor - - // - // Print the result for verification. Note that the "spy" determines what -@@ -168,18 +170,18 @@ - // in the original zero positions). - // - // CHECK: ( 5, 10, 24, 19, 53, 42, 55, 56 ) -- // CHECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 ) -+ // C_HECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 ) - // - %v0 = sparse_tensor.values %0 : tensor to memref - %vv0 = vector.transfer_read %v0[%c0], %d0 : memref, vector<8xf32> - vector.print %vv0 : vector<8xf32> -- %v1 = sparse_tensor.values %1 : tensor to memref -- %vv1 = vector.transfer_read %v1[%c0], %d0 : memref, vector<12xf32> -- vector.print %vv1 : vector<12xf32> -+ // %v1 = sparse_tensor.values %1 : tensor to memref -+ // %vv1 = vector.transfer_read %v1[%c0], %d0 : memref, vector<12xf32> -+ // vector.print %vv1 : vector<12xf32> - - // Release the resources. - bufferization.dealloc_tensor %0 : tensor -- bufferization.dealloc_tensor %1 : tensor -+ // bufferization.dealloc_tensor %1 : tensor - - llvm.call @mgpuDestroySparseEnv() : () -> () - return diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index a8c20faa8cde27..4bd62858469ca5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "ed86e740effaf1de540820a145a9df44eaf0df0e" - LLVM_SHA256 = "f5c849d3c450faa5a68a14e11f62bfd7d957df87603b39662d7f4179b21c7f7a" + LLVM_COMMIT = "5d6304f01742a0a7c628fe6850e921c745eaea08" + LLVM_SHA256 = "5230a6bd323dc27893adf688ce1769854e0b92d8ce2f4d14ac62b9a200a1e452" tf_http_archive( name = name, From 50e76284aeff50bbf37c6de3d858aa9204840114 Mon Sep 17 00:00:00 2001 From: Hye Soo Yang Date: Tue, 14 Nov 2023 22:16:43 -0800 Subject: [PATCH 103/391] Move REGISTER_OP call for `global_iter_id_op` to sprase_core_ops.cc. PiperOrigin-RevId: 582542958 --- .../core/api_def/base_api/api_def_GlobalIterId.pbtxt | 4 ++++ .../core/api_def/python_api/api_def_GlobalIterId.pbtxt | 4 ++++ tensorflow/core/tpu/kernels/BUILD | 1 + tensorflow/core/tpu/kernels/global_iter_id.cc | 1 - tensorflow/core/tpu/ops/sparse_core_ops.cc | 8 ++++++++ tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 4 ++++ tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 4 ++++ 7 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt b/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt new file mode 100644 index 00000000000000..7ec4d4db81f96c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GlobalIterId.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GlobalIterId" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt b/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt new file mode 100644 index 00000000000000..7ec4d4db81f96c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_GlobalIterId.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GlobalIterId" + visibility: HIDDEN +} \ No newline at end of file diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 862793b3bb1082..9266fc48e6d923 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -1368,6 +1368,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/tpu/ops:sparse_core_ops", ], ) diff --git a/tensorflow/core/tpu/kernels/global_iter_id.cc b/tensorflow/core/tpu/kernels/global_iter_id.cc index 92a44d7c106a2b..11b80146f63153 100644 --- a/tensorflow/core/tpu/kernels/global_iter_id.cc +++ b/tensorflow/core/tpu/kernels/global_iter_id.cc @@ -29,7 +29,6 @@ class GlobalIterId : public OpKernel { ctx->set_output(0, Tensor(ctx->frame_iter().iter_id)); } }; -REGISTER_OP("GlobalIterId").Output("iter_id: int64").SetIsStateful(); REGISTER_KERNEL_BUILDER(Name("GlobalIterId").Device(DEVICE_CPU), GlobalIterId); } // anonymous namespace diff --git a/tensorflow/core/tpu/ops/sparse_core_ops.cc b/tensorflow/core/tpu/ops/sparse_core_ops.cc index f9b9d64339e572..e770c1814399a2 100644 --- a/tensorflow/core/tpu/ops/sparse_core_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_ops.cc @@ -322,4 +322,12 @@ REGISTER_OP("XlaSparseCoreFtrl") return OkStatus(); }); +REGISTER_OP("GlobalIterId") + .Output("iter_id: int64") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + c->set_output(0, c->Scalar()); + return OkStatus(); + }); + } // namespace tensorflow diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f78ba2e0839c78..30efcc7aa3a0ad 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1936,6 +1936,10 @@ tf_module { name: "GetSessionTensor" argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalIterId" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f78ba2e0839c78..30efcc7aa3a0ad 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1936,6 +1936,10 @@ tf_module { name: "GetSessionTensor" argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GlobalIterId" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Greater" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " From 3e6a592532a8024d0a7a6617dfa63dd44e3e9359 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 22:18:29 -0800 Subject: [PATCH 104/391] Update ops-related pbtxt files. PiperOrigin-RevId: 582543274 --- .../core/ops/compat/ops_history_v2/GlobalIterId.pbtxt | 8 ++++++++ tensorflow/core/ops/ops.pbtxt | 8 ++++++++ 2 files changed, 16 insertions(+) create mode 100644 tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt diff --git a/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt new file mode 100644 index 00000000000000..5fa2302622c9ac --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/GlobalIterId.pbtxt @@ -0,0 +1,8 @@ +op { + name: "GlobalIterId" + output_arg { + name: "iter_id" + type: DT_INT64 + } + is_stateful: true +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 2bcdc8b109329d..441712fdaa6d02 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -22034,6 +22034,14 @@ op { } is_stateful: true } +op { + name: "GlobalIterId" + output_arg { + name: "iter_id" + type: DT_INT64 + } + is_stateful: true +} op { name: "Greater" input_arg { From 48260140136bc5928bb231e62e2bbaa7a5521139 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 22:39:57 -0800 Subject: [PATCH 105/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/fce7c27b3191264a8ed581e03900e094c793593a. PiperOrigin-RevId: 582546998 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index dcbd5d3e632cd2..12ea2f27056a25 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "f5091e2c05925158e0be192370a37a6cf6fcf241" - TFRT_SHA256 = "0b3cbc0ca3862115b2b15122402c42090dd3b52090183e5f6579fe7769e9df0f" + TFRT_COMMIT = "fce7c27b3191264a8ed581e03900e094c793593a" + TFRT_SHA256 = "3fe89488c1f138c36e9b4b6a220fe5ea6ecd184163c20917b0ab7eb215d32979" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index dcbd5d3e632cd2..12ea2f27056a25 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "f5091e2c05925158e0be192370a37a6cf6fcf241" - TFRT_SHA256 = "0b3cbc0ca3862115b2b15122402c42090dd3b52090183e5f6579fe7769e9df0f" + TFRT_COMMIT = "fce7c27b3191264a8ed581e03900e094c793593a" + TFRT_SHA256 = "3fe89488c1f138c36e9b4b6a220fe5ea6ecd184163c20917b0ab7eb215d32979" tf_http_archive( name = "tf_runtime", From f32f0c7b04da82e726cb5037444abe193d83b3d9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Nov 2023 23:18:00 -0800 Subject: [PATCH 106/391] Internal Code Change PiperOrigin-RevId: 582554026 --- tensorflow/core/tfrt/mlrt/kernel/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index 9dacee29030c8f..cca6cddbfb650a 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -10,7 +10,6 @@ package( # copybara:uncomment "//learning/brain/tfrt:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/core/tfrt/graph_executor:__subpackages__", - "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests:__subpackages__", "//tensorflow/core/tfrt/saved_model:__subpackages__", "//tensorflow/core/tfrt/tfrt_session:__subpackages__", ], From 559646ac10616b3c76b3aeb96d65caee303565d7 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Tue, 14 Nov 2023 23:57:01 -0800 Subject: [PATCH 107/391] Quantize same scale ops that are connected to quantized composite function PiperOrigin-RevId: 582561746 --- .../mlir/quantization/stablehlo/BUILD | 31 ++ .../mlir/quantization/stablehlo/ops/BUILD | 29 ++ .../stablehlo/ops/stablehlo_op_quant_spec.cc | 105 +++++ .../stablehlo/ops/stablehlo_op_quant_spec.h | 41 ++ .../stablehlo/passes/prepare_quantize.cc | 53 +-- .../stablehlo/passes/quantization_pattern.h | 376 ++++++++++++++++++ .../quantization/stablehlo/passes/quantize.cc | 16 +- .../mlir/quantization/stablehlo/tests/BUILD | 22 +- .../stablehlo/tests/quantize_same_scale.mlir | 66 +++ .../tests/stablehlo_op_quant_spec_test.cc | 209 ++++++++++ 10 files changed, 892 insertions(+), 56 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 640cd2e6cb7366..a19017c95006ab 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -51,6 +51,7 @@ cc_library( ":lift_quantizable_spots_as_functions_fusion_inc_gen", ":lift_quantizable_spots_as_functions_simple_inc_gen", ":quantization_options_proto_cc", + ":quantization_pattern", ":stablehlo_passes_inc_gen", ":stablehlo_type_utils", ":uniform_quantized_types", @@ -58,6 +59,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -67,6 +69,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/platform:path", "//tensorflow/core/tpu:tpu_defs", @@ -104,6 +107,34 @@ cc_library( alwayslink = True, ) +cc_library( + name = "quantization_pattern", + hdrs = [ + "passes/quantization_pattern.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":bridge_passes", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], + # Alwayslink is required for registering the MLIR passes. + # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. + alwayslink = True, +) + td_library( name = "quant_td_files", srcs = [ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD new file mode 100644 index 00000000000000..d3bf62dfce4923 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -0,0 +1,29 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", + ], + licenses = ["notice"], +) + +cc_library( + name = "stablehlo_op_quant_spec", + srcs = [ + "stablehlo_op_quant_spec.cc", + ], + hdrs = ["stablehlo_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc new file mode 100644 index 00000000000000..05ea20c29a942c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -0,0 +1,105 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant::stablehlo { + +std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (auto call_op = dyn_cast_or_null(op)) { + auto entry_function = + call_op->getAttrOfType("_entry_function"); + StringRef function_name = entry_function.getValue(); + if (!function_name.startswith("composite_")) { + return spec; + } + if (function_name.contains("conv")) { + spec->coeff_op_quant_dim[1] = 3; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot_general")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } + for (auto quantizable_operand : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(quantizable_operand.first); + } + } + return spec; +} + +std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { + auto scale_spec = std::make_unique(); + // TODO - b/307619822: Add below ops to the spec with unit tests. + // mlir::stablehlo::SelectOp, mlir::stablehlo::PadOp, + // mlir::stablehlo::GatherOp, mlir::stablehlo::SliceOp, + // mlir::stablehlo::BroadcastInDimOp + if (llvm::isa(op)) { + scale_spec->has_same_scale_requirement = true; + } + return scale_spec; +} + +bool IsOpQuantizableStableHlo(Operation* op) { + if (mlir::isa(op)) { + // Constant ops do not have QuantizableResult attribute but can be + // quantized. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + if (GetStableHloQuantScaleSpec(op)->has_same_scale_requirement) { + return true; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + return attr_enforced_quantizable; +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h new file mode 100644 index 00000000000000..c898a99c08f68f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Returns StableHLO quantization specs for an op. +std::unique_ptr GetStableHloOpQuantSpec(Operation* op); + +// Returns quantization scale specs (fixed output, same scale) for a StableHLO +// op. +std::unique_ptr GetStableHloQuantScaleSpec(Operation* op); + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 6da27d9e3c2823..24d15dfd6688d5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -15,11 +15,9 @@ limitations under the License. // Copied and modified from // //third_party/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc // This transformation pass applies quantization propagation on TF dialect. -#include #include #include -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,6 +33,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -134,50 +133,6 @@ class ConvertArithConstToStablehloConstOp } }; -std::unique_ptr GetStableHLOOpQuantSpec(Operation* op) { - auto spec = std::make_unique(); - if (auto call_op = dyn_cast_or_null(op)) { - auto entry_function = - call_op->getAttrOfType("_entry_function"); - StringRef function_name = entry_function.getValue(); - if (!function_name.startswith("composite_")) { - return spec; - } - if (function_name.contains("conv")) { - spec->coeff_op_quant_dim[1] = 3; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } else if (function_name.contains("dot_general")) { - spec->coeff_op_quant_dim[1] = -1; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } else if (function_name.contains("dot")) { - spec->coeff_op_quant_dim[1] = -1; - if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; - } - } - for (auto quantizable_operand : spec->coeff_op_quant_dim) { - spec->quantizable_operands.insert(quantizable_operand.first); - } - } - return spec; -} - -std::unique_ptr GetStableHLOQuantScaleSpec(Operation* op) { - auto scale_spec = std::make_unique(); - if (llvm::isa( - op)) { - scale_spec->has_same_scale_requirement = true; - } - return scale_spec; -} - void PrepareQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); MLIRContext* ctx = func.getContext(); @@ -185,8 +140,8 @@ void PrepareQuantizePass::runOnOperation() { // The function might contain more stats ops than required, and it will // introduce requantize if the calibration stats have conflicts. This tries to // remove all the redundant stats ops. - RemoveRedundantStatsOps(func, GetStableHLOOpQuantSpec, - GetStableHLOQuantScaleSpec); + RemoveRedundantStatsOps(func, GetStableHloOpQuantSpec, + GetStableHloQuantScaleSpec); RewritePatternSet patterns(ctx); // Convert quant stats to int8 quantization parameters. @@ -209,7 +164,7 @@ void PrepareQuantizePass::runOnOperation() { // values (tensors). ApplyQuantizationParamsPropagation( func, /*is_signed=*/true, bit_width_, !enable_per_channel_quantization_, - GetStableHLOOpQuantSpec, GetStableHLOQuantScaleSpec, + GetStableHloOpQuantSpec, GetStableHloQuantScaleSpec, /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false); // Restore constants as stablehlo::ConstantOp. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h new file mode 100644 index 00000000000000..23564308da8d6d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h @@ -0,0 +1,376 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir::quant::stablehlo { + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation *) const +// bool AllowDynamicRangeQuantizedResult(Operation *) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +// +// Implementation of this pattern is mostly copied from QuantizationPattern in +// third_party/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h. +// TODO - b/310545259 : Split declarations and implementations of +// StableHloQuantizationPattern. +template +class StableHloQuantizationPattern : public RewritePattern { + public: + using BaseType = + StableHloQuantizationPattern; + + explicit StableHloQuantizationPattern( + MLIRContext* context, const mlir::quant::QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + op->emitError("Dequantize op should have exactly one result."); + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + op->emitError("Quantize op should have exactly one operand."); + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of the quantize op has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, we shouldn't rewrite. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizableStableHlo(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + quantizing_op, custom_map)) { + return failure(); + } + + if (GetStableHloQuantScaleSpec(quantizing_op) + ->has_same_scale_requirement && + !IsConnectedWithQuantizedCompsiteFunction(quantizing_op)) { + return failure(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (operand_type.isa()) { + inputs.push_back(operand); + continue; + } + + auto ele_type = operand.getType().cast().getElementType(); + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + result.getType().cast().getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state(quantizing_op->getLoc(), + quantizing_op->getName().getStringRef(), inputs, + output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + return success(); + } + + private: + QuantPassSpec quant_params_; + + // Checks whether the operation is connnected with a quantized composite + // function. If not, the same-scale op will not be quantized. This decision is + // based on the current assumption that the performance gain of the same-scale + // op itself could not beat the overhead of the quantize and dequantize + // routines need to be added around that op. When the assumption changes, + // this policy might change as well. + bool IsConnectedWithQuantizedCompsiteFunction( + Operation* same_scale_op) const { + for (const auto& operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(preceding_op)) { + auto call_op = llvm::cast(preceding_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the preceding op is a quantized same-scale op. + if (GetStableHloQuantScaleSpec(preceding_op) + ->has_same_scale_requirement) { + for (auto result : preceding_op->getResults()) { + auto element_type = getElementTypeOrSelf(result.getType()); + if (element_type.isa()) { + return true; + } + } + } + } + + for (const auto& result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !llvm::isa(*result.user_begin())) { + continue; + } + + auto q_op = llvm::cast(*result.user_begin()); + for (auto following_op : q_op->getUsers()) { + // Check whether the following op is a quantized composite function. + if (llvm::isa(following_op)) { + auto call_op = llvm::cast(following_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the following op is a quantized same-scale op. + if (GetStableHloQuantScaleSpec(following_op) + ->has_same_scale_requirement) { + for (auto operand : following_op->getOperands()) { + auto element_type = getElementTypeOrSelf(operand.getType()); + if (element_type.isa()) { + return true; + } + } + } + } + } + + return false; + } + + // Checks if op calls a composite function and all the inputs and outputs are + // quantized. + bool IsQuantizedCompositeFunction(TF::XlaCallModuleOp call_op) const { + if (!call_op->hasAttr(kQuantTraitAttrName)) { + return false; + } + + const auto function_name = call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); + if (!function_name || !function_name.getValue().startswith("composite_")) { + return false; + } + + bool has_quantized_types = false; + for (Value input : call_op.getArgs()) { + if (auto type = input.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + for (Value output : call_op.getOutput()) { + if (auto type = output.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + return has_quantized_types; + } +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERN_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 16e7ad1cfd7010..811c53125056cd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -20,10 +20,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -31,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_pattern.h" namespace mlir::quant::stablehlo { @@ -42,14 +45,15 @@ namespace { // Base struct for quantization. template struct StableHloQuantizationBase - : public QuantizationPattern { + : public StableHloQuantizationPattern { explicit StableHloQuantizationBase(MLIRContext* ctx, const QuantPassSpec& quant_params) - : QuantizationPattern(ctx, quant_params) {} + : StableHloQuantizationPattern( + ctx, quant_params) {} static bool IsQuantizableCustomOp(Operation* op, const CustomMap& custom_op_map) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index 4c078033215618..0392eb669a4a71 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -1,6 +1,6 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -46,3 +46,23 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", ], ) + +tf_cc_test( + name = "stablehlo_op_quant_spec_test", + srcs = ["stablehlo_op_quant_spec_test.cc"], + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:QuantOps", + "@stablehlo//:stablehlo_ops", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir new file mode 100644 index 00000000000000..09df3a094741c6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir @@ -0,0 +1,66 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize -verify-each=false | FileCheck %s + +// CHECK-LABEL: same_scale_after_composite +func.func @same_scale_after_composite() -> tensor<3x1xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = "stablehlo.reshape"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[RESHAPE]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} + +// ----- + +// CHECK-LABEL: same_scale_indirectly_connected +func.func @same_scale_indirectly_connected() -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = "stablehlo.reshape"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[TRANSPOSE:.*]] = "stablehlo.transpose"(%[[RESHAPE]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x1x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[TRANSPOSE]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + %6 = "stablehlo.transpose"(%5) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x1xf32>) -> tensor<1x3xf32> + %7 = "quantfork.qcast"(%6) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %8 = "quantfork.dcast"(%7) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %8 : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: same_scale_not_connected_to_composite +func.func @same_scale_not_connected_to_composite() -> tensor<3x1xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[CST]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ1:.*]] = "quantfork.dcast"(%[[Q1]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[DQ1]] + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[RESHAPE]]) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ2:.*]] = "quantfork.dcast"(%[[Q2]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ2]] + + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc new file mode 100644 index 00000000000000..281bfc996a1c77 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/stablehlo_op_quant_spec_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir::quant::stablehlo { +namespace { + +class IsOpQuantizableStableHloTest : public ::testing::Test { + protected: + IsOpQuantizableStableHloTest() { + ctx_.loadDialect(); + } + + // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created + // module op is valid. + OwningOpRef ParseModuleOpString( + const absl::string_view module_op_str) { + auto module_op_ref = parseSourceString(module_op_str, &ctx_); + EXPECT_TRUE(module_op_ref); + return module_op_ref; + } + + // Gets the function with the given name from the module. + func::FuncOp GetFunctionFromModule(ModuleOp module, + absl::string_view function_name) { + SymbolTable symbol_table(module); + return symbol_table.lookup(function_name); + } + + // Returns the first operation with the given type in the function. + template + OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; + } + + mlir::MLIRContext ctx_{}; +}; + +// Quantizable ops: constants +// Non-quantizable ops: normal StableHLO ops and terminators +constexpr absl::string_view module_constant_add = R"mlir( + module { + func.func @constant_add() -> (tensor<3x2xf32>) { + %cst1 = stablehlo.constant dense<2.4> : tensor<3x2xf32> + %cst2 = stablehlo.constant dense<5.7> : tensor<3x2xf32> + %add = stablehlo.add %cst1, %cst2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %add : tensor<3x2xf32> + } + } +)mlir"; + +// Quantizable ops: XlaCallModule op with "fully_quantizable" attribute and +// same-scale StableHLO ops +// Non-quantizable ops: quantize/dequantize ops +constexpr absl::string_view module_composite_same_scale = R"mlir( + module { + func.func @same_scale_after_composite() -> tensor<3x1xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> + } + } +)mlir"; + +// Non-quantizable ops: XlaCallModule op without "fully_quantizable" attribute +constexpr absl::string_view module_composite_no_attr = R"mlir( + module { + func.func @composite_without_attr() -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @non_quantizable_composite, _original_entry_function = "non_quantizable_composite", _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + } +)mlir"; + +TEST_F(IsOpQuantizableStableHloTest, ConstantOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* constant_op = + FindOperationOfType(test_func); + bool is_constant_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(constant_op); + + EXPECT_TRUE(is_constant_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, TerminatorOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* return_op = FindOperationOfType(test_func); + bool is_return_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(return_op); + + EXPECT_FALSE(is_return_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, NonSameScaleStableHloOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_constant_add); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "constant_add"); + Operation* add_op = FindOperationOfType(test_func); + bool is_add_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(add_op); + + EXPECT_FALSE(is_add_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, QuantizableXlaCallModuleOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* xla_call_module_op = + FindOperationOfType(test_func); + bool is_xla_call_module_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); + + EXPECT_TRUE(is_xla_call_module_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, NonQuantizableXlaCallModuleOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_no_attr); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "composite_without_attr"); + Operation* xla_call_module_op = + FindOperationOfType(test_func); + bool is_xla_call_module_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(xla_call_module_op); + + EXPECT_FALSE(is_xla_call_module_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, SameScaleStableHloOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* reshape_op = + FindOperationOfType(test_func); + bool is_reshape_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(reshape_op); + + EXPECT_TRUE(is_reshape_quantizable); +} + +TEST_F(IsOpQuantizableStableHloTest, QuantizeDequantizeOp) { + OwningOpRef module_op_ref = + ParseModuleOpString(module_composite_same_scale); + func::FuncOp test_func = + GetFunctionFromModule(*module_op_ref, "same_scale_after_composite"); + Operation* quantize_op = + FindOperationOfType(test_func); + Operation* dequantize_op = + FindOperationOfType(test_func); + bool is_quantize_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(quantize_op); + bool is_dequantize_quantizable = + mlir::quant::stablehlo::IsOpQuantizableStableHlo(dequantize_op); + + EXPECT_FALSE(is_quantize_quantizable); + EXPECT_FALSE(is_dequantize_quantizable); +} + +} // namespace +} // namespace mlir::quant::stablehlo From 066708b5ed74276e9930d4897c4bf13aa385af53 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Wed, 15 Nov 2023 00:31:09 -0800 Subject: [PATCH 108/391] Make flib_def mutable PiperOrigin-RevId: 582569370 --- tensorflow/core/tpu/kernels/tpu_compile_op_support.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index d098abe6e1ae08..5cb7e5a5d55511 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -62,7 +62,7 @@ using GuaranteedConsts = std::variant, // List of parameters for lowering function library definition to HLO IR. struct FunctionToHloArgs { const NameAttrList* const function; - const FunctionLibraryDefinition* const flib_def; + const FunctionLibraryDefinition* flib_def; int graph_def_version; GuaranteedConsts guaranteed_constants; }; From c5f5294d9d3f5fe6029a6a5cec76460fa7b14632 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 00:34:02 -0800 Subject: [PATCH 109/391] Remove unused visibility specs PiperOrigin-RevId: 582570313 --- tensorflow/lite/acceleration/configuration/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD index 4f1b2fe568cb97..6b6aab6638d033 100644 --- a/tensorflow/lite/acceleration/configuration/BUILD +++ b/tensorflow/lite/acceleration/configuration/BUILD @@ -277,11 +277,8 @@ cc_library( "//conditions:default": [], }), visibility = [ - "//tensorflow/lite/acceleration/configuration/c:__pkg__", "//tensorflow/lite/core/acceleration/configuration/c:__pkg__", - "//tensorflow/lite/core/experimental/acceleration/configuration/c:__pkg__", "//tensorflow/lite/experimental/acceleration/configuration:__pkg__", - "//tensorflow/lite/experimental/acceleration/configuration/c:__pkg__", ], deps = [ ":configuration_fbs", From 39a6f00890be347e4a11d092d8117b77fa9db86a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 01:02:12 -0800 Subject: [PATCH 110/391] compat: Update forward compatibility horizon to 2023-11-15 PiperOrigin-RevId: 582577030 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d49b9634a64a25..5bb692137843f1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 14) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 15) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 8f8b5481066f8e628be1426ab605a4553c8e28d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 01:06:01 -0800 Subject: [PATCH 111/391] Update GraphDef version to 1681. PiperOrigin-RevId: 582578078 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 28f282f8ebe8e5..1cbde634435938 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1680 // Updated: 2023/11/14 +#define TF_GRAPH_DEF_VERSION 1681 // Updated: 2023/11/15 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 31920d0a1ae7c864644eced2cac15ccfb36a246b Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Wed, 15 Nov 2023 03:18:29 -0800 Subject: [PATCH 112/391] Support quantization of stablehlo.concatenate, convert, pad and select PiperOrigin-RevId: 582609952 --- .../stablehlo/ops/stablehlo_op_quant_spec.cc | 4 +- .../stablehlo/tests/quantize_same_scale.mlir | 114 ++++++++++++++++++ 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc index 05ea20c29a942c..39a2a96313079f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.cc @@ -70,11 +70,11 @@ std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { std::unique_ptr GetStableHloQuantScaleSpec(Operation* op) { auto scale_spec = std::make_unique(); // TODO - b/307619822: Add below ops to the spec with unit tests. - // mlir::stablehlo::SelectOp, mlir::stablehlo::PadOp, // mlir::stablehlo::GatherOp, mlir::stablehlo::SliceOp, // mlir::stablehlo::BroadcastInDimOp if (llvm::isa(op)) { + mlir::stablehlo::PadOp, mlir::stablehlo::ReshapeOp, + mlir::stablehlo::SelectOp, mlir::stablehlo::TransposeOp>(op)) { scale_spec->has_same_scale_requirement = true; } return scale_spec; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir index 09df3a094741c6..ff294bb4eb7031 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/quantize_same_scale.mlir @@ -64,3 +64,117 @@ func.func @same_scale_not_connected_to_composite() -> tensor<3x1xf32> { %5 = "quantfork.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> return %5 : tensor<3x1xf32> } + +// ----- + +// CHECK-LABEL: concatenate_and_composite +// CHECK: %[[ARG0:.*]]: tensor<3x2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x2xf32> +func.func @concatenate_and_composite(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x5xf32> { + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[PAD:.*]] = "stablehlo.concatenate"(%[[Q1]], %[[Q2]]) {dimension = 0 : i64} + // CHECK-SAME: (tensor<3x2x!quant.uniform:f32, 5.000000e-03>>, tensor<1x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"(%[[PAD]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: (tensor<4x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CALL]]) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %1 = "quantfork.dcast"(%0) : (tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<3x2xf32> + %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform:f32, 5.000000e-03>> + %3 = "quantfork.dcast"(%2) : (tensor<1x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2xf32> + %4 = "stablehlo.concatenate"(%1, %3) { + dimension = 0 : i64 + } : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> + %5 = "quantfork.qcast"(%4) {volatile} : (tensor<4x2xf32>) -> tensor<4x2x!quant.uniform:f32, 5.000000e-03>> + %6 = "quantfork.dcast"(%5) : (tensor<4x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<4x2xf32> + %7 = "tf.XlaCallModule"(%6) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>) -> tensor<4x5xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<4x5xf32>) -> tensor<4x5x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + return %9 : tensor<4x5xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_convert +func.func @composite_and_convert() -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[CONVERT:.*]] = "stablehlo.convert"(%[[CALL]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CONVERT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.convert %2 : (tensor<1x3xf32>) -> (tensor<1x3xf32>) + %4 = "quantfork.qcast"(%3) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %5 = "quantfork.dcast"(%4) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> +} + +// ----- + +// CHECK-LABEL: pad_and_composite +// CHECK: %[[ARG0:.*]]: tensor<2x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor +func.func @pad_and_composite(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<5x6xf32> { + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor) -> tensor:f32, 5.000000e-03>> + // CHECK: %[[PAD:.*]] = "stablehlo.pad"(%[[Q1]], %[[Q2]]) + // CHECK-SAME: {edge_padding_high = dense<[2, 1]> : tensor<2xi64>, edge_padding_low = dense<[0, 1]> : tensor<2xi64>, interior_padding = dense<[1, 2]> : tensor<2xi64>} + // CHECK-SAME: (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>, tensor:f32, 5.000000e-03>>) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"(%[[PAD]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: (tensor<5x9x!quant.uniform:f32, 5.000000e-03>>) -> tensor<5x6x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[CALL]]) : (tensor<5x6x!quant.uniform>) -> tensor<5x6xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "quantfork.qcast"(%arg0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %2 = "quantfork.qcast"(%arg1) {volatile} : (tensor) -> tensor:f32, 5.000000e-03>> + %3 = "quantfork.dcast"(%2) : (tensor:f32, 5.000000e-03>>) -> tensor + %4 = "stablehlo.pad"(%1, %3) { + edge_padding_low = dense<[0, 1]> : tensor<2xi64>, + edge_padding_high = dense<[2, 1]> : tensor<2xi64>, + interior_padding = dense<[1, 2]> : tensor<2xi64> + }: (tensor<2x3xf32>, tensor) -> tensor<5x9xf32> + %5 = "quantfork.qcast"(%4) {volatile} : (tensor<5x9xf32>) -> tensor<5x9x!quant.uniform:f32, 5.000000e-03>> + %6 = "quantfork.dcast"(%5) : (tensor<5x9x!quant.uniform:f32, 5.000000e-03>>) -> tensor<5x9xf32> + %7 = "tf.XlaCallModule"(%6) {Sout = [#tf_type.shape<5x6>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<5x9xf32>) -> tensor<5x6xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<5x6xf32>) -> tensor<5x6x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<5x6x!quant.uniform>) -> tensor<5x6xf32> + return %9 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: composite_and_select +// CHECK: %[[ARG0:.*]]: tensor<1x3xi1> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x3xf32> +func.func @composite_and_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[CALL:.*]] = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-SAME: () -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG1]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[SELECT:.*]] = "stablehlo.select"(%[[ARG0]], %[[CALL]], %[[Q1]]) : (tensor<1x3xi1>, tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantfork.dcast"(%2) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: "func.return"(%[[DQ]]) + + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantfork.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantfork.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = "quantfork.qcast"(%arg1) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.select %arg0, %2, %4 : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %8 = "quantfork.qcast"(%7) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %9 = "quantfork.dcast"(%8) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %9 : tensor<1x3xf32> +} From 193a8150d359bf6fffa9aa60834b95b4b7255f6f Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 15 Nov 2023 03:37:02 -0800 Subject: [PATCH 113/391] [XLA:GPU] Add helpers to calculate fused and unfused time separately. (NFC) Fused and unfused logic happens to be interwined in one loop, but there is not good reason for it. There is not shared computations, expect for producer_data. PiperOrigin-RevId: 582613392 --- .../gpu/model/gpu_performance_model.cc | 112 ++++++++++++------ .../service/gpu/model/gpu_performance_model.h | 12 ++ 2 files changed, 87 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index bc23ba1c5e2153..7d5045cef03337 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -399,30 +399,60 @@ absl::Duration GpuPerformanceModel::ComputeTime( return absl::Nanoseconds(1.0f * flops / flop_per_ns_effective); } -GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( - const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, +absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( + const HloInstruction* producer, const EstimateRunTimeData& producer_data, + const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, - std::vector fused_consumers, bool multi_output) { - VLOG(8) << "Producer: " << producer->name(); - if (producer->opcode() == HloOpcode::kFusion) { - VLOG(10) << producer->fused_instructions_computation()->ToString(); + std::vector fused_consumers) { + const se::DeviceDescription* device_info = cost_analysis->device_info_; + + absl::Duration producer_output_read_time_unfused = absl::ZeroDuration(); + + for (const HloInstruction* fused_consumer : fused_consumers) { + VLOG(8) << "Unfused consumer: " << fused_consumer->name(); + float utilization_by_this_consumer = cost_analysis->operand_utilization( + *fused_consumer, fused_consumer->operand_index(producer)); + + auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); + + LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( + ShapeUtil::ElementsInRecursive(fused_consumer->shape()), + analysis_unfused, *device_info); + + int64_t n_bytes_total = std::llround(producer_data.bytes_written * + utilization_by_this_consumer); + int64_t n_bytes_net = std::min(producer_data.bytes_written, n_bytes_total); + + auto read_time_unfused = ReadTime( + *device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net, + n_bytes_total, fused_consumer->shape().element_type(), + /*coalesced=*/!TransposesMinorDimension(fused_consumer), + config.first_read_from_dram); + + VLOG(10) << " Read time unfused: " << read_time_unfused; + producer_output_read_time_unfused += read_time_unfused; } - const se::DeviceDescription* device_info = cost_analysis->device_info_; + absl::Duration time_unfused = + kKernelLaunchOverhead * (fused_consumers.size() + 1) + + producer_data.exec_time + producer_output_read_time_unfused; - EstimateRunTimeData producer_data = - EstimateRunTimeForInstruction(producer, cost_analysis, config); + return time_unfused; +} - int64_t fused_consumer_count = fused_consumers.size(); - float total_producer_utilization = 0; +absl::Duration GpuPerformanceModel::EstimateFusedExecTime( + const HloInstruction* producer, const EstimateRunTimeData& producer_data, + const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + std::vector fused_consumers, bool multi_output) { + const se::DeviceDescription* device_info = cost_analysis->device_info_; - absl::Duration exec_time_fused = absl::ZeroDuration(); - absl::Duration producer_output_read_time_unfused = absl::ZeroDuration(); + absl::Duration exec_time_fused = + kKernelLaunchOverhead * fused_consumers.size(); for (const HloInstruction* fused_consumer : fused_consumers) { - VLOG(8) << "Consumer: " << fused_consumer->name(); + VLOG(8) << "Fused consumer: " << fused_consumer->name(); float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - total_producer_utilization += utilization_by_this_consumer; // The model ignores consumer computation and output writes. The main goal // of the model is to compare estimates of fused and unfused cases. Since @@ -434,14 +464,10 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( // make it complete. auto analysis_fused = AnalyzeProducerConsumerFusion(*producer, *fused_consumer, *device_info); - auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_data.num_threads * utilization_by_this_consumer, analysis_fused, *device_info); - LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( - ShapeUtil::ElementsInRecursive(fused_consumer->shape()), - analysis_unfused, *device_info); absl::Duration compute_time_by_this_consumer = ComputeTime( *device_info, producer_data.flops * utilization_by_this_consumer, @@ -460,31 +486,43 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( exec_time_fused += std::max(compute_time_by_this_consumer, input_access_time_by_this_consumer); + } - int64_t n_bytes_total = std::llround(producer_data.bytes_written * - utilization_by_this_consumer); - int64_t n_bytes_net = std::min(producer_data.bytes_written, n_bytes_total); + // Multi-output fusion still writes the initial output of the producer. + // For now assume that the producer's output does not need to be recomputed. + if (multi_output) { + exec_time_fused += producer_data.write_time; + } - auto read_time_unfused = ReadTime( - *device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net, - n_bytes_total, fused_consumer->shape().element_type(), - /*coalesced=*/!TransposesMinorDimension(fused_consumer), - config.first_read_from_dram); + return exec_time_fused; +} - VLOG(10) << " Read time unfused: " << read_time_unfused; - producer_output_read_time_unfused += read_time_unfused; +GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( + const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + std::vector fused_consumers, bool multi_output) { + VLOG(8) << "Producer: " << producer->name(); + if (producer->opcode() == HloOpcode::kFusion) { + VLOG(10) << producer->fused_instructions_computation()->ToString(); } - absl::Duration time_unfused = - kKernelLaunchOverhead * (fused_consumer_count + 1) + - producer_data.exec_time + producer_output_read_time_unfused; + EstimateRunTimeData producer_data = + EstimateRunTimeForInstruction(producer, cost_analysis, config); + + absl::Duration time_unfused = EstimateUnfusedExecTime( + producer, producer_data, cost_analysis, config, fused_consumers); absl::Duration time_fused = - kKernelLaunchOverhead * fused_consumer_count + exec_time_fused; - // Multi-output fusion still writes the initial output of the producer. - // For now assume that the producer's output does not need to be recomputed. - if (multi_output) { - time_fused += producer_data.write_time; + EstimateFusedExecTime(producer, producer_data, cost_analysis, config, + fused_consumers, multi_output); + + int64_t fused_consumer_count = fused_consumers.size(); + float total_producer_utilization = 0; + + for (const HloInstruction* fused_consumer : fused_consumers) { + float utilization_by_this_consumer = cost_analysis->operand_utilization( + *fused_consumer, fused_consumer->operand_index(producer)); + total_producer_utilization += utilization_by_this_consumer; } if (VLOG_IS_ON(8)) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index 7afdedb8a748af..b7b28fff1eeda7 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -91,6 +91,18 @@ class GpuPerformanceModel { const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config); + static absl::Duration EstimateUnfusedExecTime( + const HloInstruction* producer, const EstimateRunTimeData& producer_data, + const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + std::vector fused_consumers); + + static absl::Duration EstimateFusedExecTime( + const HloInstruction* producer, const EstimateRunTimeData& producer_data, + const GpuHloCostAnalysis* cost_analysis, + const GpuPerformanceModelOptions& config, + std::vector fused_consumers, bool multi_output); + static RunTimes EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, const GpuPerformanceModelOptions& config, From 1ab3e95961daa6544b121d953156063da1942022 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 15 Nov 2023 04:31:29 -0800 Subject: [PATCH 114/391] [XLA] Allow XLA to propagate the layouts into parameters without layouts. PiperOrigin-RevId: 582624872 --- third_party/xla/xla/service/BUILD | 2 + .../xla/service/algebraic_simplifier_test.cc | 2 +- .../xla/xla/service/computation_layout.cc | 22 +++------ .../xla/xla/service/computation_layout.h | 4 +- .../xla/xla/service/layout_assignment.cc | 49 ++++++++++--------- .../xla/xla/service/layout_assignment.h | 4 +- .../xla/xla/service/layout_assignment_test.cc | 30 ++++++++++++ .../xla/xla/service/sharding_propagation.cc | 7 +++ 8 files changed, 78 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index aed7d55f90f0ae..52d88a01978b81 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -672,6 +672,7 @@ cc_library( "//xla:shape_tree", "//xla:shape_util", "//xla:sharding_op_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:util", @@ -5477,6 +5478,7 @@ cc_library( "//xla:shape_layout", "//xla:types", "//xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 6de7433f729ca7..fbbfff2a584049 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -7532,7 +7532,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { HloInstruction::CreateReshape(reshaped_shape, broadcast)); std::unique_ptr module = CreateNewVerifiedModule(); - module->AddEntryComputationWithLayouts(builder.Build()); + module->AddEntryComputation(builder.Build()); AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); diff --git a/third_party/xla/xla/service/computation_layout.cc b/third_party/xla/xla/service/computation_layout.cc index c4510143bd77e7..717b7ac550442e 100644 --- a/third_party/xla/xla/service/computation_layout.cc +++ b/third_party/xla/xla/service/computation_layout.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/service/computation_layout.h" -#include #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "xla/printer.h" @@ -34,8 +34,6 @@ ComputationLayout::ComputationLayout(const ProgramShape& program_shape, } if (ignore_layouts) { SetToDefaultLayout(); - } else { - SetToDefaultLayoutIfEmpty(); } } @@ -45,24 +43,18 @@ void ComputationLayout::SetToDefaultLayout() { } result_layout_.SetToDefaultLayout(); } - -void ComputationLayout::SetToDefaultLayoutIfEmpty() { - for (auto& parameter_layout : parameter_layouts_) { - if (!parameter_layout.LayoutIsSet()) { - parameter_layout.SetToDefaultLayout(); - } - } - if (!result_layout_.LayoutIsSet()) { - result_layout_.SetToDefaultLayout(); - } -} - bool ComputationLayout::LayoutIsSet() const { return absl::c_all_of(parameter_layouts_, [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && result_layout_.LayoutIsSet(); } +bool ComputationLayout::AnyLayoutSet() const { + return absl::c_any_of(parameter_layouts_, + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) || + result_layout_.LayoutIsSet(); +} + void ComputationLayout::Print(Printer* printer) const { printer->Append("("); if (!parameter_layouts_.empty()) { diff --git a/third_party/xla/xla/service/computation_layout.h b/third_party/xla/xla/service/computation_layout.h index 572aaac669e358..c44d34d5e9d04b 100644 --- a/third_party/xla/xla/service/computation_layout.h +++ b/third_party/xla/xla/service/computation_layout.h @@ -78,10 +78,10 @@ class ComputationLayout { // Sets layouts of all parameters and the result to the default layout. void SetToDefaultLayout(); - void SetToDefaultLayoutIfEmpty(); - // Returns true if all layouts (parameters and result) have been set. bool LayoutIsSet() const; + // Returns true if any layouts (parameters and result) have been set. + bool AnyLayoutSet() const; // Prints a string representation of this object. void Print(Printer* printer) const; diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 9d2183f22352ca..3e5e64a51823e5 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -221,12 +221,6 @@ std::string ComputationLayoutConstraint::ToString() const { layout_state_, computation_layout_.ToString()); } -LayoutAssignment::LayoutConstraints::LayoutConstraints( - HloComputation* computation, ComputationLayout* computation_layout, - int64_t priority) - : computation_(computation), - computation_constraint_(computation, computation_layout, priority) {} - PointsToSet::BufferSet* LayoutAssignment::GetBufferSet( const HloInstruction* instruction) const { auto it = buffer_sets_cache_.find(instruction); @@ -715,19 +709,23 @@ Status LayoutAssignment::AddMandatoryConstraints( } else if (instruction->opcode() == HloOpcode::kParameter) { if (reverse_computation_order_ || (constraints->computation()->IsEntryComputation() && - entry_computation_layout_->LayoutIsSet()) || + entry_computation_layout_->AnyLayoutSet()) || (conditional_mismatch_.count(constraints->computation()) == 0 && constraints->computation_constraint().parameter_layout_is_set())) { const ShapeLayout& parameter_layout = constraints->computation_layout().parameter_layout( instruction->parameter_number()); - // Parameter layouts must match the respective layout in - // ComputationLayout, if there is one. - TF_RETURN_IF_ERROR( - SetInstructionLayout(parameter_layout.shape(), instruction)); - if (reverse_computation_order_) { - TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( - instruction, parameter_layout.shape(), this)); + // Allow some paramter/result layouts to be unset in the entry + // computation. + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + TF_RETURN_IF_ERROR( + SetInstructionLayout(parameter_layout.shape(), instruction)); + if (reverse_computation_order_) { + TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( + instruction, parameter_layout.shape(), this)); + } } } } else if (IsLayoutConstrainedCustomCall(instruction)) { @@ -927,7 +925,8 @@ Status LayoutAssignment::AddMandatoryConstraints( current_priority_ + kNumberOfPropagationRounds)); } else if (reverse_computation_order_ || (constraints->computation()->IsEntryComputation() && - entry_computation_layout_->LayoutIsSet()) || + entry_computation_layout_->AnyLayoutSet() && + entry_computation_layout_->result_layout().LayoutIsSet()) || current_priority_ > LayoutConstraint::kBeginningPriority) { const ShapeLayout* result_layout = constraints->ResultLayout(); if (result_layout != nullptr) { @@ -1680,14 +1679,17 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( [&shape_layout, this, priority, user]( const ShapeIndex& index, const PointsToSet::BufferList& buffers) -> Status { - if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { + const auto& subshape = + ShapeUtil::GetSubshape(shape_layout.shape(), index); + if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index) && + subshape.has_layout()) { for (const LogicalBuffer* buffer : buffers) { if (buffer->shape().IsArray() && (buffer->instruction()->opcode() != HloOpcode::kReduce || !buffer->instruction()->shape().IsTuple())) { - TF_RETURN_IF_ERROR(SetBufferLayout( - ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), - *buffer, /*mandatory=*/false, /*dfs=*/true, priority, user)); + TF_RETURN_IF_ERROR(SetBufferLayout(subshape.layout(), *buffer, + /*mandatory=*/false, + /*dfs=*/true, priority, user)); } } } @@ -2202,7 +2204,8 @@ Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { } // Copy the root instruction's result if its layout does not match the result // layout constraint. - if (constraints.ResultLayout() != nullptr) { + if (constraints.ResultLayout() != nullptr && + constraints.ResultLayout()->LayoutIsSet()) { // Layout assignment at this point only does minor-to-major assignment so // tiling info should be ignored here for comparison. VLOG(5) << "Computation result layout needs root copying\n"; @@ -2250,7 +2253,7 @@ Status LayoutAssignment::CalculateComputationLayout( ShapeUtil::ForEachSubshape( operand->shape(), [this, &change, operand, update]( const Shape& subshape, const ShapeIndex& index) { - if (subshape.IsTuple()) { + if (subshape.IsTuple() || !subshape.has_layout()) { return; } auto param_layout = InferArrayLayout(operand, index); @@ -2661,10 +2664,10 @@ StatusOr LayoutAssignment::Run( computation_layouts_.emplace( module->entry_computation(), new LayoutConstraints(entry, - entry_computation_layout_->LayoutIsSet() + entry_computation_layout_->AnyLayoutSet() ? entry_computation_layout_ : nullptr, - entry_computation_layout_->LayoutIsSet() + entry_computation_layout_->AnyLayoutSet() ? LayoutConstraint::kGivenPriority : LayoutConstraint::kDefaultPriority)); for (int64_t i = 0; i < kNumberOfPropagationRounds; ++i) { diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index a901cef0b98295..18fd5df38f41a8 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -273,7 +273,9 @@ class LayoutAssignment : public HloModulePass { public: explicit LayoutConstraints(HloComputation* computation, ComputationLayout* computation_layout, - int64_t priority); + int64_t priority) + : computation_(computation), + computation_constraint_(computation, computation_layout, priority) {} ~LayoutConstraints() = default; const HloComputation* computation() const { return computation_; } diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index c3665e7bea1886..e0db0346e4c826 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -1692,5 +1692,35 @@ ENTRY main { // Expecting a copy before custom call to reconcile the different layouts. EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kCopy); } + +// Test the ability to enforce a partially specified parameter constraint. +TEST_F(LayoutAssignmentTest, PartialEntryParameterLayout) { + const char* module_str = R"( + HloModule EntryLayout, entry_computation_layout={(f32[32,650]{1,0},s32[16,1,18]{0,1,2})->(f32[650,32]{1,0},s32[18,16,1]{0,1,2})} + + ENTRY %main { + operand = f32[32,650] parameter(0) + transpose = transpose(operand), dimensions={1,0} + indices = s32[16,1,18] parameter(1) + transpose_indices = transpose(indices), dimensions={2,0,1} + ROOT t = tuple(transpose, transpose_indices) + } )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); + // Allow propagation only to parameter 0 + m->mutable_entry_computation_layout()->mutable_parameter_layout(0)->Clear(); + + LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), + nullptr); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); + // Assign bitcasting layout to parameter 0 + ExpectLayoutIs(m->entry_computation_layout().parameter_layout(0).shape(), + {0, 1}); + // Parameter layout that is set is unmodified. + ExpectLayoutIs(m->entry_computation_layout().parameter_layout(1).shape(), + {0, 1, 2}); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index c82bb6bc87daaf..b72ac39b0870a9 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/sharding_op_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -2818,6 +2819,12 @@ Status ShardingPropagation::CanonicalizeLayouts(HloModule* module) { LOG(INFO) << "There is no registered layout_canonicalization_callback."; return OkStatus(); } + // If the result layout is automatically set, allow layout assignment to + // choose the layout. + if (!module->entry_computation_layout().LayoutIsSet() || + !module->entry_computation_layout().result_layout().LayoutIsSet()) { + return OkStatus(); + } TF_ASSIGN_OR_RETURN(auto layouts, module->layout_canonicalization_callback()(*module)); Shape& result_shape = layouts.second; From 4619736580692868d1b0ec6ee8d81908d20ddef9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 06:23:35 -0800 Subject: [PATCH 115/391] This is an automatic update to the GPU allowlist. PiperOrigin-RevId: 582651501 --- .../compatibility/gpu_compatibility.bin | Bin 27992 -> 31856 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin b/tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.bin index 8108897c68e54b083e88834a2800262dafffbdbf..e0d9029a639d79618178c578385689b05cb9ded9 100644 GIT binary patch literal 31856 zcmZ{t4_IAip4U$(lev>&GMNsWWxH%;vuu}EHhZ~T;L>Fy+hwzCymm3mR2;Dvl1l>g zLhiy1rFnSr@bqCvj94*Z#1U(Z7_nl+iWMtXtXMH(#flXxMywdI#)=U~jF|fQo%7!C zo*&0P^P7A4obUO*zw3Y{!`;MrM&LK$CcLttOZ@1SZ29cd7b!ymshj$Iu4I3 zuRK^jI)^A=dE&f3uWsXYOnIFXuDqIHOK(tHBFS>SsZ+LI?;Ed0)=N#Z{0@9JA+7v#;0=Gk-{vmcm!Qih7&1)m zl`-+oN3V15G~S}ObAJ_{uX1mME%%Wlv7dW|^jA0cOo0Dg<3Gzbw*Lq4$qql8g{9Cn z9rQ1k{8t0~U%2MgrcwDn3vcZ7`{%>*(CuD6kUvM7@sIVoKG%%@4%^uIUxjbB`&kAy zjc%OSuYE<8c#pEKvDL4?Z@fisdw&KWR^BVH26Q#VI?L_KEAmIYI>eRNbMUyTb0I7j z-8N%t3CsP;>#Ff`lvfXY`=USBcGw&~lSGE)VdZt#c+Ilj-p3E%lgeurmO|G=EU>JP z#qkkuAKR7J5qOjGa$rU1@`!I!|JGzrTlb|zNv7-hL}j$xDPnz%%C*nSA3mVG)?sz%!o(KKv&!q#TVB0tl~)8FR$e8r-QVKciionjq`a;fuN~Ig`*;<; zne?*^Y#QA-G05_|^7_Z@d;+pG%X62zk8&pH=hxBy0d(J2As@ zQh6O7_4ct}c|AhguDrHkVRVJW9?MH=?mBC{N|o1Xc%kyDfo)!LoE4&;<#pxtw(;6v zy}ggu;VWt%hGFCAhKMnix79XZ7_TwK@4<(Z*D5TIE=D|Nxy5XAnEKWnWeoln`4TNm z&MV*aatrxlwli~El`Vwh+mN{C%9f|iMKB1wb^xiDH` zk{c(IM3~s)zGRu0AW}pP5h7MeaFR$9HAEh<#=Xio(Lxjwn-p?^7$usC5@P2m7%@UL z5=F#z4s8%}Ka(VCi9BMBjy6dQ5)DKVu}zw@vVTMkfu&P<{or+9DQt)9{0)LKI#)7n ziF2veWLqK?>tPP;4A`axwt#+`_>OOz*78KEBhv=E6|mhnY?y5_LEA;bkuckR7O>q4 ztbuK6i0|t5BH^-X5`Fql9Nf1MtblErNr*$?^u6{*A{DN7P6cSAum_9{GDf=E^Hd-b zaXJIEEwBX{M~t?&8*P8YxdAV2@t+qXuraiQ#Cri=)36Nd=91oZQ*T$IySg%bfpnew z0s3`V9s1#YeP^sJd{Okr#}3vY50;O{?SI6&=r8!&pNA#UPocjb;I|H|WBoAe!*_kNtBXp>zXiQb6R>ZJgE47SzjIHl;n^+X?QO(ZY3QbMOM zV4DesvZTnNxN!xqtp&`bT(=gB8h?Oca%{Xze+9EOdf9U|Th z@S1}4vu>W&WbEm>l)03+SW@+TBAsgGzWjc`HtVoDwvlUX9x++lTuNs;y12^i&y-bH zM3SlJo#+14!TJ=!a@oG#-)3dD&F=?n(*Rp(=em+@^7q?h8@I7Yg*3*0Hv+aDfsL{4 zT8p`#_V%$-Huxw&KMzZypF%J7b!*aA={yY3Zoy(`2eRI+J@k$8R)#$1#Q4E|FMx&6 z$I*wx+tqu0>1=?$9=67rU5MVz#nU@i1GE{~G}`Ti?OjG|T>k{OUEgy7+9uc%+Pti` zE72ci)_N;II|Q3RyYZ^yd{gr6O8iz;S@`6Ii`~u#0s2{33jGZFBR;+3d>NqMgw>!Q zwDgs(zvDdn!-F-+hvlJfK%Zy)(U*{q+jq|fY*Px9`uhgd8xf-C&z^2h| zGX{K3w3psUl=1pbfc7SA5bYvb>4%-k?ua`vv;0|reg)QmzJ`$6aD|adq$(=H75x2Y z4(_W1D?*z`NbQY29FC-&7X$QR*f#UoI`Nd~d1Q1hK-UCYLX$_{htajao(PvQzq}Qo z9fD1u-Dvh|*WSam2iFMi2WV$tX|%IhZI{=#?+0ktVRdMSjrM%HJrWL=IZsa^e-^m3)7Zkov7Q zk|z;8%RTPzrT-FPVz~iM6c9_l!M+g=G4V3vC5BO68#utrzlZ%Ga*3O9`UWw>fIj_8 z+!wRS^6#b%YL(A68x*IR6KRZMHNkh!uuK$Cuo_O{c)sI2c#3vOun0c?>*$Hydio_1 zBUTu6Hh+zIsg5y%C?-~ah4%~Tn7Ooblbg>aI%t^JMh><-D`>_){pUHC$p`Bvedrc! zi~Y9QIE-d|>`RzzHp(1pYd?q2s`G74wnrB=^KnmGYofcqI^Y>^)XBZY9IM)@gf*(1 zN?|&uH(`>~9`l@?Qy;uY<&=WWqPs~@4kuC`j#j77Cu0{6K6|c;R5%mpH&k99!{+H1 z(w2+?NwaFY_H zOZ2w&bhPy*&h>XCVu?saB|d-g;5k_dYh;|U+bD${*nb%IhxV8kZEYXCNadJ<&7!+W z$T-~>DT_vVb?;}r_q5Bxyh$i;ppK*6WS#AK7q-B7CrtW&h~Tg}i*TLSGq=6Gk}9vG z@C@77zPYeMbRmLa(MfmobI}}g3gL7)pSH*N<|*F`;-R$7u*KicI7~>sj}Dp($6n_AWPKZGsG*v@+C`|ZR6A%8EhPVf8izvS(EJMG&0c?4enOMd3S ziqPc|Qcs?#^;U=F2?zHwqUY82=rP`+xA_J5gBm|;hAlD|&Jg?FWv+LBfcMA7TlBW~ zhwxG5JqMF{yV2)e>3aV>Ik$FSOtmFSdRx1CFLz%|^q)_4qwoRSMf-PTL^9OV!t%Jr5GMC2#T0ju*M44;ye^x*K8#M+ z;T{ahp_{a4dwv36G~+<_SS`Vt(A5#%*uw0}#v1{?mG5}I^~(1JcuM&ei5A@h>Xxl< zId9d+;5y&8jPJIp@tg1!I+?u>{je!?V}y(Wkw`dN7AdP_x^>0^yuUEsqPKnS!RO4_ zkgdfkERHUwyt&M9wgbFR{}rzuqPM+Ygcqp1!?5i}?*9pQE{}w($UD-Cd7O2+Cbx{Y z=xy&C@cEdZjljmx4HDT}N11rPMc&>R$x|5Shf4mQ3ux=;w7VZUY1LK^d`z8pyKG;A zE(OoWFFQZ>T)KF`_yOJ38{8E;km!=-=`MXg3N=o^AiW#$#C9; zd4zkpmvtj(2YI#Y^Z{)7*Zgc2rgJ*-SG{AHQFHENv>CPi4y+Vi0TH6Tu1FcrW~w=Y z2hZ!uXv8VTdqm~+3M{I8<1n4qEtA)}x|Y5HpQpn1zK+1g&olzHw8y?f zdo!GDi%iDy9^wWAPP0tRM4IIXuqHO%Nbv4%fA~WqjAwym|w^{L6md5!e{ILF3CCkk#C- zIFkXskBsk-^8FH?QNEk78gwOu)Y$bk@4Y+k9?ba_n3PuZc8$-#OH?gtVH?c9^86*+ zFXRkt3h=#Sd{@xgzPI2jD(7+7FuIKKb&nm-u4e*#5C07>-+tx$2yMUe-G+tH6i;r?mN*_^P&6 zU=8TT-~rb)ZLxBBV#+g!fOcN~n{M55!m6E@;Ek%C7;KBVvh@GkPB<*>w4mGkDSwTw z!*)-3jG@;_KE>R~4$1c*M!&;70(0*O%b{;GzaC|7WlmkoV?5;GPjk^3Cz`*>XOQ@N zm*o=t^S{ONG48!6aE#Y9eh<*(6~O3_^ZdPzZyh)3{k*Qz&Jfx$wv*2p&HotB9MHzO zNQ?5CM|*MDJepNHM)GO$CB(bt8Rqum`2M)#Y@mseYdyYswQ|(>_GT`)^{t2NI=pRs zH`KNDb@8qC*JcU z{?5U?UxBeEr$l+5ftM)nTG&P%W47g87LK@mxjDf5UE{sXHg?~*4PTD=*#v9|oxG1^ z^6n}NS6z^_=K{Qs&U<;Mm3Iz2t-N>Hz64zg@5A`@o=-$0t;|Dc^>Msld<&HC_u&P~ zw*j^ybK-u^t>t0w+)f2}e`vh*JpLYho*TJba?75B)36NQbHBLXNASKN#a;^Ve)jKr z^=M%m+vjO`i}DV^3eYuM-t>_-#JiYvx+YhQcdqj8hUY5pBy67RW_h#I>f3JKg8|;3 zdftbo)V20q_*A8TUlw33=;B1yo0_=i&3b@$&epe%KVcF(R8c^}a0kePaRMUl?!E+urx!W6FCK7DpE& zWK3l~PxHYS=Rtt)OMlPnBQ?sm5?-TxOJO^;_%a5VoO?OflQQ4dvrgCJZR5ScHg?X} z;Tusu8-|UeTjzYr)&p<19+LsykBs+_^8OM&q`Ws_HRuNM4moZ;TFcDmaGiyNdA|Zf zmQ$j<&%jHRcP(s#IkDL9?T%5+0p9N#?`5{JYjPXDT<&KRupx9yoKO3CM_wm4XD-0| z=->D9PAl&mcv^YyvV94Z>toqzyz`a!tMGi~-3VJ|emsWv{;~ANsdRw% zN5*@OZS0!751)(p*$ga=ZkBV^^zBRq@0=|K_&)ogcdoZ6->2a%$~Ocn5FL@NHRq#y zt{1aT*W-%u7QLNwH(WmdV6!A_p80Xgn%fyiE=aKl1H3==ymLfv``m?3F+kXC0oH%-^mdMK!nebI z)(@LPH@5GcDGyh(W{B*ZSNQ1 zG36bGZ8MLp6JNux_k3TJ;n1mv>w3Ixd^gbAzSrR!zu;%XuyJ%lM0RYY|G8u9WPtZ0 z<1KpI`%Czc^4^5iperG=Idl3rdjY;@|DjilQssLZUaEX+V4E-USs=pnJvkTMYriJe z>3ZBT-lDg2z6D?Sc|RM64Wr8trXCl|t7Gmp!)$=}k&nE*MQ?jQMw>b5XFIS`bOprz z{8-`S;h}RrXM78l?{b-u)5t8|71!qmgPhNsj{dw}-`##{7u&hNt8mG=~kk1sh% z%e$)F`|itPfcG>1$g79wZSSM-JmsAWD->VLo4)7HVMVOdHR&_nqPM*};WO%dX@M;; zH%=4gIy)R;;Oh_Y{n+@%E zy$owaSL^q7&&RC*?^FNStBL4s?+AQLoiin{T?UO!B0D~nn`bR`tkX65q45^I?fn*f zM0pRvCeRHN`*V8OnF#QGXnaSM?|pc<)Zf-RtPWk6FnvUxJ-PjQH^BSlyIws+Z|D3H zeD!zwSq!$toV#kxyBtT?yODLe9`6`$(c9i{!~2!@7;FSx|Gu|8A9B5C0=y6Z6EAPk z+uo1R7Ao&;n0yzYaNk?b8`nDzAD#C(<1KpI`xSUziJ!$`t6U#15T@RpRT0m-J;3_| z<1KpI`(1duI+jzgesoD9JEzMvlyvUJ0Pkm(y_$&L_C5;FQ{K6-LUbX^n`;r*yNGqV zCVj?R^tN{=eD^s&Yk@6reLPK=nzU9`c;7)B2=M;Ic#Gcl{s=yyyysy_bPfC7U5r!i zHT+tD_sM_i)kO5R_X&8d@-Bdd(Cu(7Yw{Lvx1U5=r)zS}c#Gclz6zgs!Ot?VX>{ZJ z-cl3Sdo;lNzVQ~l?fn_Nro_)yU=8SMh-^O*Z|Nrw0=!@P*sF=?ZSP8W>CgIEDQt(i zbA!kpOSxy4vAdpix+ZTMZ_(S{*WoM0el`pnM>j-d^QI5F*Q}EP-j9s8=xy&W;ib>} z*(R(8T?vuRTf8N2=bs(S`xO{#4vF6OJ_9dMdDp@=n46ag>a-(2O}4d3}`e_Kh|JahgOk-h(5XmJJt zd_OV1(`fB{KY}+D``J7!iLQaj_753byYXHN@ILv8R}0bG-Y4L-D(3=N2;C0%CDt)? z$JQw8bUm&aZ_(S{SK$+C9Ld0@(T(qW%QIqc{0s2DZ@fisdw&KWF7mS#SOdBm%bS`= zNt_1(-Y@+NuO_0my({6H#eP-_+hOk9AWTiTzlpec%ir10F-u!bXmyS5cx~mZsQLF6 zeCC+Ht#Q~ex@B_6$4`AeuZ{2Vb0%`Ffo_jZCAsGi@*b-!{|G)-a3+q>jMl_wV9Q@mVcGep8}w_254Az}W1NzJ|B}4;TCQp&7Qw_16p` z{k}tfe}UIbfbU)7JFBkOK7g*CGvh6K+xuhqkn&!HHKVI1zKLIJFJD3I2=nx$)!7K}uKd)iiRf+b7vQNAepVz} zbPt~2uXR~fxf6rydc0+Px6#_ZZ^HAHZ$E4b-57Di%DFNL3-JEJc#Gclz6a0H!ELq* zi=&GXc~Wcld-S{mAYTUIznuX8Gylq~iG26a_BsWZ&sNwh3VTq(7)$)P@!$WPW_dXK zxtV5kx<>Dsc9z-J_PhuZR|Xvd5EzB{l| zJPWeEmptD*JnGpWJBM?|w-Bw(UxC*u-#Bdb=NOX+H(z|q$u}9``@ZpAL~HxrfiM1~ zpH0FB(6tlaAU{5?TUAz8-g}Ydd#sx_9t>H&-|1XuU7@OJ=!cvV7z-K3{G8rB-roOcTvSh;{ndzY1g3p(2&jS>YUJOA*s7kE*#K=ihia4k)i5 zp>0-P4`5MrMZ|Ho%lw4bR)~B{sL?X7@_R|eM4G=B30WugDM7F6bJeu%sJ46H#XsfW zk9OExF~^^f`tW(i@^U`UI2hpjsqvjcYuDl~e7Vri7GN#t;>2NmdpYN#=NX# zwbIv-C*bnCRCZehun?X*)LZg(zgNR|6T@&_haVWoH;p)`YH(kUDg?MIvR~zikBgR{mSB4)x0S1$auGw?(2w_uz!(%f)*P zuIuoY@!dvi*YZucd|%jR{V=(=86$oN`9AsFrTbZbIdzz<5|>}ke(ZlLpuNYYy+PI9 zLwJj7ZwnSfS4_zDq`y7pEce$F4vr^0V|R-09o*L|Fl=&iRomV0l4tz;nuN{&G-`O@ zwQRZjb6(B}I{9p()G3FU`I@)d*#S+xI#=SNRUZWS$-(vgca`-+{gf*S@(bjOKJkD}DSzly)WPON`_LA^5HQv1s zW7Q!&kKchWs{8Uu*Z^&|6EiI9njij;o>xEX?fO1K+poN~VPSNIgj^@|ceGyW>hSuy z{e51(_m^U|Q-t^Ycm2M7roE7AuM_@2&3!Gf1=uu^VtHA4eQ3Pq)pqZ}=Tz>~unf8u zVvgl?<#qJFSBJFn%7Hg3w#)W1Pv;9~y&l&pt!?Bu%A6|mt9+(k&bvZl9^OJo&Fj(Y zdsY8Q)>$`vrVWJsa78)uJmk+TOOV zwD|4@_`dv~ynMsT_a*qC@{Pf^m=jkCsZaV+Z%2mT^AYbx*6CWjW4uLg_tCfE#majO zHiE97kY^)Zms6RJR4YGRLH`f%KD=f6zw&;Bc1b+Lq)gU_q`=4n_4T?-+7&OC39vFI!Xct87vR}azK z-cQ3D*ioB>USNQdnC~%pD5*8 zgY~*LH;n%l+t~is;3Mjo4#IAt8zHRRc{i4OeHzFpM}>ku-g9l zusn3T-wDW{Q8CJTU7u^le}`>sud8tR&2pP%VAJTvgZz1iXe_}03*$eg{O`eQRNq;J z#nHtA{bSBffd85Q=G7;v{7=ExRsEwd`L5y`ap3s!S?osE>-xN7{8!an@-}=x`H#Uy z(DeuUN1WLJ|0CO`|11B;XbYA94y+ViL7+eVA1|GMr}58M{;$H1ss7UlTc$8_{i^Fz zA^o5AqiP(yYy4-~!0sy_z~%Xk&1PXKbWOwp%R1NN|J|!ayYe~$m*3a7y&PB(I{EG@ z|MQIn6=e~pi1l?U*FNJf-_NzZI^nr0*B00U*LTxIie;VahsJAOT?4)c&r@F0unf8u z;y|u^7jZej|HMPD7R}24I9&c+0XzRZSUx%@&_Cvs$~Nj;d(-&usX4h1zRUfB?Vp0p zqPt1-vplW5J~m!6%IibXDz7ygeo)>49dlPB*5q#7v%inZe_r7I;*CesW z`UYO@9MAp_w_Z7=Du>f>eLq|S+hk%bp^nlnU)yxAQLFhine+VJ%zAyi-!=ZrYOcNw zUskzHz=qH*k)S+x;CoD&NOgrX8{m6n$ICmTd>^CDDBm4eDY~@JH|*r$q3dzZ`0Dr3 zUV-a+#9^zv0bPKv)I**Yafr=tdX%4xu*~%7WG={$Pn=A)^>y$sCv;s(wxyigY`e*} z_OY0N>H0l{VR2|gY45|M>e<{ntPWk6kb8aaH(-O#ov0J~pNzxm`sP)bv{j_qIt$nP zUJqMioLD5>z7Xa3!?u*?yT)@_jVZSSJSSj7=+aqFuEu9zt7_ZBKl1YISNu_c=Qb>i zt}yG#RWfCB`K|6C5kE6R8rIl`o27mU**6_e>z6YV7SD2C!ZyG zKr!UAC31e&!`9d=O~`Q`z*`^RPrQA}(PP6$@KH5h%){j0S((C@%T@gunf^-e(Y&)6 z(9Vnh%R67{l;=tKpvezgH_u|&UWjXbfYd;J7KeW(E9}&x(dW+Fri~4@mHo~7U5D%b zGYpe|BW4|)+kYy8+mi1w&IYt~WY^oDjB4v~KwCSoQgm`0+_t!b4BD?Ueu%ON-xhRs z{(YvckZP+FUaGV$umy7ta2h$^d-*rN*U<1AXB&UD&TU(M#6-pj_g*%`r=?jQW1D*A z^M%*uA^CSb?Og7``xRS-#bGgG-#aE>WpcIyyifmc@B9(H?foKL-}8iF+w^<+Ee*?? zr@{1h*6F-&8E?_s-Z$V2ijBa=V1s^d_o4e#fcInLEqdGgAza_{Y{6paihbTyZr-_% z4(9!R7`t>x^tSiQ@Q`A4uyw|OC1SrOl@WO`(GuYOp79pF?fnjXTCtn3L0C$8GrT(U z0p3slA1`mw+un!a4T|lteKAa~Uk_YQQT0N$(Z{pL_~$9F3-E0kwA*ZkEdpkU1OA*_ z0|EY@82@SI^$~njv3Xb$))3^+mpIk~{3Cl_ed5aNIe1vHLYQ37ZZnP_$e;J@Ygn)A zbKUsM?`N~=lQHiH@D;@dV3V-X0Dl=%Cj$H*8vhaHbswHqY#mky3kUf#((DEJpZ&jH zeM*(rY4|b4YG9iTpeux~Pn35}&n3CeLj(=Iz}B)h>sz~%J(oI?iQeREiN3ZT_vgdh z_T)3E$nER@uGa3$7h2C}F8A=p#)0#P@kR3AZ`a%?z@*mY{6e7QYt~;%xAB9=34Yi6 JNvn?Y{{Z1KOn(3X literal 27992 zcmZ{t4^&-OzSlPyCc|Wy43ptSULIDakIAEz7w^4X;3f|vQ<*%BygckefHkx+&`0R9zyV&U7yq=;f-Imei5M3h+jmNBD54e^kQMu=)+^P9$u z6V=2jmG=|*#0-t7C)W2FbCakgwzG_xBx;B)+A%>i5Xep4=+A#X68E;*oo2UTY-X*# z>gPA$Q-9pghGC=V28o+&?u{gE5qan~vTqwqrIUkJ-ax7le-3ELaW z>kaE=lvgi&>5uvA?SReVGeM-JoXX)R)@z3SZVo?!UsGN)umrkBVxH|h<#qT^d&k(J zybi!)%FDot(B(?7U9Zl^dFxfEyw1W4m9`GH+(DZNE)`Q2G3ML6+Vc_)v7JPt&-p!D z$1LUCW4{ZZ{VhM6h9%K85xeKSBE~t7VsDlL{Ey!B^3kaLkHF`D!|$I5%R{$w(J&rm z+FlV8?^5>F^6IwnP3tXsmtTSJyzFNQ*bKUHVz=$((qGeU9}e(;Wc|mK|7Y;TANBh$ z!Wz(3@AWT}_HP9Ezw~Fk+*B$53V4-je<^IMow1Wp?dME0{;^-@=eG4zy+!YOKY`aP?{!!VT`{r6_O|jmeap*3RC%3*NB@Z5 zKMLDmJY6BG+0J1CaO=Nez1G#3dmTQj#*jhSIJ#j8whNTkBkMJ$ygq{`l-DAx0bMn* z!nQsa$NsE$jP=Uvd3cfXDuCsn+lm|WJ@Vfg?``YuX)Z~1z0h0{Ei*CpFRC2AZT&YX z{n1?JebXbj)7=%IgTcM0w@G z^3d&EFs4wp)&5@Vm8-nk;j60NCfEW#(?kc`#meg=>ow2*Y;v5L13!R|D6dIaKe{+E z!*)!09r$zJF{YH)Gqg3zYYSG2u7EIQd}j=}Sl zS21j-mGuu1VmtaL{KxW^_1b2?dyH4%_tn^ugiWIxBL>)BSM7RWy{44cJ@|m~nuEpB z)f4yG-cer9{dw;gTa?!!cu0BWzzRf1LBZu?{a++159ok-|^a3gl<*MvoFKOe&X%_HOh{njT2ERLo6R)okygI8ba<#)>#P6 z5Q9V=QAlhs$;}XhL>*B;tbdEV5ko{HQA}*ozKOb|(;jwm43vKXg`L86{0A~vy@A^M3b&(56r3wzgnHL%rJ z=z|1lG*?q?&F2%X@wVnftoO1pZw8dZpbY!a>9Q;2V0H`@M)c|Slq0UJQuL8P@^&8f)oa`SnBb_rIGZWLd- zlj-Yf?mk`-YG>0N_=|gw&%lb%$9MIYV&Tw5(Vqy=hhUpr&(U4|#p4y_Y5z;?+jzmd zF6che{Y(pNUao)i*8}_pVdLyyW9+i+ZLJ8Er|Z8Ppr3{%(a-qx>H2@hzNzQ^^*@AF zqrZm!z(m0D!;07+XTRjOx9!T6MDOdSB0yUT+v46>=3Z%QZ(F>pv$-$ArfCn*H^COr z%leRE#p+|N&GE~v3I4tnP-X-+LYZae?AK`(0vX^*E5r$m1)Kwl4A;yyaxuTQtJ0>UK%K-fULJ`jUNGkZ%Gbl>-e1-bPMsB*I^PZ` za}72|nTO2D`MY&eCS6&iTq@(=Cjn*WVJ(!MpseJ?9%F?Rcp9MJgoV(j&`ZA4`nC#l z^e^wtMIkI3ZEeQ8wU@q9*4jn6vjO@#*m5J|AA0F4U1`17x84fS55OkSuQN}l+v4fX z2LakiSU=i1w7cHn4yrMa0`x1eTJ*zy{dv(Jn%sM=Ij{osO|~tG*IU#5@MM5C3fo{! z5OTH536}!2EwFi6KV-B$-D%$&0oq~MDB6|Z^zxDDYVMCRqkSBporQIvokGi!!=U}` z%5dmd`^9eaBtX9oi=iJt&vL}k8}s~M*_)37SPuF+M_-ZlH|FI4eJyN-d9M(?JR6Yq zpf904>Aw3`K$!v91ZCFGc*lFGPjXXcJ_yiG!urw9p_S#%rJmQD!@PbGpk0R5p&dah zedSWTJCdHX**^H*-eb)chCYXo+_3scG$+c-L*>St4A4hmGHT3MTV*Nk*Gs><&duZ-BM~Hj8$r z&TH@aS4lC~%)0^FQP?orC97@i_2z_60<`lmx#yh7XcNiw_4QSNb^{hg+mBZA+gje7 z?Cs&&oMI08L;iI`1#I@X$5 zIFI32q|zMxYkS+DEevgoqpf5PJ`tb|!8V!OqCRc8X%Emg!4}Zw__Spv1z-Jj|1sWy z)skfQxbFsdPQ#LTHW6~Idh=a#Nt97$k$nlc>+uY>S?Mu3k5z>1Km7jQJmkV;-ZsQ{ zc}eiZ>q<}a)#IVoXo$a02b8IX$-KQtd~df5YuHFAQC4AE*gp(++w=~s$!SxjZ@m}b zKMotjJ3+|tUg@MoSIW9t`;GZLK)(d5M_WbYqW9LQ-96lrmGNi#xjeTd=I~!z!5+#vFa8FVY~`v<-E{?=&E zI325sSp7A0L;*2O}tjo=nV#hF`|wb_z~j;<*P&J=xlO7FZcF4c#lqVzlUcw zLDlKkd$5vnKbwP*HlOa$-|(J?B$$WY^84W<>YCf3d@;H_g6hqcNH{8E3eDPm&xSWA zn(b>jkn;{%6kMhf^NX)9Zq{s1$giKn_jy`?|OfQ zc8%-YWgD<4x+0%iqmIcu9?)4ZvjHUWX5{oy61ix(kyvfXk*~ z!>9dh|BROdnb%$IQ?z4vyKEB{LRU!0cII~7?-FirXwuy z+Rtm8*G1d+>d@(Pf7`Zgl{wJ$ydk{C&xT>6=mrVzdW@8D>6wWD->261n(}=NpJe=V z>so_Vp(`O6USus0DdTOmodDmL|CX1_QssLJK1yR;-zwM|=VFPFYcdiEMZ=MBMJL;h z?9;h;-+GJQ^|=k-ddbgj!iLbLl=nrvX9BzreCV}X^se_aw5z|*&$eKt=n9AfUb`!4 zccc~bJbZM!+pV|gUGH=7LFHW!TVl?eBQm*;a!>PJ>hZ=b?ur;wnF7o}KwS^Ky7o;` z-mUAN`2N_>=3sGjlkj}}l0C1r^3r?0QS}^y zS2O>)br-{SNJ=&J{NmnvQd~D79;saJV_!YmL0;YRZ^9DC{H!0Q^ZY4{qFK}IaLa!T zpH}1KEUW`v1Cc|au1J`BrQ^~)_FUH$(TG_`Kg>SY_sA^ID;W4(mIuSq$J$|_tO9np z+H>$hRlXj!^y{2|Le7b-PkVb3t<7!mi;S&2p^m%&-qRmAJTpw$n9BVlukL+g__=j| z2Cre=;<80p1G<~=71_@q;_g52ceq|lJ$8`#MhJNyK*pGPgjH-mWGt&^dl436;*|AE zAuLG@^0%yeqWI_>U58O5YwMK1&Ntv&cI?C=-JTR|3f)cOTQIpdIbL3Fh68*bS>G|` z`x!hz?{<%G5!QgN+VPE5R+_Z{-xGJe_SPuhWAGZ)-eTAe^V~Y|ZQAQScZ{)8r7rgA zT->nUqIcVO9lox{&_S52V}}XJMQ>Ms_;@62ZU*>1w!S0E_Yu5BtqoVCJi4gsTNyT+ z0lugHu9u6L@;w31QNAJAW)1yX*cx-~ z67lu9<5@`qT<79F>${BBZSTAAWi>XA!iLfHJHD~?bJ zMOQ@RkYm=Z7h|!?ZZ=I0KDy1XSZ~q0ZGIVEq}p5yTVW1gAmsd|?^DCa%Vb=O2YBDL z-lBKC--nN?b?XFd09^<11N^=oQ+s&Rto?FdXHU1eA5h10A9=ZHQFR=G=WtKv)|&$> z5FPP#ABmOiolC+MCW=<)>bh6gzO5hoecupYRaXi&g>I8Ybilmb7AuoGbnX=c>Um=8 z8CLat37@vlh%)2sDy#Aa{O?om60h4F%OJ;BIKTHi_rj<-E4SL8uXcFu@Ab1L z*aCC$G$Gdn*IR2D?`?Lo-5=onsr44U+c!UkPrvAAv#<_y4Frdg_Kt?i``9$g0p2hE z1FzkpcfFs3Hz@CXST4G4)*_j9M?&S7!+3|-SIevG@s{-#z00q{w_otHBy1Yp7$Mgg zIZ5(>*$f5ve&PA<8&|#$;N!~oKCDrESrgbc_k}~eO-CJT0p2G*PUmA^jq*MQuTkE` zupKUf>Rsx5hl1Igerc{`Xjb_Y>*Q?&icdlMEym%`ij6V7FMyIhmG_~>?@ zx88-y`z*XrdDp>~<=G$Jwr{7(!{sI(;Ct8e&00Y3=Hq?%!cjk)fDNGQAaZD@j3Zo+ z>1%yH!29q&@~(B!yWR)j9m?CliqPc}nS3ygv`exJ*{Ac->v``Jz3bf$&sE+{um$G6 zoU+|^%XNdb=@0P!)bq}g=ccar$MESN`PnS216_mof1Q)atKz*J;Qit~FDIgRy`O_O zDDQk&F1l^zP3xU13st7u9b%u($y?T2^se_+_;#_MC1Em8j}e)C%bbv+-NOOikF2-o zUGLA}W6FCG)_|^>U@9|x7p_FvG;0CAC;qXQj~eBB3|^yri(xywIj~OHHuus0(qn5i z`*c2TSZ~q0ZN3g)|7AZLgpH#cCNgbim^C*8d>>oi5#{>`KB9bAq&&JPkvWH4i*lKp ztpM-S_q}|G-fi#8)f&D?wCFaN!*=KNkcq){KHj#zn`m9%H{qi%_*p+} z65Z(T+<1{blJ3`|0p35e-lBKCzktV-_cBcG{keU0bDj3y4Dden&%Aty-t|5KUoG*o z5NwmX(^bNrN80UOk9F+RIk|1UMelmwfDbC~VVKn(cM`w6^W zjjQXh7`ozJ?@N^t&pZ3mz3n~+qx)ou-t|5MUsU<7fvs|Lbf2*Oj}HuZYs!`Y?+>lF z=w0vk;7R2@4jV+5aJ*TIr01Hs0PjQp+-tY!UGM$ysvr8<4&{r{$j=&K_j$T7Ls;)F);Wx>ti1!gKe67TcfCJ?UsK*QumrkBzjwNy zECqNUUGj1wde{4iaMcI%V0q|vm_MDIq`gbor*rbA^%lMBeFZ+Q)|Cm^47zb5(@)5E z`dkhLcz7<-RG^h=w0u# z@Qp%0tAi~wH_sE9b4mZcAl@AT-XB?S(YxLsz(wQGHx^K#Z<)Pc*nPBD|GH!e0 zANzDp-n8DLcfGH`7mxT^0ycwg-0>zSk`gl%;QfX57QO5JK(uOpzYl9fSL1l|W^20L zavYh@3rSsD0d<}JES;l$QB~JT_}KUS^+#dyJ(mz|$j48;@7~7uXP7})qodoQQ%T$R z5u0qt_B-$<>-k;&9su6tIkfz~#0IYeup(lSkk6V7!{s+9>b^nlh%I6k4}C5U{%h|% z>ovqa_~;M)$FNQL5{@CCkbd7O-+|&a53TlXx4s2vrTsK1{dx|r*A(@zr4pXI6S7`T zf7T>aX*vRYKeE2_XkFhA;3WtBY!cRwE>7g&%b*nwzZyx;6AJ;}M^?OCh~D)+2wx=W zF3W}$iZAhP>m7MD92WB;_SN$0>aSRD(YyQ-JgK}}VDrDi=MstU;M3a2M}0d(JUwYO z0|DL-theZ0?|bkiH80P>;^^v$Z{aPu4p)|$-yQFRlMg<^Q>G<-Y{0 zM^{DsF6+Ph`HQkp=KaF$fO^h+?&TUjx%lauEjTMV|&+`3Hcr=Gr!=5qOzUOimr z=Z^K&YlvI$9X0olz((+t&xfXcdpzIC0N-cUcMz@HwkL4CrdWr?&=otrWwh7)+r8~Q z3u8}KiSj)SFUj-QRSkPs%veRpwZq47D#PKjzKd*MWnZ54V94_QR!%?rid0TNg-N~B zc)7>(F??E$!LzUqbPdD;+x^e`%N$wtyyD91AUv+TvSEeja)>;(S7}7HG{nBPTdxA; zbq+qrLA%;|*wQaE{t;B;`+Q+VMaU%CAESVjVMsUmR3h7V;A@nTZON5f8&gCRF~L6F zr=G&tm37UQ$+Yz`T)tNo$T-2<=IKvs?ZC%W`(OTdo^Pr0Jp~_CzE!X_ z^1ejWvb|O0ulHT+^-!Igci^?kYY27|-8JF^wyW%ZUQ7uGvdF9a22(L1pG8_AWS`_` zLHu}ikNLu<6aTLO}Z?E;uRle=;kn(MUE&P!8 ztcc9D9KFE$FTnQ`>#Nt1AHfT0lzR*_umqlsgtS-2{b;z%EC%>K|L?s#G@*4p4#Nk3 z(a&;W`RI(}%i7Nr!*w36THkEt+YK*KzH!)G5uZ~eBoEzfkw}-S!w*5xO`$a4lPw{z| z*p_vyT)SJuHDVGjzds2)k` z;h2>G-(!zF-+JZyJUpd*3t&0uwwRwi-*ALS|0V?2?S0GoZmYTcD!fqnCSlX)#;h;j zx!{vsocjRZFRbsl@_hi$A%Bv8(wcs~4{JnMLoh|FKIEOljet5{`j1{ts#F~naQSSg zTSqButI(K-1Vf#wWA~B>zrfgpR_Ez`Th|ihUEkaAn5ye0YzSS7*n2$bkHwq0fI1GX zdB>Abb?k>PlPtI19m*GrKOy62J0BiTBG3en~G ze9KHBT<4+J`sOO%c6dtpHo+F?DAR=GLE0Ow2%8kVUTHtEzB6hL{RmFVeeIorCD1hz zbQ@Dt0&%;FbBS5&a?hfE{90B10GXz%Q+ak zvx-%F&cKV+_+0~AWzM@#)U&N~cE@_HsCn@ge1#L?o{JIK2)Y46=F0xg)}F3T?;3Ew ztHNj66YMtg=nD_|Jr4h8uk9VGz60<&#SE+nmMg(_sq#8+y$V&iv+x3Sk6s5`CXe$( zBimKV>wW8`=keR{8I}K=upwBAm|$Cvcb_?BwPt~gK2z*4{$CtpinIqTi{P|kvt3BA7YPj~jVSU%tTzefJQ@(?+adeq^xv#A& zDZV!Yd>>oi5#{>`{*cMn?MEw89$l1>{3Lt&I#YaiRJ^wWyiflZFBhVBeNMu!Deow3 zgL!R*kmmzkmlLVZL~CDLFa4i=Iv;nex9DB(TksO)Jpvm+H$cd|+tt>2Ro=vy4Dfws zeFv5A6Y*7ja~&2#S8TP3)+?>p8uPQg?L7-)PgaTYJq?%7p}NOV4SVK8X*L+Dagd;Yh3FH3W|UGtu0ref(yeI8v%-OnEWulD*^ z^&Ei9Z{)gl8(0y#TnV=IxP0Dv6)LZ@a6K;9!IoJ>&J*czIpp;{`|LF)?N6f7{o|gk zW0rDm+wQ`L04|$`CDAnz(yy&|rFbs}ct5}4U+7cd~~u9W>?0$BJ=y4 zrR>*jf7AN!{J`&j1wNtrN&+^6Zk%x1ULkEC3h@5IdW+s|`vdqU1}?h~YeZK=IBkzK zcO}cje?7qet6!fW;|aboVG{AyM}$Wyt?{Z)?4&0e;Yom`rtL# z7`h>1x9#P${X*J*BEbKt^}nY4AH((cH`ib?7nKmZZ7)xIo5y?Geip{zWQpEw`)PPT z3mKPH!ybO0zDGFTvZOan0p1^2Z_&Fxci>~pSuPucT|<{7_B=}}51F|D|3m-XYkNZZ z?}z6zc)0#MlrKh?7v#@(VGAjv&ug#s&sF~IaQQ61>)!-h$YcFS>}fws_<;cb2iAXz zGOpJ>_<-tTbFetN`XGNs$5r;%t1^*IFF$q4|0rCKXN9nAbel9@#xs5rWAAbBu1O6# zou^y2o;5W#zYQN&edroY-eVdf_@DooVK`VROvylSGnj-B<2e zuUWMQxC`H+Qn%h|n7qd%-(%6aE)Pe{68rUY+N1x|%SEH|KLXcdWgaXK-3|kBA=|pG zz1AyNd9}lvxJcZ#Ho+F~nI`tMmEZmu2=ISk{ii77dfkJstGv#^;^^uF{bOcT%BX80 zvgPHXPWc~&*Q)Wi5GHGcP3G|uw&k}u+%~;oy=2mId9P@>4qVm&o5e@g1Sz)lSat`t zrq&B%FrEJ=Fu5*=mG(<`ROM$CR)a1?$i1ZZdp1GmB5HE}7j06<_X$#}b0CxgB$6(jcB{QC^@}^<(`>Jkk_kZQJvtRM20iK(% z5W2#QC#x^g#_!lfOfEhX%Ht#~}4qc4c^^V2Fdo#fM)c^6WAJMzsC*b(gcHuo$e^=Utg@clOi0 z?LG(Nkg`PY`kaC1C{_boWekw-kM8EAA|f}uEdkyiT5r+2-tWPu6dQ*P!V=1x;nmCq zcpv(|Ub{u_dhdtVE4D*`Ruj+xyq{@p80%JX546kEzpQ*Vvj%9 z)_x`KeQ0N8usG3Srr>O~&!P{^h2M z{W?GISpN;mQ1t6@^)2|4VgoRFA8aJRU&hq20RJb}e^_~a2~R4v3af#Ig8Uh2b^`og z{(oM6N|o0sc%EWaur&tIB|_&X%6&@{`%6@rw_)-gpyY<(#+0xvzn_#M^fTKrv^u{} zV3JmuQ{DW22~R4v3af#I2rgOk`sGAV_eGw62Kb(S=K0DqWY_m3Tz((jWl@-XPEVdS z%Q4s=%rbRw-QL^QcNMMcaRWZD*f4AqHb~^)n|U9YO*0wb{mgoc-t~S0k1Mtgi^1eQ zB~FP+yxwkh0z6;-xjp}?{3&>zVpXs;yq1Wx=fyTt57RlgV?AZ<uB@XmZ#5oWa?Wiz*)A*zpNDh$<&e2 Date: Wed, 15 Nov 2023 06:38:10 -0800 Subject: [PATCH 116/391] Enhance the sorting lambda to be strict weak ordering by using instruction ids generated by post order visiting of the module. PiperOrigin-RevId: 582655900 --- third_party/xla/xla/client/lib/BUILD | 6 +- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/copy_insertion.cc | 64 +++++++++++++++++-- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 134c8c47d6e115..f6c2fe2d91ee9e 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -208,10 +208,10 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - backend_tags = { + tags = [ # Times out. - "ghostfish_iss": ["noasan"], - }, + "noasan", + ], deps = [ ":constants", ":math", diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 52d88a01978b81..0094c3e6336835 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4720,6 +4720,7 @@ cc_library( "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index a49fb6914dc566..13f2bb2bdc5605 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -16,9 +16,11 @@ limitations under the License. #include "xla/service/copy_insertion.h" #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -1191,6 +1193,18 @@ class CopyRemover { HloOrdering* ordering, bool check_live_range_ordering, const absl::flat_hash_set& execution_threads) : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Instruction indices based on post order traversal of computations and + // instructions. Used as an enhancement for getting strict weak ordering + // used for sorting below. + absl::flat_hash_map instruction_ids; + int64_t id = 0; + for (HloComputation* computation : module.MakeComputationPostOrder()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + instruction_ids[instruction] = id++; + } + } + // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. @@ -1241,8 +1255,48 @@ class CopyRemover { } std::vector values = buffer.values(); - absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { - return ordering_->IsDefinedBefore(*a, *b); + absl::c_sort(values, [this, instruction_ids](const HloValue* a, + const HloValue* b) { + // IsDefinedBefore() is generally not strict weak ordering required by + // the sort algorithm, since a may not be comparable to b or c by + // IsDefinedBefore(), but b and c can be comparable. Such as in: + // if () { b = ...; c = b + 1; } else { a = ...; } + // or + // a = param(0) + param(1); b = param(2) + param(3); c = b + 1; + // So it fails the "incomparability being transitive" requirement by + // strict weak ordering. We enhance the ordering test by using + // instruction ids generated by post order visiting of the + // computations/instructions. All HloValue's are comparable and + // dependency (thus transitivity) is respected when hlo ordering cannot + // decide the order. + if (a == b) { + return false; + } + const bool a_has_smaller_id = + instruction_ids.at(a->defining_instruction()) < + instruction_ids.at(b->defining_instruction()); + // Use a_has_smaller_id as a hint for the order between a and b. In case + // it's right, there is no need for two IsDefinedBefore() tests. + if (a_has_smaller_id) { + // Test a is defined before b first. + if (ordering_->IsDefinedBefore(*a, *b)) { + return true; + } + if (ordering_->IsDefinedBefore(*b, *a)) { + return false; + } + } else { + // Test b is defined before a first. + if (ordering_->IsDefinedBefore(*b, *a)) { + return false; + } + if (ordering_->IsDefinedBefore(*a, *b)) { + return true; + } + } + + // Use post order as tie breaker. + return a_has_smaller_id; }); // Create a list containing all of the values in the buffer. @@ -1497,7 +1551,7 @@ class CopyRemover { // s_x will be ordered before the definition of d_1. To make sure the // copy elision is safe, the following code checks that this ordering is // valid --- in particular we check it is safe to order d_m ahead of all - // the liverages at and after x_{x+1}, and it is safe to order all uses + // the liverages at and after s_{x+1}, and it is safe to order all uses // of s_x before the definition of d_1, by checking the live range // constraints for each pair --- we cannot skip the later checks because // the live range ordering is not guranteed to be transitive --- while it @@ -1833,7 +1887,7 @@ class CopyRemover { } // namespace -// We add copies for all non-phi indices of the true and false computation +// We add copies for all phi indices of the true and false computation // roots, in order to resolve interference. We later rely on // RemoveUnnecessaryCopies to drop the unnecessary ones. Status CopyInsertion::AddCopiesForConditional( @@ -1844,7 +1898,7 @@ Status CopyInsertion::AddCopiesForConditional( TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(), conditional, &indices_to_copy)) { - VLOG(2) << "No copies necessary for kWhile instruction " + VLOG(2) << "No copies necessary for kConditional instruction " << conditional->name(); return OkStatus(); } From d1eebeeac94ee50cf3f06c8cefe234592df8a8ee Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Wed, 15 Nov 2023 07:35:22 -0800 Subject: [PATCH 117/391] Multi-threaded fusion priority evaluation. This improves the overall time we spend on HLO passes by ~14% (priority vs. priority). The other low-hanging fruit is caching of analysis results. We rerun it a lot right now: in each priority computation, we do it for the producer, for each consumer, and for each producer-consumer fusion. The first two can be cached. Similarly, in GpuPriorityFusion::ChooseKind. This is still a todo. PiperOrigin-RevId: 582670292 --- third_party/xla/xla/service/gpu/BUILD | 4 ++ .../xla/xla/service/gpu/fusion_pipeline.cc | 12 ++-- .../xla/xla/service/gpu/fusion_pipeline.h | 3 + .../xla/xla/service/gpu/gpu_compiler.cc | 8 +-- .../xla/xla/service/gpu/priority_fusion.cc | 57 ++++++++++++++++--- .../xla/xla/service/gpu/priority_fusion.h | 13 +++-- .../xla/service/gpu/priority_fusion_test.cc | 2 +- 7 files changed, 78 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b760136fa34912..2bbdaa6deeb495 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2081,7 +2081,10 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", ], @@ -2719,6 +2722,7 @@ cc_library( "//xla/service:layout_assignment", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/stream_executor:device_description", + "@local_tsl//tsl/platform:env", ], ) diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index 5b0914dbd47723..6554ca858bd910 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -43,6 +43,7 @@ namespace gpu { HloPassPipeline FusionPipeline( const DebugOptions& debug_options, HloCostAnalysis::ShapeSizeFunction shape_size_bytes_function, + tsl::thread::ThreadPool* thread_pool, const se::DeviceDescription& gpu_device_info) { HloPassFix fusion("fusion"); // We try to split variadic ops with many parameters into several such ops @@ -56,12 +57,13 @@ HloPassPipeline FusionPipeline( LayoutAssignment::InstructionCanChangeLayout)), "hlo verifier (debug)"); - GpuHloCostAnalysis::Options cost_analysis_options{ - shape_size_bytes_function, - /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}; if (debug_options.xla_gpu_enable_priority_fusion()) { - fusion.AddPass(gpu_device_info, cost_analysis_options); + GpuHloCostAnalysis::Options cost_analysis_options{ + shape_size_bytes_function, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; + fusion.AddPass(thread_pool, gpu_device_info, + std::move(cost_analysis_options)); } else { fusion.AddPass(/*may_duplicate=*/false, gpu_device_info); diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.h b/third_party/xla/xla/service/gpu/fusion_pipeline.h index cecb1a98091270..bd2d7bb1f7ed27 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.h +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.h @@ -20,14 +20,17 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { // Function wrapper around the (non-horizontal) XLA GPU fusion pipeline. +// Thread pool may be nullptr. HloPassPipeline FusionPipeline( const DebugOptions& debug_options, HloCostAnalysis::ShapeSizeFunction shape_size_bytes_function, + tsl::thread::ThreadPool* thread_pool, const se::DeviceDescription& gpu_device_info); // Function wrapper around the horizontal XLA GPU fusion pipeline. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 9707f2fbf3e03d..0c4e6fe52701b2 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -835,10 +835,10 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, const se::DeviceDescription& gpu_device_info = gpu_target_config.device_description; - TF_RETURN_IF_ERROR( - FusionPipeline(debug_options, ShapeSizeBytesFunction(), gpu_device_info) - .Run(hlo_module) - .status()); + TF_RETURN_IF_ERROR(FusionPipeline(debug_options, ShapeSizeBytesFunction(), + thread_pool, gpu_device_info) + .Run(hlo_module) + .status()); if (debug_options.xla_gpu_collect_cost_model_stats()) { GpuHloCostAnalysis::Options cost_analysis_options{ diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 069514a435fa9d..8df808c1bfad82 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/priority_fusion.h" +#include #include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/meta/type_traits.h" +#include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -46,6 +48,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/blocking_counter.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -75,23 +78,30 @@ class GpuPriorityFusionQueue : public FusionQueue { HloComputation* computation, const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, - FusionProcessDumpProto* fusion_process_dump) + FusionProcessDumpProto* fusion_process_dump, + tsl::thread::ThreadPool* thread_pool) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), can_fuse_(can_fuse), - fusion_process_dump_(fusion_process_dump) { + fusion_process_dump_(fusion_process_dump), + thread_pool_(thread_pool) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); // Initializes the priority queue. - for (auto instruction : computation->MakeInstructionPostOrder()) { + std::vector instructions; + for (auto* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kParameter || instruction->user_count() == 0 || !instruction->IsFusible() || instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { continue; } - Priority priority = CalculateProducerPriority(instruction); + instructions.push_back(instruction); + } + std::vector priorities = ComputePriorities(instructions); + + for (auto [instruction, priority] : llvm::zip(instructions, priorities)) { auto emplace_result = producer_priority_queue_.emplace( std::make_pair(priority, instruction->unique_id()), instruction); CHECK(emplace_result.second); @@ -100,6 +110,28 @@ class GpuPriorityFusionQueue : public FusionQueue { } } + std::vector ComputePriorities( + const std::vector& instructions) { + auto schedule_or_run = [this](std::function fn) { + if (thread_pool_) { + thread_pool_->Schedule(std::move(fn)); + } else { + fn(); + } + }; + tsl::BlockingCounter counter(instructions.size()); + std::vector priorities(instructions.size()); + + for (size_t i = 0; i < instructions.size(); ++i) { + schedule_or_run([&, i] { + priorities[i] = CalculateProducerPriority(instructions[i]); + counter.DecrementCount(); + }); + } + counter.Wait(); + return priorities; + } + std::pair> DequeueNextInstructionAndOperandsToFuseInOrder() override { while (current_consumers_.empty()) { @@ -193,9 +225,14 @@ class GpuPriorityFusionQueue : public FusionQueue { TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction)); } - for (auto instruction : to_update_priority_) { + std::vector to_update_vector{to_update_priority_.begin(), + to_update_priority_.end()}; + std::vector new_priorities = + ComputePriorities(to_update_vector); + + for (auto [instruction, new_priority] : + llvm::zip(to_update_vector, new_priorities)) { auto reverse_it = reverse_map_.find(instruction); - const auto new_priority = CalculateProducerPriority(instruction); const auto new_key = std::make_pair(new_priority, instruction->unique_id()); if (reverse_it != reverse_map_.end()) { @@ -240,6 +277,7 @@ class GpuPriorityFusionQueue : public FusionQueue { if (auto fusion_decision = CanFuseWithAllUsers(producer); !fusion_decision) { if (fusion_process_dump_) { + absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = fusion_process_dump_->add_fusion_steps() ->mutable_producer_ineligible(); step->set_producer_name(std::string(producer->name())); @@ -253,6 +291,7 @@ class GpuPriorityFusionQueue : public FusionQueue { producer, &cost_analysis_, GpuPerformanceModelOptions::PriorityFusion(), producer->users()); if (fusion_process_dump_) { + absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = fusion_process_dump_->add_fusion_steps()->mutable_update_priority(); step->set_producer_name(std::string(producer->name())); @@ -323,6 +362,9 @@ class GpuPriorityFusionQueue : public FusionQueue { // Proto with structured logs of fusion decisions. Used only for debugging. If // null, logging is disabled. FusionProcessDumpProto* fusion_process_dump_; + absl::Mutex fusion_process_dump_mutex_; + + tsl::thread::ThreadPool* thread_pool_; }; } // namespace @@ -438,6 +480,7 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, // kernels, in which case we don't want to fuse. // TODO(b/119692968): Remove this once we have fixed our fusion emitter. if (consumer->opcode() == HloOpcode::kFusion) { + absl::MutexLock lock(&fusion_node_evaluations_mutex_); if (fusion_node_evaluations_.find(consumer) == fusion_node_evaluations_.end()) { // We have no cached results for this fusion node yet. Compute it now. @@ -501,7 +544,7 @@ std::unique_ptr GpuPriorityFusion::GetFusionQueue( [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index); }, - fusion_process_dump_.get())); + fusion_process_dump_.get(), thread_pool_)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index af24b6e9c688f9..1723766d7784c8 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -33,18 +34,20 @@ limitations under the License. #include "xla/service/hlo_pass_interface.h" #include "xla/service/instruction_fusion.h" #include "xla/statusor.h" +#include "tsl/platform/threadpool.h" namespace xla { namespace gpu { class GpuPriorityFusion : public InstructionFusion { public: - explicit GpuPriorityFusion( - const se::DeviceDescription& d, - const GpuHloCostAnalysis::Options& cost_analysis_options) + GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, + const se::DeviceDescription& d, + GpuHloCostAnalysis::Options cost_analysis_options) : InstructionFusion(GpuPriorityFusion::IsExpensive), + thread_pool_(thread_pool), device_info_(d), - cost_analysis_options_(cost_analysis_options) {} + cost_analysis_options_(std::move(cost_analysis_options)) {} absl::string_view name() const override { return "priority-fusion"; } @@ -68,6 +71,7 @@ class GpuPriorityFusion : public InstructionFusion { HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, HloInstruction* producer) override; + tsl::thread::ThreadPool* thread_pool_; se::DeviceDescription device_info_; // Cost model options that defines priorities in the queue. @@ -79,6 +83,7 @@ class GpuPriorityFusion : public InstructionFusion { // Keep track of the number of times each instruction inside a fusion node is // indexed with different index vectors. + absl::Mutex fusion_node_evaluations_mutex_; absl::flat_hash_map fusion_node_evaluations_; }; diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 8198a6e46f0a08..5574173a75ef11 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -69,7 +69,7 @@ class PriorityFusionTest : public HloTestBase { } GpuPriorityFusion priority_fusion_{ - TestGpuDeviceInfo::RTXA6000DeviceInfo(), + /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(), GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}}; From bd613db37d552ae703c1e543b141a2c1c92d3351 Mon Sep 17 00:00:00 2001 From: Chris Kennelly Date: Wed, 15 Nov 2023 08:33:23 -0800 Subject: [PATCH 118/391] Internal Code Change PiperOrigin-RevId: 582687138 --- tensorflow/core/grappler/costs/graph_properties.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 4342ec637492ec..84b33460db1b03 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -841,7 +841,6 @@ class SymbolicShapeRefiner { } int output_port_num = input_tensor.index(); - AttrValue attr_output_shape; TensorShapeProto proto; const auto handle = input_ic->output(output_port_num); input_ic->ShapeHandleToProto(handle, &proto); From 6f50f70c873d60a84baf35957190d8e8ccff5dd6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 15 Nov 2023 08:48:17 -0800 Subject: [PATCH 119/391] Add a private API to allow setting layouts on jitted computations. We expose 3 modes: * `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet. * `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior. * `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit. Public API coming soon. Co-authored-by: Roy Frostig PiperOrigin-RevId: 582692036 --- third_party/xla/xla/python/xla_client.py | 3 ++- third_party/xla/xla/python/xla_compiler.cc | 4 ++++ third_party/xla/xla/python/xla_extension/__init__.pyi | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 6239cce3b38351..1a5c3fff7c4bcb 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 213 +_version = 215 # Version number for MLIR:Python components. mlir_api_version = 54 @@ -864,3 +864,4 @@ def heap_profile(client: Client) -> bytes: copy_array_to_devices_with_sharding = _xla.copy_array_to_devices_with_sharding batched_device_put = _xla.batched_device_put check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index d27bc48fb1899c..d92c2ca020e6a5 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -288,6 +288,10 @@ void BuildXlaCompilerSubmodule(py::module& m) { // Shapes py::class_ layout_class(m, "Layout"); layout_class + .def(py::init([](py::object minor_to_major) { + return std::make_unique( + py::cast>(minor_to_major)); + })) .def("minor_to_major", [](Layout layout) { return SpanToTuple(layout.minor_to_major()); }) .def("__eq__", [](const Layout& layout, diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 19752e1c593903..5267109c42bfec 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -73,6 +73,7 @@ class PrimitiveType(enum.IntEnum): # === BEGIN xla_compiler.cc class Layout: + def __init__(self, minor_to_major: Tuple[int, ...]): ... def minor_to_major(self) -> Tuple[int, ...]: ... def to_string(self) -> str: ... def __eq__(self, other: Layout) -> bool: ... From 9b667896d82022406fa6594eae1aff04be7493b3 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Wed, 15 Nov 2023 09:43:07 -0800 Subject: [PATCH 120/391] Split out the Java and JNI code for TestInit.java into separate build targets. PiperOrigin-RevId: 582709953 --- tensorflow/lite/java/BUILD | 40 +++++++++++++++++----- tensorflow/lite/java/src/test/native/BUILD | 19 ++++++++-- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 3c0f0265236d5b..3f9fe7fea2a364 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -465,6 +465,14 @@ java_library( ], ) +java_library_with_tflite( + name = "test_init", + testonly = True, + srcs = [ + "src/test/java/org/tensorflow/lite/TestInit.java", + ], +) + #----------------------------------------------------------------------------- # java_library targets that also include native code dependencies. @@ -516,7 +524,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], javacopts = JAVACOPTS, # We want to ensure that every test case in the test also verifies that the @@ -532,6 +539,9 @@ java_test_with_tflite( "v1only", ], test_class = "org.tensorflow.lite.TensorFlowLiteTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -597,7 +607,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], data = [ # The files named as .bin reshape the incoming tensor from (2, 8, 8, 3) to (2, 4, 4, 12). @@ -613,6 +622,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -631,7 +643,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -646,6 +657,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -663,7 +677,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterApiTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -677,6 +690,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterApiTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", ], @@ -695,7 +711,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/InterpreterApiNoRuntimeTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -703,6 +718,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.InterpreterApiNoRuntimeTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_stable_test_jni.so", ], @@ -720,7 +738,6 @@ java_test_with_tflite( srcs = [ "src/test/java/org/tensorflow/lite/NnApiDelegateNativeTest.java", "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", ], data = [ @@ -728,6 +745,9 @@ java_test_with_tflite( ], tags = ["no_mac"], test_class = "org.tensorflow.lite.NnApiDelegateNativeTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -745,7 +765,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/SupportedFeatures.java", - "src/test/java/org/tensorflow/lite/TestInit.java", "src/test/java/org/tensorflow/lite/TestUtils.java", "src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java", ], @@ -755,6 +774,9 @@ java_test_with_tflite( javacopts = JAVACOPTS, tags = ["no_mac"], test_class = "org.tensorflow.lite.nnapi.NnApiDelegateTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], @@ -797,7 +819,6 @@ java_test_with_tflite( size = "small", srcs = [ "src/test/java/org/tensorflow/lite/TensorTest.java", - "src/test/java/org/tensorflow/lite/TestInit.java", ], data = [ "src/testdata/add.bin", @@ -808,6 +829,9 @@ java_test_with_tflite( ], javacopts = JAVACOPTS, test_class = "org.tensorflow.lite.TensorTest", + tflite_deps = [ + ":test_init", + ], tflite_jni_binaries = [ "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", ], diff --git a/tensorflow/lite/java/src/test/native/BUILD b/tensorflow/lite/java/src/test/native/BUILD index 2c32fc618331a8..db20aafcd2d0f8 100644 --- a/tensorflow/lite/java/src/test/native/BUILD +++ b/tensorflow/lite/java/src/test/native/BUILD @@ -30,10 +30,9 @@ cc_library_with_tflite( "interpreter_test_jni.cc", "nnapi_delegate_test_jni.cc", "supported_features_jni.cc", - "test_init_jni.cc", ], tflite_deps = [ - "//tensorflow/lite/c:test_util", + ":test_init_jni", "//tensorflow/lite/delegates/nnapi/java/src/main/native", "//tensorflow/lite/java/src/main/native", "//tensorflow/lite/java/src/main/native:jni_utils", @@ -51,6 +50,22 @@ cc_library_with_tflite( alwayslink = 1, ) +cc_library_with_tflite( + name = "test_init_jni", + testonly = 1, + srcs = [ + "test_init_jni.cc", + ], + tflite_deps = [ + "//tensorflow/lite/java/src/main/native:jni_utils", + "//tensorflow/lite/c:test_util", + ], + deps = [ + "//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + # Same as "native", but excluding dependencies on experimental features. cc_library_with_tflite( name = "native_stable", From 68b404991a10bc8ae70042c64899e54a5f01981a Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Wed, 15 Nov 2023 09:54:34 -0800 Subject: [PATCH 121/391] Remove deprecation for target "for_generated_wrappers_v2". PiperOrigin-RevId: 582713647 --- tensorflow/python/framework/BUILD | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 1217581e0cd647..38b5ad2ac9344d 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -258,13 +258,14 @@ py_strict_library( ], ) -# What is needed for tf_gen_op_wrapper_py. This is the same as -# "for_generated_wrappers" minus the "function" dep. This is to avoid -# circular dependencies, as "function" uses generated op wrappers. +# This rule should only be depended on by tf_gen_op_wrapper_py. +# Do not depend on this rule! Depend on the fine-grained sub-targets instead. +# This is the same as "for_generated_wrappers" minus the "function" dep. +# This is to avoid circular dependencies, as "function" uses generated op wrappers. py_strict_library( name = "for_generated_wrappers_v2", - deprecation = "Depending on this target can cause build dependency cycles. Depend on the fine-grained sub-targets instead.", srcs_version = "PY3", + tags = ["avoid_dep"], visibility = ["//visibility:public"], deps = [ ":byte_swap_tensor", From 1b7cc437e03b657592d9044fdd8529546cdf5d82 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Wed, 15 Nov 2023 10:02:51 -0800 Subject: [PATCH 122/391] [XLA] Fix typo in cost model in dot_handler. Clearly the parenthesis is not in the right place ... PiperOrigin-RevId: 582716345 --- third_party/xla/xla/service/spmd/dot_handler.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 289a443ffb722c..2f487f2b90e805 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -2347,9 +2347,9 @@ GetNonContractingPartitionGroupedShardingForOtherOperand( other_group_dims = std::move(*found_dims); } else if (may_replicate_other_contracting_dims && (!may_replicate_other_non_contracting_dims || - ShapeUtil::ByteSizeOf(other_shape)) <= - ShapeUtil::ByteSizeOf(MakePartitionedShape( - output_base_shape, output_sharding))) { + ShapeUtil::ByteSizeOf(other_shape) <= + ShapeUtil::ByteSizeOf(MakePartitionedShape( + output_base_shape, output_sharding)))) { for (const auto& dim : other_contracting_dims) { other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); } From 70cb2616c5d1050620231a544b7b78c084460ec1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 10:17:05 -0800 Subject: [PATCH 123/391] Integrate LLVM at llvm/llvm-project@b88308b1b481 Updates LLVM usage to match [b88308b1b481](https://github.com/llvm/llvm-project/commit/b88308b1b481) PiperOrigin-RevId: 582721429 --- third_party/llvm/generated.patch | 119 ++++++++++++++++++++++++++----- third_party/llvm/workspace.bzl | 4 +- 2 files changed, 102 insertions(+), 21 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index a37125c400d30a..502a7c7cd1863e 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,20 +1,101 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -1191,6 +1191,7 @@ - name = "exp10f_impl", - hdrs = ["src/math/generic/exp10f_impl.h"], - deps = [ -+ ":__support_fputil_basic_operations", - ":__support_fputil_fma", - ":__support_fputil_multiply_add", - ":__support_fputil_nearest_integer", -@@ -1206,6 +1207,7 @@ - name = "exp2f_impl", - hdrs = ["src/math/generic/exp2f_impl.h"], - deps = [ -+ ":__support_fputil_except_value_utils", - ":__support_fputil_fma", - ":__support_fputil_multiply_add", - ":__support_fputil_nearest_integer", +diff -ruN --strip-trailing-cr a/clang/test/Analysis/builtin_signbit.cpp b/clang/test/Analysis/builtin_signbit.cpp +--- a/clang/test/Analysis/builtin_signbit.cpp ++++ b/clang/test/Analysis/builtin_signbit.cpp +@@ -5,6 +5,7 @@ + // RUN: -O0 %s -o - | FileCheck %s --check-prefixes=CHECK-BE64 + // RUN: %clang -target powerpc64le-linux-gnu -emit-llvm -S -mabi=ibmlongdouble \ + // RUN: -O0 %s -o - | FileCheck %s --check-prefixes=CHECK-LE ++// REQUIRES: asserts + + bool b; + double d = -1.0; +diff -ruN --strip-trailing-cr a/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test b/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test +--- a/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test ++++ b/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test +@@ -1,4 +1,5 @@ + # UNSUPPORTED: system-darwin, system-windows ++# REQUIRES: gcc + + # Make sure the artifical field `vptr.ClassName` from gcc debug info is ignored. + # RUN: %build --compiler=gcc %S/Inputs/debug-types-expressions.cpp -o %t +diff -ruN --strip-trailing-cr a/llvm/lib/Support/CommandLine.cpp b/llvm/lib/Support/CommandLine.cpp +--- a/llvm/lib/Support/CommandLine.cpp ++++ b/llvm/lib/Support/CommandLine.cpp +@@ -1667,13 +1667,6 @@ + Handler = LookupLongOption(*ChosenSubCommand, ArgName, Value, + LongOptionsUseDoubleDash, HaveDoubleDash); + +- // If Handler is not found in a specialized subcommand, look up handler +- // in the top-level subcommand. +- // cl::opt without cl::sub belongs to top-level subcommand. +- if (!Handler && ChosenSubCommand != &SubCommand::getTopLevel()) +- Handler = LookupLongOption(SubCommand::getTopLevel(), ArgName, Value, +- LongOptionsUseDoubleDash, HaveDoubleDash); +- + // Check to see if this "option" is really a prefixed or grouped argument. + if (!Handler && !(LongOptionsUseDoubleDash && HaveDoubleDash)) + Handler = HandlePrefixedOrGroupedOption(ArgName, Value, ErrorParsing, +diff -ruN --strip-trailing-cr a/llvm/unittests/Support/CommandLineTest.cpp b/llvm/unittests/Support/CommandLineTest.cpp +--- a/llvm/unittests/Support/CommandLineTest.cpp ++++ b/llvm/unittests/Support/CommandLineTest.cpp +@@ -525,59 +525,6 @@ + EXPECT_FALSE(Errs.empty()); + } + +-TEST(CommandLineTest, TopLevelOptInSubcommand) { +- enum LiteralOptionEnum { +- foo, +- bar, +- baz, +- }; +- +- cl::ResetCommandLineParser(); +- +- // This is a top-level option and not associated with a subcommand. +- // A command line using subcommand should parse both subcommand options and +- // top-level options. A valid use case is that users of llvm command line +- // tools should be able to specify top-level options defined in any library. +- cl::opt TopLevelOpt("str", cl::init("txt"), +- cl::desc("A top-level option.")); +- +- StackSubCommand SC("sc", "Subcommand"); +- StackOption PositionalOpt( +- cl::Positional, cl::desc("positional argument test coverage"), +- cl::sub(SC)); +- StackOption LiteralOpt( +- cl::desc("literal argument test coverage"), cl::sub(SC), cl::init(bar), +- cl::values(clEnumVal(foo, "foo"), clEnumVal(bar, "bar"), +- clEnumVal(baz, "baz"))); +- StackOption EnableOpt("enable", cl::sub(SC), cl::init(false)); +- StackOption ThresholdOpt("threshold", cl::sub(SC), cl::init(1)); +- +- const char *PositionalOptVal = "input-file"; +- const char *args[] = {"prog", "sc", PositionalOptVal, +- "-enable", "--str=csv", "--threshold=2"}; +- +- // cl::ParseCommandLineOptions returns true on success. Otherwise, it will +- // print the error message to stderr and exit in this setting (`Errs` ostream +- // is not set). +- ASSERT_TRUE(cl::ParseCommandLineOptions(sizeof(args) / sizeof(args[0]), args, +- StringRef())); +- EXPECT_STREQ(PositionalOpt.getValue().c_str(), PositionalOptVal); +- EXPECT_TRUE(EnableOpt); +- // Tests that the value of `str` option is `csv` as specified. +- EXPECT_STREQ(TopLevelOpt.getValue().c_str(), "csv"); +- EXPECT_EQ(ThresholdOpt, 2); +- +- for (auto &[LiteralOptVal, WantLiteralOpt] : +- {std::pair{"--bar", bar}, {"--foo", foo}, {"--baz", baz}}) { +- const char *args[] = {"prog", "sc", LiteralOptVal}; +- ASSERT_TRUE(cl::ParseCommandLineOptions(sizeof(args) / sizeof(args[0]), +- args, StringRef())); +- +- // Tests that literal options are parsed correctly. +- EXPECT_EQ(LiteralOpt, WantLiteralOpt); +- } +-} +- + TEST(CommandLineTest, AddToAllSubCommands) { + cl::ResetCommandLineParser(); + diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 4bd62858469ca5..d462b582b7076b 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5d6304f01742a0a7c628fe6850e921c745eaea08" - LLVM_SHA256 = "5230a6bd323dc27893adf688ce1769854e0b92d8ce2f4d14ac62b9a200a1e452" + LLVM_COMMIT = "b88308b1b4813e55ce8f54ceff6e57736328fb58" + LLVM_SHA256 = "fe27af49a5596d91929b989b40cb2879829914802ad45b6d947ca5e070ec20d4" tf_http_archive( name = name, From 2e822f07ea297345819b1d57465fc438f2f1c4d1 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 15 Nov 2023 10:21:31 -0800 Subject: [PATCH 124/391] Remove an unused variable. PiperOrigin-RevId: 582723150 --- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 9a7a5d96788031..e5be3dc6de5dfd 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -721,8 +721,6 @@ GetStreamExecutorGpuDeviceAllocator( case GpuAllocatorConfig::Kind::kDefault: case GpuAllocatorConfig::Kind::kBFC: { LOG(INFO) << "Using BFC allocator."; - std::vector executors; - executors.reserve(addressable_devices.size()); std::vector allocators_and_streams; for (const auto& ordinal_and_device : addressable_devices) { From 4bf6a7a57c110e0f57c71b82873cc831dd00fdc5 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 15 Nov 2023 11:01:26 -0800 Subject: [PATCH 125/391] [xla:ffi] Add API for accessing platform specific stream in XLA FFI handlers This should allow to get access to CUDA streams in FFI handlers. ``` static Status Handler(Custream stream) { } Ffi::Bind() .Ctx() .To(Handler); `` PiperOrigin-RevId: 582737147 --- third_party/xla/xla/ffi/api/c_api.h | 30 +++++++++++++++++ third_party/xla/xla/ffi/api/ffi.h | 52 +++++++++++++++++++++++++++-- third_party/xla/xla/ffi/ffi.cc | 23 +++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 55d92827558218..3e96b8035296a0 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -117,6 +117,16 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Create_Args, errc); typedef XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args); +struct XLA_FFI_Error_Destroy_Args { + size_t struct_size; + void* priv; + XLA_FFI_Error* error; +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Destroy_Args, error); + +typedef void XLA_FFI_Error_Destroy(XLA_FFI_Error_Destroy_Args* args); + //===----------------------------------------------------------------------===// // Builtin argument types //===----------------------------------------------------------------------===// @@ -228,6 +238,24 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, handler); typedef XLA_FFI_Error* XLA_FFI_Handler_Register( XLA_FFI_Handler_Register_Args* args); +//===----------------------------------------------------------------------===// +// Stream +//===----------------------------------------------------------------------===// + +struct XLA_FFI_Stream_Get_Args { + size_t struct_size; + void* priv; + + XLA_FFI_ExecutionContext* ctx; + void* stream; // out +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Stream_Get_Args, stream); + +// Returns an underling platform-specific stream via out argument, i.e. for CUDA +// platform it returns `CUstream` (same as `cudaStream`). +typedef XLA_FFI_Error* XLA_FFI_Stream_Get(XLA_FFI_Stream_Get_Args* args); + //===----------------------------------------------------------------------===// // API access //===----------------------------------------------------------------------===// @@ -241,7 +269,9 @@ struct XLA_FFI_Api { XLA_FFI_InternalApi* internal_api; _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy); _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register); + _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get); }; #undef _XLA_FFI_API_STRUCT_FIELD diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 1985864be1e3bf..9a4882e656c22b 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -20,11 +20,59 @@ limitations under the License. #error Two different XLA FFI implementations cannot be included together #endif // XLA_FFI_API_H_ +#include +#include + +#include "xla/ffi/api/c_api.h" + // IWYU pragma: begin_exports #include "xla/ffi/api/api.h" // IWYU pragma: end_exports -// TODO(ezhulenev): Implement FFI arguments and attributes decoding for external -// FFI users without any dependencies on absl or other libraries. +namespace xla::ffi { + +namespace internal { +// TODO(ezhulenev): We need to log error message somewhere, currently we +// silently destroy it. +inline void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) { + XLA_FFI_Error_Destroy_Args destroy_args; + destroy_args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; + destroy_args.priv = nullptr; + destroy_args.error = error; + api->XLA_FFI_Error_Destroy(&destroy_args); +} +} // namespace internal + +//===----------------------------------------------------------------------===// +// PlatformStream +//===----------------------------------------------------------------------===// + +template +struct PlatformStream {}; + +template +struct CtxDecoding> { + using Type = T; + + static_assert(std::is_pointer_v, "stream type must be a pointer"); + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx) { + XLA_FFI_Stream_Get_Args args; + args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; + args.priv = nullptr; + args.ctx = ctx; + args.stream = nullptr; + + if (XLA_FFI_Error* error = api->XLA_FFI_Stream_Get(&args); error) { + internal::DestroyError(api, error); + return std::nullopt; + } + + return reinterpret_cast(args.stream); + } +}; + +} // namespace xla::ffi #endif // XLA_FFI_API_FFI_H_ diff --git a/third_party/xla/xla/ffi/ffi.cc b/third_party/xla/xla/ffi/ffi.cc index 9899010c41f8c2..6f2b02cda3867f 100644 --- a/third_party/xla/xla/ffi/ffi.cc +++ b/third_party/xla/xla/ffi/ffi.cc @@ -173,6 +173,16 @@ static XLA_FFI_Error* XLA_FFI_Error_Create(XLA_FFI_Error_Create_Args* args) { return new XLA_FFI_Error{Status(ToStatusCode(args->errc), args->message)}; } +static void XLA_FFI_Error_Destroy(XLA_FFI_Error_Destroy_Args* args) { + Status struct_size_check = ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_Error_Destroy", XLA_FFI_Error_Destroy_Args_STRUCT_SIZE, + args->struct_size); + if (!struct_size_check.ok()) { + LOG(ERROR) << struct_size_check.message(); + } + delete args->error; +} + static XLA_FFI_Error* XLA_FFI_Handler_Register( XLA_FFI_Handler_Register_Args* args) { XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -185,6 +195,17 @@ static XLA_FFI_Error* XLA_FFI_Handler_Register( return nullptr; } +static XLA_FFI_Error* XLA_FFI_Stream_Get(XLA_FFI_Stream_Get_Args* args) { + XLA_FFI_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "XLA_FFI_Stream_Get", XLA_FFI_Stream_Get_Args_STRUCT_SIZE, + args->struct_size)); + + auto handle = args->ctx->run_options->stream()->platform_specific_handle(); + args->stream = handle.stream; + + return nullptr; +} + //===----------------------------------------------------------------------===// // XLA FFI Internal Api Implementation //===----------------------------------------------------------------------===// @@ -214,7 +235,9 @@ static XLA_FFI_Api api = { &internal_api, XLA_FFI_Error_Create, // creates error + XLA_FFI_Error_Destroy, // frees error XLA_FFI_Handler_Register, // registers handler + XLA_FFI_Stream_Get, // returns platform specific stream }; XLA_FFI_Api* GetXlaFfiApi() { return &api; } From 2c0a7a10e3c306edbdc592ae93ea0dad114a36a7 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 15 Nov 2023 11:04:09 -0800 Subject: [PATCH 126/391] [XLA] hlo-opt tool: conversions from HLO to optimized HLO/LLVM/PTX/TTIR etc PiperOrigin-RevId: 582738218 --- third_party/xla/xla/runlit.cfg.py | 1 + third_party/xla/xla/runlit.site.cfg.py | 1 + third_party/xla/xla/tools/BUILD | 8 + third_party/xla/xla/tools/hlo_opt/BUILD | 119 +++++++++++ third_party/xla/xla/tools/hlo_opt/gpu_hlo.hlo | 12 ++ .../xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo | 17 ++ .../xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo | 17 ++ third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 108 ++++++++++ .../xla/tools/hlo_opt/gpu_specs/a100.txtpb | 42 ++++ third_party/xla/xla/tools/hlo_opt/opt_lib.cc | 51 +++++ third_party/xla/xla/tools/hlo_opt/opt_lib.h | 53 +++++ third_party/xla/xla/tools/hlo_opt/opt_main.cc | 195 ++++++++++++++++++ 12 files changed, 624 insertions(+) create mode 100644 third_party/xla/xla/tools/hlo_opt/BUILD create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_hlo.hlo create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_opt.cc create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_specs/a100.txtpb create mode 100644 third_party/xla/xla/tools/hlo_opt/opt_lib.cc create mode 100644 third_party/xla/xla/tools/hlo_opt/opt_lib.h create mode 100644 third_party/xla/xla/tools/hlo_opt/opt_main.cc diff --git a/third_party/xla/xla/runlit.cfg.py b/third_party/xla/xla/runlit.cfg.py index 28d411422eb865..1c28ed57e8b3c8 100644 --- a/third_party/xla/xla/runlit.cfg.py +++ b/third_party/xla/xla/runlit.cfg.py @@ -97,6 +97,7 @@ 'xla-translate', 'xla-translate-gpu-opt', 'xla-translate-opt', + 'hlo-opt', ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/third_party/xla/xla/runlit.site.cfg.py b/third_party/xla/xla/runlit.site.cfg.py index 6c932591cac135..5fc38b581a0baf 100644 --- a/third_party/xla/xla/runlit.site.cfg.py +++ b/third_party/xla/xla/runlit.site.cfg.py @@ -48,6 +48,7 @@ "service/mlir_gpu", "translate", "translate/mhlo_to_lhlo_with_xla", + "tools", ] config.mlir_tf_tools_dirs = [ os.path.join(real_test_srcdir, os.environ["TEST_WORKSPACE"], xla_root_dir, diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 057facad20511a..2f23d3b68b8717 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -282,6 +282,14 @@ xla_cc_binary( ], ) +xla_cc_binary( + name = "hlo-opt", + testonly = True, + deps = [ + "//xla/tools/hlo_opt:opt_main", + ], +) + cc_library( name = "hlo_expand_main", srcs = ["hlo_expand_main.cc"], diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD new file mode 100644 index 00000000000000..1052ea229ceb65 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -0,0 +1,119 @@ +load("//xla:glob_lit_test.bzl", "glob_lit_tests") +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) +load("@local_tsl//tsl:tsl.default.bzl", "filegroup") +load( + "@local_tsl//tsl/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") + +# hlo-opt tool. +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# Includes a macro to register a provider. +cc_library( + name = "opt_lib", + srcs = ["opt_lib.cc"], + hdrs = ["opt_lib.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:debug_options_flags", + "//xla:statusor", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/service:compiler", + "//xla/stream_executor:platform", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "gpu_opt", + testonly = True, + srcs = if_cuda_is_configured(["gpu_opt.cc"]), + visibility = ["//visibility:public"], + deps = [ + ":opt_lib", + "//xla:debug_options_flags", + "//xla:statusor", + "//xla:types", + "//xla/service:compiler", + "//xla/service:platform_util", + "//xla/service/gpu:executable_proto_cc", + "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/platform", + "@com_google_absl//absl/container:flat_hash_map", + ] + if_gpu_is_configured([ + "//xla/service:gpu_plugin", + "//xla/service/gpu:gpu_executable", + ]) + if_cuda_is_configured([ + "//xla/stream_executor:cuda_platform", + ]), + alwayslink = True, # Initializer needs to run. +) + +cc_library( + name = "opt_main", + testonly = True, + srcs = ["opt_main.cc"], + visibility = ["//visibility:public"], + deps = [ + ":opt_lib", + "//xla:debug_options_flags", + "//xla:status", + "//xla:statusor", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_runner", + "//xla/service:platform_util", + "//xla/tools:hlo_module_loader", + "//xla/tools:run_hlo_module_lib", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/util:command_line_flags", + ] + if_gpu_is_configured([ + ":gpu_opt", + ]) + if_cuda_is_configured([ + "//xla/stream_executor:cuda_platform", + ]), +) + +glob_lit_tests( + name = "gpu_opt_tests", + data = [":test_utilities"], + default_tags = tf_cuda_tests_tags() + [ + ], + driver = "//xla:run_lit.sh", + test_file_exts = ["hlo"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "gpu_specs/a100.txtpb", + "//xla/tools:hlo-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//mlir:run_lit.sh", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo.hlo new file mode 100644 index 00000000000000..c170f21bb31942 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo.hlo @@ -0,0 +1,12 @@ +// RUN: hlo-opt %s --platform=CUDA --stage=hlo --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s + +HloModule module + +ENTRY computation { +// CHECK: bitcast + p = f32[5000,6000]{1,0} parameter(0) + e = f32[5000,6000]{1,0} sqrt(p) + c = f32[6000,5000] transpose(p), dimensions={1,0} + r = f32[300,20,5000] reshape(c) + ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) +} diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo new file mode 100644 index 00000000000000..6abb61fcc51d9b --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -0,0 +1,17 @@ +// RUN: hlo-opt %s --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s + +HloModule m, is_scheduled=true + +add { + a = f16[] parameter(0) + b = f16[] parameter(1) + ROOT out = f16[] add(a, b) +} + + +// CHECK: load half +ENTRY e { + p1 = f16[1048576] parameter(0) + i = f16[] constant(0) + ROOT out = f16[] reduce(p1, i), dimensions={0}, to_apply=add +} diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo new file mode 100644 index 00000000000000..9ed67ee237db0e --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo @@ -0,0 +1,17 @@ +// RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s --dump-input-filter=all + +HloModule m, is_scheduled=true + +add { + a = f16[] parameter(0) + b = f16[] parameter(1) + ROOT out = f16[] add(a, b) +} + + +// CHECK: shfl.sync.down +ENTRY e { + p1 = f16[1048576] parameter(0) + i = f16[] constant(0) + ROOT out = f16[] reduce(p1, i), dimensions={0}, to_apply=add +} diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc new file mode 100644 index 00000000000000..422bccdfdd81cf --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -0,0 +1,108 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/debug_options_flags.h" +#include "xla/service/compiler.h" +#include "xla/service/gpu/executable.pb.h" +#include "xla/service/gpu/gpu_executable.h" +#include "xla/service/platform_util.h" +#include "xla/statusor.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/platform/initialize.h" +#include "xla/tools/hlo_opt/opt_lib.h" +#include "xla/types.h" + +namespace xla { + +namespace { + +// TODO(cheshire): Switch CUDA/ROCM +static auto kGpuPlatformId = se::cuda::kCudaPlatformId; + +static StatusOr> ToGpuExecutable( + std::unique_ptr module, Compiler* compiler, + se::StreamExecutor* executor, const Compiler::CompileOptions& opts) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimized_module, + compiler->RunHloPasses(std::move(module), executor, opts)); + DebugOptions d = optimized_module->config().debug_options(); + d.set_xla_embed_ir_in_executable(true); + optimized_module->mutable_config().set_debug_options(d); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + compiler->RunBackend(std::move(optimized_module), executor, opts)); + return executable; +} + +struct GpuOptProvider : public OptProvider { + StatusOr> GenerateStage( + std::unique_ptr module, absl::string_view s) override { + TF_ASSIGN_OR_RETURN( + se::Platform * platform, + se::MultiPlatformManager::PlatformWithId(kGpuPlatformId)); + + TF_ASSIGN_OR_RETURN(Compiler * compiler, + Compiler::GetForPlatform(platform)); + DebugOptions debug_opts = GetDebugOptionsFromFlags(); + + Compiler::CompileOptions opts; + + se::StreamExecutor* executor = nullptr; + if (debug_opts.xla_gpu_target_config_filename().empty()) { + TF_ASSIGN_OR_RETURN(std::vector stream_executors, + PlatformUtil::GetStreamExecutors( + platform, /*allowed_devices=*/std::nullopt)); + executor = stream_executors[0]; + } + + if (s == "hlo") { + TF_ASSIGN_OR_RETURN( + std::unique_ptr optimized_module, + compiler->RunHloPasses(std::move(module), executor, opts)); + return optimized_module->ToString(); + } else if (s == "llvm") { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + ToGpuExecutable(std::move(module), compiler, executor, opts)); + return static_cast(executable.get()) + ->ir_module_string(); + } else if (s == "ptx") { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + ToGpuExecutable(std::move(module), compiler, executor, opts)); + return static_cast(executable.get())->text(); + } + + // Unimplemented stage. + return std::nullopt; + } + + std::vector SupportedStages() override { + return {"hlo", "llvm", "ptx"}; + } +}; + +} // namespace +} // namespace xla + +REGISTER_MODULE_INITIALIZER(gpu_opt_provider, { + xla::OptProvider::RegisterForPlatform( + xla::kGpuPlatformId, std::make_unique()); +}); diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100.txtpb b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100.txtpb new file mode 100644 index 00000000000000..864125066c3ae6 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_specs/a100.txtpb @@ -0,0 +1,42 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +gpu_device_info { + cuda_compute_capability { + major: 8 + minor: 0 + } + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 65536 + shared_memory_per_block_optin: 65536 + shared_memory_per_core: 65536 + threads_per_core_limit: 2048 + core_count: 6192 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 2039000000000 + l2_cache_size: 4194304 + clock_rate_ghz: 1.1105 + device_memory_size: 79050250240 +} +platform_name: "CUDA" +dnn_version_info { + major: 8 + minor: 3 + patch: 2 +} +device_description_str: "A100 80GB" diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc new file mode 100644 index 00000000000000..3bf19f43004f2d --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.cc @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/hlo_opt/opt_lib.h" + +#include "absl/container/flat_hash_map.h" +#include "xla/types.h" + +namespace xla { + +using ProviderMap = + absl::flat_hash_map>; +static absl::Mutex provider_mu(absl::kConstInit); + +static ProviderMap& GetProviderMap() { + static auto& provider_map = *new ProviderMap(); + return provider_map; +} + +/*static*/ void OptProvider::RegisterForPlatform( + se::Platform::Id platform, + std::unique_ptr translate_provider) { + absl::MutexLock l(&provider_mu); + CHECK(!GetProviderMap().contains(platform)); + GetProviderMap()[platform] = std::move(translate_provider); +} + +/*static*/ OptProvider* OptProvider::ProviderForPlatform( + se::Platform::Id platform) { + absl::MutexLock l(&provider_mu); + auto it = GetProviderMap().find(platform); + if (it == GetProviderMap().end()) { + return nullptr; + } + + return it->second.get(); +} + +} // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.h b/third_party/xla/xla/tools/hlo_opt/opt_lib.h new file mode 100644 index 00000000000000..29ae5ba73c96d0 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/opt_lib.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TOOLS_HLO_OPT_OPT_LIB_H_ +#define XLA_TOOLS_HLO_OPT_OPT_LIB_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/compiler.h" +#include "xla/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/types.h" + +namespace xla { + +// Platform-specific provider of `hlo_translate` functionality. +struct OptProvider { + // Generates textual output for a given stage on a given platform, returns + // empty optional if the stage is not supported. + virtual StatusOr> GenerateStage( + std::unique_ptr module, absl::string_view stage) = 0; + + virtual ~OptProvider() = default; + + virtual std::vector SupportedStages() = 0; + + static void RegisterForPlatform( + se::Platform::Id platform, + std::unique_ptr translate_provider); + + static OptProvider* ProviderForPlatform(se::Platform::Id platform); +}; + +} // namespace xla + +#endif // XLA_TOOLS_HLO_OPT_OPT_LIB_H_ diff --git a/third_party/xla/xla/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/tools/hlo_opt/opt_main.cc new file mode 100644 index 00000000000000..4b00cf4b25934a --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/opt_main.cc @@ -0,0 +1,195 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for reading a HloModule from a HloProto file and execute the module on +// given platform(s). See kUsage for details. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/platform_util.h" +#include "xla/statusor.h" +#include "xla/tools/hlo_module_loader.h" +#include "xla/tools/hlo_opt/opt_lib.h" +#include "xla/tools/run_hlo_module.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status.h" +#include "tsl/util/command_line_flags.h" + +namespace { +const char* const kUsage = R"( +This tool lets you run a given HloModule from a file (or stdin) and convert it +to expanded HLO, fully optimized HLO, or a binary depending on options. + +You can also pass in debug option flags for the HloModule. + +Usage: + + bazel run opt -- --platform=[CUDA|CPU|Interpreter|...] path/to/hlo_module +)"; + +struct HloOptConfig { + // Optional flags. + bool help{false}; + bool split_input_file{false}; + std::string platform{"cuda"}; + std::string input_file{""}; + std::string input_format{""}; + std::string output_file{"-"}; + std::string stage{"optimized_hlo"}; + std::string input_stage{"hlo"}; + bool list_stages{false}; +}; + +} // namespace + +namespace xla { + +namespace { + +std::string GetHloPath(const HloOptConfig& opts, int argc, char** argv) { + if (!opts.input_file.empty()) { + return opts.input_file; + } + QCHECK(argc == 2) << "Must specify a single input file"; + return argv[1]; +} + +StatusOr GetHloContents(const HloOptConfig& opts, int argc, + char** argv) { + std::string hlo_path = GetHloPath(opts, argc, argv); + if (hlo_path == "-") { + std::string stdin; + std::getline(std::cin, stdin, static_cast(EOF)); + return stdin; + } + + std::string data; + TF_RETURN_IF_ERROR( + tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &data)); + return data; +} + +StatusOr> GetModule(const HloOptConfig& opts, + int argc, char** argv) { + TF_ASSIGN_OR_RETURN(std::string module_data, + GetHloContents(opts, argc, argv)); + + std::string format = opts.input_format; + if (format.empty()) { + format = std::string(tsl::io::Extension(GetHloPath(opts, argc, argv))); + } + return LoadModuleFromData(module_data, format); +} + +StatusOr TranslateToStage(int argc, char** argv, + const HloOptConfig& opts) { + se::Platform* platform = + xla::PlatformUtil::GetPlatform(opts.platform).value(); + + OptProvider* provider = OptProvider::ProviderForPlatform(platform->id()); + if (provider == nullptr) { + return absl::UnimplementedError( + absl::StrCat("Provider not found for platform: ", platform->Name())); + } + + if (opts.list_stages) { + return absl::StrJoin(provider->SupportedStages(), "\n"); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + GetModule(opts, argc, argv)); + + TF_ASSIGN_OR_RETURN(std::optional out, + provider->GenerateStage(std::move(module), opts.stage)); + + if (!out.has_value()) { + return absl::UnimplementedError("Stage not supported"); + } + + return *out; +} + +Status RunOpt(int argc, char** argv, const HloOptConfig& opts) { + TF_ASSIGN_OR_RETURN(std::string output, TranslateToStage(argc, argv, opts)); + if (opts.output_file == "-") { + std::cout << output << std::endl; + } else { + TF_RETURN_IF_ERROR( + tsl::WriteStringToFile(tsl::Env::Default(), opts.output_file, output)); + } + return OkStatus(); +} + +} // namespace +} // namespace xla + +// gpu_device_config_filename: Probably deserves it's own flag? Since in here it +// will affect more top-level logic? +int main(int argc, char** argv) { + HloOptConfig opts; + std::vector flag_list = { + tsl::Flag("o", &opts.output_file, + "Output filename, or '-' for stdout (default)."), + tsl::Flag("platform", &opts.platform, + "The platform for which we perform the translation"), + tsl::Flag("format", &opts.input_format, + "The format of the input file. By default inferred from the " + "filename. Valid values:\n" + "\t\t\t hlo : HLO textual format\n" + "\t\t\t pb : xla::HloProto in binary proto format\n" + "\t\t\t pbtxt : xla::HloProto in text proto format"), + tsl::Flag("stage", &opts.stage, + "Output stage to dump. " + "Valid values depend on the platform, for GPUs:\n" + "\t\t\t * hlo : HLO after all optimizations\n" + "\t\t\t * llvm : LLVM IR\n" + "\t\t\t * ptx : PTX dump\n"), + tsl::Flag("list-stages", &opts.list_stages, + "Print all supported stages for a given platform and exit")}; + // Modifies global DebugOptions, populates flags with every flag available + // from xla.proto. + xla::AppendDebugOptionsFlags(&flag_list); + // The usage string includes the message at the top of the file, the + // DebugOptions flags and the flags defined above. + const std::string kUsageString = + absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list)); + bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); + tsl::port::InitMain(kUsageString.c_str(), &argc, &argv); + + if (!parse_ok) { + LOG(QFATAL) << kUsageString; + } + + xla::Status s = xla::RunOpt(argc, argv, opts); + if (!s.ok()) { + std::cerr << s; + return 1; + } + return 0; +} From 8b28fcf4e7b8c783e9a97905b383c4f4f64ce1d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 11:15:20 -0800 Subject: [PATCH 127/391] Allows adding a large `FunctionDef` to context in eager mode. Introduces `TFE_ContextAddFunctionDefNoSerialization`, which is similar to `TFE_ContextAddFunctionDef`, only without serialization when passing the `FunctionDef` protobuf to C++. PiperOrigin-RevId: 582741936 --- tensorflow/python/_pywrap_tfe.pyi | 1 + tensorflow/python/eager/context.py | 10 +++++++--- tensorflow/python/eager/wrap_function.py | 2 +- tensorflow/python/tfe_wrapper.cc | 25 ++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/_pywrap_tfe.pyi b/tensorflow/python/_pywrap_tfe.pyi index 26d129cd2a8566..1385ae69244d58 100644 --- a/tensorflow/python/_pywrap_tfe.pyi +++ b/tensorflow/python/_pywrap_tfe.pyi @@ -179,6 +179,7 @@ def TFE_ClearScalarCache() -> object: ... def TFE_CollectiveOpsCheckPeerHealth(arg0: object, arg1: str, arg2: int) -> None: ... def TFE_ContextAddFunction(arg0: object, arg1: TF_Function) -> None: ... def TFE_ContextAddFunctionDef(arg0: object, arg1: str, arg2: int) -> None: ... +def TFE_ContextAddFunctionDefNoSerialization(ctx: object, function_def) -> None: ... def TFE_ContextCheckAlive(arg0: object, arg1: str) -> bool: ... def TFE_ContextClearCaches(arg0: object) -> None: ... def TFE_ContextClearExecutors(arg0: object) -> None: ... diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index aeda5a61594fd7..88614b43afb37f 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1378,9 +1378,13 @@ def add_function_def(self, fdef): fdef: A FunctionDef protocol buffer message. """ self.ensure_initialized() - fdef_string = fdef.SerializeToString() - pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, - len(fdef_string)) + if is_oss: + fdef_string = fdef.SerializeToString() + pywrap_tfe.TFE_ContextAddFunctionDef( + self._handle, fdef_string, len(fdef_string) + ) + else: + pywrap_tfe.TFE_ContextAddFunctionDefNoSerialization(self._handle, fdef) def get_function_def(self, name): """Get a function definition from the context. diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 65228aeb2bbe19..5a641aba2da70f 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -224,7 +224,7 @@ def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): _lift_unlifted_variables(fn_graph, variable_holder) # We call __init__ after lifting variables so that the function's signature # properly reflects the new captured inputs. - for f in fn_graph.as_graph_def().library.function: + for f in fn_graph.as_graph_def(use_pybind11_proto=True).library.function: context.context().add_function_def(f) self._signature = signature function_type = function_type_lib.from_structured_signature( diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index cae983b25dfb02..21fecca23371bc 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -858,6 +858,23 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // TODO(b/309152522): Remove the switch once it works on Windows. #if !IS_OSS pybind11_protobuf::ImportNativeProtoCasters(); + m.def( + "TFE_ContextAddFunctionDefNoSerialization", + [](py::handle& ctx, tensorflow::FunctionDef function_def) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + // Annotate eager runtime construction context to the given + // `function_def` as an attribute. + tensorflow::AttrValue value; + SetAttrValue("kEagerRuntime", &value); + (*function_def.mutable_attr())["_construction_context"] = value; + status->status = tensorflow::unwrap(tensorflow::InputTFE_Context(ctx)) + ->AddFunctionDef(function_def); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return; + }, + pybind11::arg("ctx"), pybind11::arg("function_def")); + m.def("TFE_ContextGetFunctionDefNoSerialization", [](py::handle& ctx, const char* function_name) -> tensorflow::FunctionDef { @@ -885,6 +902,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) { LOG(FATAL) << "This function cannot be called."; return -1; }); + m.def("TFE_ContextAddFunctionDefNoSerialization", + // Opensource fails whenever a protobuf is used as argument. The + // disrepency in the type is to make opensource tests pass. + [](py::handle& ctx, int function_def) { + LOG(FATAL) << "This function cannot be called."; + return -1; + }); + #endif m.def("TFE_ContextGetGraphDebugInfo", [](py::handle& ctx, const char* function_name, TF_Buffer& buf) { From db2c3e26a8157c791480b65c6b7ddc98d8705041 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 11:39:17 -0800 Subject: [PATCH 128/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/694731c33516aaac207ecd4378e7821a25dc86d7. PiperOrigin-RevId: 582750294 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 12ea2f27056a25..28ddd55af5174b 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "fce7c27b3191264a8ed581e03900e094c793593a" - TFRT_SHA256 = "3fe89488c1f138c36e9b4b6a220fe5ea6ecd184163c20917b0ab7eb215d32979" + TFRT_COMMIT = "694731c33516aaac207ecd4378e7821a25dc86d7" + TFRT_SHA256 = "31fcb5b04a4c41ae13ca2e1b92d48cead97bfa6f6fa3f637e74dd0b596dbf195" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 12ea2f27056a25..28ddd55af5174b 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "fce7c27b3191264a8ed581e03900e094c793593a" - TFRT_SHA256 = "3fe89488c1f138c36e9b4b6a220fe5ea6ecd184163c20917b0ab7eb215d32979" + TFRT_COMMIT = "694731c33516aaac207ecd4378e7821a25dc86d7" + TFRT_SHA256 = "31fcb5b04a4c41ae13ca2e1b92d48cead97bfa6f6fa3f637e74dd0b596dbf195" tf_http_archive( name = "tf_runtime", From 39ba7cba3cbc339dcd0e37bfc948d3df56585290 Mon Sep 17 00:00:00 2001 From: Matt Kreileder Date: Wed, 15 Nov 2023 11:47:48 -0800 Subject: [PATCH 129/391] Remove 'experimental' note from C cancellation API. PiperOrigin-RevId: 582753190 --- tensorflow/lite/core/c/c_api.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/lite/core/c/c_api.h b/tensorflow/lite/core/c/c_api.h index 38433325452e9f..5c9ae3962ff6a7 100644 --- a/tensorflow/lite/core/c/c_api.h +++ b/tensorflow/lite/core/c/c_api.h @@ -285,8 +285,6 @@ TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsAddRegistrationExternal( /// /// By default it is disabled and calling to `TfLiteInterpreterCancel` will /// return kTfLiteError. See `TfLiteInterpreterCancel`. -/// -/// \warning This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterOptionsEnableCancellation( TfLiteInterpreterOptions* options, bool enable); @@ -457,8 +455,6 @@ TfLiteTensor* TfLiteInterpreterGetTensor(const TfLiteInterpreter* interpreter, /// /// Returns kTfLiteError if cancellation is not enabled via /// `TfLiteInterpreterOptionsEnableCancellation`. -/// -/// \warning This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterCancel( const TfLiteInterpreter* interpreter); From 0a24283d294d65208a274c0e59a2c369588396c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 12:31:12 -0800 Subject: [PATCH 130/391] Integrate LLVM at llvm/llvm-project@c5dd1bbcc37e Updates LLVM usage to match [c5dd1bbcc37e](https://github.com/llvm/llvm-project/commit/c5dd1bbcc37e) PiperOrigin-RevId: 582767215 --- third_party/llvm/generated.patch | 89 -------------------------------- third_party/llvm/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 91 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 502a7c7cd1863e..af1f3cebfc9100 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -10,92 +10,3 @@ diff -ruN --strip-trailing-cr a/clang/test/Analysis/builtin_signbit.cpp b/clang/ bool b; double d = -1.0; -diff -ruN --strip-trailing-cr a/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test b/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test ---- a/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test -+++ b/lldb/test/Shell/SymbolFile/DWARF/ignored_artificial_fields.test -@@ -1,4 +1,5 @@ - # UNSUPPORTED: system-darwin, system-windows -+# REQUIRES: gcc - - # Make sure the artifical field `vptr.ClassName` from gcc debug info is ignored. - # RUN: %build --compiler=gcc %S/Inputs/debug-types-expressions.cpp -o %t -diff -ruN --strip-trailing-cr a/llvm/lib/Support/CommandLine.cpp b/llvm/lib/Support/CommandLine.cpp ---- a/llvm/lib/Support/CommandLine.cpp -+++ b/llvm/lib/Support/CommandLine.cpp -@@ -1667,13 +1667,6 @@ - Handler = LookupLongOption(*ChosenSubCommand, ArgName, Value, - LongOptionsUseDoubleDash, HaveDoubleDash); - -- // If Handler is not found in a specialized subcommand, look up handler -- // in the top-level subcommand. -- // cl::opt without cl::sub belongs to top-level subcommand. -- if (!Handler && ChosenSubCommand != &SubCommand::getTopLevel()) -- Handler = LookupLongOption(SubCommand::getTopLevel(), ArgName, Value, -- LongOptionsUseDoubleDash, HaveDoubleDash); -- - // Check to see if this "option" is really a prefixed or grouped argument. - if (!Handler && !(LongOptionsUseDoubleDash && HaveDoubleDash)) - Handler = HandlePrefixedOrGroupedOption(ArgName, Value, ErrorParsing, -diff -ruN --strip-trailing-cr a/llvm/unittests/Support/CommandLineTest.cpp b/llvm/unittests/Support/CommandLineTest.cpp ---- a/llvm/unittests/Support/CommandLineTest.cpp -+++ b/llvm/unittests/Support/CommandLineTest.cpp -@@ -525,59 +525,6 @@ - EXPECT_FALSE(Errs.empty()); - } - --TEST(CommandLineTest, TopLevelOptInSubcommand) { -- enum LiteralOptionEnum { -- foo, -- bar, -- baz, -- }; -- -- cl::ResetCommandLineParser(); -- -- // This is a top-level option and not associated with a subcommand. -- // A command line using subcommand should parse both subcommand options and -- // top-level options. A valid use case is that users of llvm command line -- // tools should be able to specify top-level options defined in any library. -- cl::opt TopLevelOpt("str", cl::init("txt"), -- cl::desc("A top-level option.")); -- -- StackSubCommand SC("sc", "Subcommand"); -- StackOption PositionalOpt( -- cl::Positional, cl::desc("positional argument test coverage"), -- cl::sub(SC)); -- StackOption LiteralOpt( -- cl::desc("literal argument test coverage"), cl::sub(SC), cl::init(bar), -- cl::values(clEnumVal(foo, "foo"), clEnumVal(bar, "bar"), -- clEnumVal(baz, "baz"))); -- StackOption EnableOpt("enable", cl::sub(SC), cl::init(false)); -- StackOption ThresholdOpt("threshold", cl::sub(SC), cl::init(1)); -- -- const char *PositionalOptVal = "input-file"; -- const char *args[] = {"prog", "sc", PositionalOptVal, -- "-enable", "--str=csv", "--threshold=2"}; -- -- // cl::ParseCommandLineOptions returns true on success. Otherwise, it will -- // print the error message to stderr and exit in this setting (`Errs` ostream -- // is not set). -- ASSERT_TRUE(cl::ParseCommandLineOptions(sizeof(args) / sizeof(args[0]), args, -- StringRef())); -- EXPECT_STREQ(PositionalOpt.getValue().c_str(), PositionalOptVal); -- EXPECT_TRUE(EnableOpt); -- // Tests that the value of `str` option is `csv` as specified. -- EXPECT_STREQ(TopLevelOpt.getValue().c_str(), "csv"); -- EXPECT_EQ(ThresholdOpt, 2); -- -- for (auto &[LiteralOptVal, WantLiteralOpt] : -- {std::pair{"--bar", bar}, {"--foo", foo}, {"--baz", baz}}) { -- const char *args[] = {"prog", "sc", LiteralOptVal}; -- ASSERT_TRUE(cl::ParseCommandLineOptions(sizeof(args) / sizeof(args[0]), -- args, StringRef())); -- -- // Tests that literal options are parsed correctly. -- EXPECT_EQ(LiteralOpt, WantLiteralOpt); -- } --} -- - TEST(CommandLineTest, AddToAllSubCommands) { - cl::ResetCommandLineParser(); - diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index d462b582b7076b..d3a0e41d030015 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "b88308b1b4813e55ce8f54ceff6e57736328fb58" - LLVM_SHA256 = "fe27af49a5596d91929b989b40cb2879829914802ad45b6d947ca5e070ec20d4" + LLVM_COMMIT = "c5dd1bbcc37e8811e7c6050159014d084eac6438" + LLVM_SHA256 = "f374bf677707588fc07235215e2bf03e27dfb299c4b478306dc918099c60b583" tf_http_archive( name = name, From 4fa5f7bb88921e8ac416d7c86bedef55af671cd7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 12:47:38 -0800 Subject: [PATCH 131/391] Add --expt-extended-lambda and --expt-relaxed-constexpr options for NVCC compiler. PiperOrigin-RevId: 582772885 --- .../crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl | 2 +- .../crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 7303a1355b1b30..0da1d7b58f4bb0 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -263,7 +263,7 @@ def InvokeNvcc(argv, log=False): # may cause compilation failure or incorrect run time execution. # Use at your own risk. if USE_CLANG_COMPILER: - nvccopts += ' -allow-unsupported-compiler' + nvccopts += ' -allow-unsupported-compiler --expt-extended-lambda --expt-relaxed-constexpr ' if depfiles: # Generate the dependency file diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 7303a1355b1b30..0da1d7b58f4bb0 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -263,7 +263,7 @@ def InvokeNvcc(argv, log=False): # may cause compilation failure or incorrect run time execution. # Use at your own risk. if USE_CLANG_COMPILER: - nvccopts += ' -allow-unsupported-compiler' + nvccopts += ' -allow-unsupported-compiler --expt-extended-lambda --expt-relaxed-constexpr ' if depfiles: # Generate the dependency file From a2c7b935eb8dd55025b303d9530ec408af14fc47 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 12:48:21 -0800 Subject: [PATCH 132/391] Allow CompilationEnvironments::ProcessNewEnvFn to return error status. PiperOrigin-RevId: 582773087 --- third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/compilation_environments.cc | 5 +++-- third_party/xla/xla/service/compilation_environments.h | 5 +++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 0094c3e6336835..9b2c0ff3dc079c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -6921,6 +6921,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", # fixdeps: keep "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/compilation_environments.cc b/third_party/xla/xla/service/compilation_environments.cc index 2268268744b11f..48a6ea0f1cd0bf 100644 --- a/third_party/xla/xla/service/compilation_environments.cc +++ b/third_party/xla/xla/service/compilation_environments.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -253,8 +254,8 @@ Status CompilationEnvironments::AddEnvImpl( return tsl::errors::InvalidArgument( "Unknown compilation environment type: %s", descriptor.full_name()); } - std::unique_ptr processed_env = - process_new_env(std::move(env)); + TF_ASSIGN_OR_RETURN(std::unique_ptr processed_env, + process_new_env(std::move(env))); // Check for unknown fields const tsl::protobuf::UnknownFieldSet& unknown_fields = diff --git a/third_party/xla/xla/service/compilation_environments.h b/third_party/xla/xla/service/compilation_environments.h index d3778a96c8553f..34fb9edebb8a0f 100644 --- a/third_party/xla/xla/service/compilation_environments.h +++ b/third_party/xla/xla/service/compilation_environments.h @@ -46,8 +46,9 @@ namespace xla { // CompilationEnvironments is not thread-safe. class CompilationEnvironments { public: - using ProcessNewEnvFn = std::function( - std::unique_ptr)>; + using ProcessNewEnvFn = + std::function>( + std::unique_ptr)>; CompilationEnvironments() = default; CompilationEnvironments(const CompilationEnvironments& rhs) { *this = rhs; } From cf0dfaeedf59a57c53756e48d9fc5eb75260bac8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 12:49:52 -0800 Subject: [PATCH 133/391] Turn on clang+nvcc compiler for rbe_linux_cuda_nvcc config. PiperOrigin-RevId: 582773576 --- .bazelrc | 2 +- third_party/xla/.bazelrc | 2 +- third_party/xla/third_party/tsl/.bazelrc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index cf45e783817917..b11faac6bc96c0 100644 --- a/.bazelrc +++ b/.bazelrc @@ -527,6 +527,7 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true @@ -536,7 +537,6 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,s build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc --config=rbe_linux diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index cf45e783817917..b11faac6bc96c0 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -527,6 +527,7 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true @@ -536,7 +537,6 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,s build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc --config=rbe_linux diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index cf45e783817917..b11faac6bc96c0 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -527,6 +527,7 @@ build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.16-clang_confi test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=cuda +build:rbe_linux_cuda_nvcc --action_env=TF_NVCC_CLANG="1" build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_linux_cuda_nvcc --@local_xla//xla/python:enable_gpu=true build:rbe_linux_cuda_nvcc --@local_xla//xla/python:jax_cuda_pip_rpaths=true @@ -536,7 +537,6 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,s build:rbe_linux_cuda_nvcc --action_env=TF_CUDA_VERSION="12" build:rbe_linux_cuda_nvcc --action_env=TF_CUDNN_VERSION="8" build:rbe_linux_cuda_nvcc --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.2" -build:rbe_linux_cuda_nvcc --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:rbe_linux_cuda_nvcc --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.2-cudnn8.9_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc --config=rbe_linux From 4218ecacc8c86bd3cf73c013ff0932120c361f91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 13:07:16 -0800 Subject: [PATCH 134/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/0bb14f40fe30f95a190e6b932cb3cc1ed7376d8a. PiperOrigin-RevId: 582780091 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 28ddd55af5174b..e17daad0c7dbcf 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "694731c33516aaac207ecd4378e7821a25dc86d7" - TFRT_SHA256 = "31fcb5b04a4c41ae13ca2e1b92d48cead97bfa6f6fa3f637e74dd0b596dbf195" + TFRT_COMMIT = "0bb14f40fe30f95a190e6b932cb3cc1ed7376d8a" + TFRT_SHA256 = "0c0544c04e42a3967382c358d72d1f9dd957c939d003d7778671ed73e404f753" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 28ddd55af5174b..e17daad0c7dbcf 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "694731c33516aaac207ecd4378e7821a25dc86d7" - TFRT_SHA256 = "31fcb5b04a4c41ae13ca2e1b92d48cead97bfa6f6fa3f637e74dd0b596dbf195" + TFRT_COMMIT = "0bb14f40fe30f95a190e6b932cb3cc1ed7376d8a" + TFRT_SHA256 = "0c0544c04e42a3967382c358d72d1f9dd957c939d003d7778671ed73e404f753" tf_http_archive( name = "tf_runtime", From 0228381427549d9dc0339b7645eea301aada347e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 13:13:13 -0800 Subject: [PATCH 135/391] When respecting user sharding annotations, consider all sharding strategies that match the annotation. PiperOrigin-RevId: 582782018 --- .../xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding.cc | 32 +++++++++++-------- .../auto_sharding/auto_sharding_cost_graph.h | 5 +-- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 27df7e0fe86b17..d629719b8d01e9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -133,6 +133,7 @@ cc_library( deps = [ ":auto_sharding_strategy", ":matrix", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index e7646ad59bf67f..bdf3c393b01df7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1270,21 +1270,26 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. strategies->following = nullptr; - int32_t strategy_index = -1; + std::vector strategy_indices; for (size_t i = 0; i < strategies->leaf_vector.size(); i++) { if (strategies->leaf_vector[i].output_sharding == existing_sharding) { - strategy_index = i; + strategy_indices.push_back(i); } } - if (strategy_index >= 0) { - VLOG(1) << "Keeping strategy index: " << strategy_index; + if (!strategy_indices.empty()) { + VLOG(1) << "Keeping strategy indices: " + << spmd::ToString(strategy_indices); // Stores other strategies in the map, removes them in the vector and // only keeps the one we found. - ShardingStrategy found_strategy = - strategies->leaf_vector[strategy_index]; pretrimmed_strategy_map[strategies->node_idx] = strategies->leaf_vector; + std::vector new_leaf_vector; + for (int32_t found_strategy_index : strategy_indices) { + ShardingStrategy found_strategy = + strategies->leaf_vector[found_strategy_index]; + new_leaf_vector.push_back(found_strategy); + } strategies->leaf_vector.clear(); - strategies->leaf_vector.push_back(found_strategy); + strategies->leaf_vector = new_leaf_vector; } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); @@ -1330,16 +1335,17 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, resharding_costs, input_shardings})); } - CHECK_EQ(strategies->leaf_vector.size(), 1); // If there is only one option for resharding, and the cost computed for // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. - for (auto& operand_resharding_costs : - strategies->leaf_vector[0].resharding_costs) { - if (operand_resharding_costs.size() == 1 && - operand_resharding_costs[0] >= kInfinityCost) { - operand_resharding_costs[0] = 0; + if (strategies->leaf_vector.size() == 1) { + for (auto& operand_resharding_costs : + strategies->leaf_vector[0].resharding_costs) { + if (operand_resharding_costs.size() == 1 && + operand_resharding_costs[0] >= kInfinityCost) { + operand_resharding_costs[0] = 0; + } } } } else if (!strategies->following) { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 21808e3fc6eafc..6102683fe4fdf4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/matrix.h" @@ -113,8 +114,8 @@ class CostGraph { Matrix CreateEdgeCost(NodeIdx src_idx, NodeIdx dst_idx, size_t in_node_idx, StrategyVector* strategies, bool zero_cost = false) { - CHECK_GE(node_lens_.size(), src_idx); - CHECK_GE(node_lens_.size(), dst_idx); + CHECK_LT(src_idx, node_lens_.size()); + CHECK_LT(dst_idx, node_lens_.size()); Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); for (NodeStrategyIdx k = 0; k < strategies->leaf_vector.size(); ++k) { const ShardingStrategy& strategy = strategies->leaf_vector[k]; From 3a7b0aa98b495dabade91e3b9169652084e2a03b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 13:15:26 -0800 Subject: [PATCH 136/391] [HloValueSemanticsAnalysis] Adds handler for RNG bit generator. PiperOrigin-RevId: 582782655 --- .../xla/xla/service/hlo_value_semantics_analysis.cc | 10 ++++++++++ .../xla/xla/service/hlo_value_semantics_analysis.h | 1 + 2 files changed, 11 insertions(+) diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index dfb029a54194fa..2e7f98cf5f357f 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -1071,6 +1071,16 @@ Status HloValueSemanticsPropagation::HandleReplicaId( return OkStatus(); } +Status HloValueSemanticsPropagation::HandleRngBitGenerator( + HloInstruction* rng_bit_generator) { + const HloValueSemantics* semantics = analysis_->NewHloValueSemantics( + HloValueSemanticLabel::kRandom, {rng_bit_generator, {}}); + ShapeTree rbg_semantics_tree( + rng_bit_generator->shape(), semantics); + analysis_->SetHloValueSemantics(rng_bit_generator, rbg_semantics_tree); + return OkStatus(); +} + Status HloValueSemanticsPropagation::HandleClamp(HloInstruction* clamp) { const ShapeTree& operand_semantics = analysis_->GetInstructionSemantics(clamp->operand(1)); diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index c0491079555a90..7037951eddaaf0 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -240,6 +240,7 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { Status HandleAsyncDone(HloInstruction* async_done) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleDomain(HloInstruction* domain) override; + Status HandleRngBitGenerator(HloInstruction* rng_bit_generator) override; protected: HloValueSemantics CopySemantics(const HloValueSemantics& semantics) const; From d013ae4342c855e343332723035ab09211d08b58 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 15 Nov 2023 13:25:41 -0800 Subject: [PATCH 137/391] [xla:gpu] Add optional custom fusions pass to XLA:GPU pipeline PiperOrigin-RevId: 582785959 --- third_party/xla/xla/debug_options_flags.cc | 13 ++++ third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/custom_fusion_rewriter.h | 3 +- .../xla/xla/service/gpu/gpu_compiler.cc | 39 +++++++---- third_party/xla/xla/service/gpu/kernels/BUILD | 14 ++++ .../gpu/kernels/custom_fusion_pattern.cc | 5 ++ .../gpu/kernels/custom_fusion_pattern.h | 19 ++++++ .../gpu/kernels/cutlass_gemm_fusion.cc | 67 +++++++++++++++++-- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 20 +++--- .../gpu/kernels/cutlass_gemm_kernel.cu.cc | 18 ++++- third_party/xla/xla/xla.proto | 11 ++- 11 files changed, 178 insertions(+), 33 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index e7b4805aa2a34f..d46444f7910594 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -136,6 +136,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_dumping(true); opts.set_xla_gpu_enable_xla_runtime_executable(true); + opts.set_xla_gpu_enable_custom_fusions(false); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); @@ -1066,6 +1067,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_xla_runtime_executable), debug_options->xla_gpu_enable_xla_runtime_executable(), "Whether to enable XLA runtime for XLA:GPU backend")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_custom_fusions", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_custom_fusions), + debug_options->xla_gpu_enable_custom_fusions(), + "Whether to enable XLA custom fusions")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_custom_fusions_re", + string_setter_for(&DebugOptions::set_xla_gpu_enable_custom_fusions_re), + debug_options->xla_gpu_enable_custom_fusions_re(), + "Limits custom fusion only to fusions which match this regular " + "expression. Default is all custom fusions registerered in a current " + "process.")); flag_list->push_back( tsl::Flag("xla_gpu_enable_gpu2_runtime", bool_setter_for(&DebugOptions::set_xla_gpu_enable_gpu2_runtime), diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2bbdaa6deeb495..757c3f313aae17 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2674,6 +2674,7 @@ cc_library( "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu/kernels:custom_fusion_library", "//xla/service/gpu/kernels:custom_fusion_pattern", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -2769,6 +2770,7 @@ cc_library( ":compile_module_to_llvm_ir", ":conv_layout_normalization", ":copy_fusion", + ":custom_fusion_rewriter", ":dot_dimension_sorter", ":executable_proto_cc", ":fusion_merger", diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h index 2d7ce59207e3c0..1e5e643a4a4c13 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.h @@ -61,7 +61,8 @@ namespace xla::gpu { // class CustomFusionRewriter : public HloModulePass { public: - explicit CustomFusionRewriter(const CustomFusionPatternRegistry* patterns); + explicit CustomFusionRewriter(const CustomFusionPatternRegistry* patterns = + CustomFusionPatternRegistry::Default()); absl::string_view name() const override { return "custom-fusion-rewriter"; } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0c4e6fe52701b2..551a386c203de8 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -117,6 +117,7 @@ limitations under the License. #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" #include "xla/service/gpu/copy_fusion.h" +#include "xla/service/gpu/custom_fusion_rewriter.h" #include "xla/service/gpu/dot_dimension_sorter.h" #include "xla/service/gpu/fusion_pipeline.h" #include "xla/service/gpu/fusion_wrapper.h" @@ -1053,6 +1054,17 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( }); pipeline.AddPass>(); + // Greedy pattern matching for custom fusions. We run it before Triton + // rewriter or a regular Gemm rewriter to be able to match compatible GEMMs + // before they matched into Triton gemm or a cuBLAS custom call. + // + // TODO(ezhulenev): This should be plugged into the cost model and fusion + // heuristic, so we can mix and match various Gemm implementations based + // on projected (measured) performance. + if (debug_options.xla_gpu_enable_custom_fusions()) { + pipeline.AddPass(); + } + // Rewrite GEMMs into custom calls. se::GpuComputeCapability gpu_version = gpu_target_config.device_description.gpu_compute_capability(); @@ -1500,19 +1512,19 @@ StatusOr GpuCompiler::CompileToTargetBinary( llvm_modules.size()); tsl::BlockingCounter counter(llvm_modules.size()); for (int i = 0; i < llvm_modules.size(); i++) { - thread_pool->Schedule( - [&compile_results, i, &llvm_modules, &counter, this, &module_config, - &gpu_version, &debug_module, &options] { - // Each thread has its own context to avoid race conditions. - llvm::LLVMContext new_context; - std::unique_ptr new_module = - CopyToContext(*llvm_modules.at(i), new_context); - compile_results.at(i) = CompileSingleModule( - module_config, gpu_version, debug_module, new_module.get(), - /*relocatable=*/true, options, - /*shard_number=*/i); - counter.DecrementCount(); - }); + thread_pool->Schedule([&compile_results, i, &llvm_modules, &counter, this, + &module_config, &gpu_version, &debug_module, + &options] { + // Each thread has its own context to avoid race conditions. + llvm::LLVMContext new_context; + std::unique_ptr new_module = + CopyToContext(*llvm_modules.at(i), new_context); + compile_results.at(i) = CompileSingleModule( + module_config, gpu_version, debug_module, new_module.get(), + /*relocatable=*/true, options, + /*shard_number=*/i); + counter.DecrementCount(); + }); } counter.Wait(); @@ -1771,7 +1783,6 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // TODO(ezhulenev): Unify AOT compilation with GpuRuntimeExecutable::Create // (see `gpu/runtime/executable.h`). - // Options for the default XLA runtime compilation pipeline. runtime::CompilationPipelineOptions copts; diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index bb68c506b01f34..5f71b06d94dfc2 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -43,6 +43,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/base:core_headers", ], ) @@ -60,6 +61,16 @@ cc_library( ], ) +# Bundle all custom fusions into a single target, so we can link all fusions and patterns by adding +# a single dependency. +cc_library( + name = "custom_fusion_library", + visibility = ["//visibility:public"], + # copybara:uncomment_begin(google-only) + # deps = [":cutlass_gemm_fusion"], + # copybara:uncomment_end(google-only) +) + # copybara:uncomment_begin(google-only) # # TODO(ezhulenev): We currently do not have a CUTLASS dependency in open source BUILD. # @@ -68,9 +79,12 @@ cc_library( # srcs = ["cutlass_gemm_fusion.cc"], # deps = [ # ":custom_fusion", +# ":custom_fusion_pattern", # ":custom_kernel", # ":cutlass_gemm_kernel", # "@com_google_absl//absl/status", +# "//xla:shape_util", +# "//xla:status", # "//xla:statusor", # "//xla:xla_data_proto_cc", # "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc index 99f3fcb6274870..db4d5d21409189 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.cc @@ -23,6 +23,11 @@ limitations under the License. namespace xla::gpu { +CustomFusionPatternRegistry* CustomFusionPatternRegistry::Default() { + static auto* registry = new CustomFusionPatternRegistry(); + return registry; +} + std::vector CustomFusionPatternRegistry::Match( HloInstruction* instr) const { std::vector matches; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h index f040d647490add..02123388e26004 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_fusion_pattern.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -56,6 +57,10 @@ class CustomFusionPattern { class CustomFusionPatternRegistry { public: + // Returns a pointer to a default custom fusion pattern registry, which is a + // global static registry. + static CustomFusionPatternRegistry *Default(); + std::vector Match(HloInstruction *instr) const; void Add(std::unique_ptr pattern); @@ -71,4 +76,18 @@ class CustomFusionPatternRegistry { } // namespace xla::gpu +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN(PATTERN) \ + XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, __COUNTER__) + +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN_(PATTERN, N) \ + XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) + +#define XLA_REGISTER_CUSTOM_FUSION_PATTERN__(PATTERN, N) \ + ABSL_ATTRIBUTE_UNUSED static const bool \ + xla_custom_fusion_pattern_##N##_registered_ = [] { \ + ::xla::gpu::CustomFusionPatternRegistry::Default() \ + ->Emplace(); \ + return true; \ + }() + #endif // XLA_SERVICE_GPU_KERNELS_CUSTOM_FUSION_PATTERN_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 6fb46343652591..250eec73172bdc 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -14,37 +14,91 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include "absl/status/status.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/kernels/custom_fusion.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/shape.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla::gpu { +//===----------------------------------------------------------------------===// +// Cutlass Gemm pattern matching helpers +//===----------------------------------------------------------------------===// + +static Status IsF32Gemm(const HloDotInstruction* dot) { + const Shape& lhs = dot->operand(0)->shape(); + const Shape& rhs = dot->operand(1)->shape(); + const Shape& out = dot->shape(); + + if (lhs.dimensions_size() != 2 || rhs.dimensions_size() != 2) + return absl::InternalError("dot operands must have rank 2"); + + if (lhs.element_type() != PrimitiveType::F32 || + rhs.element_type() != PrimitiveType::F32 || + out.element_type() != PrimitiveType::F32) + return absl::InternalError("dot operations must use F32 data type"); + + // Check that we do not transpose any of the operands. + auto& dot_dims = dot->dot_dimension_numbers(); + + if (dot_dims.lhs_contracting_dimensions().size() != 1 || + dot_dims.lhs_contracting_dimensions()[0] != 1) + return absl::InternalError("lhs contracting dimensions must be 1"); + + if (dot_dims.rhs_contracting_dimensions().size() != 1 || + dot_dims.rhs_contracting_dimensions()[0] != 0) + return absl::InternalError("rhs contracting dimensions must be 0"); + + return OkStatus(); +} + +//===----------------------------------------------------------------------===// +// CutlassGemmPattern +//===----------------------------------------------------------------------===// + +class CutlassGemmPattern : public CustomFusionPattern { + public: + std::optional TryMatch(HloInstruction* instr) const override { + auto* dot = DynCast(instr); + if (!dot || !IsF32Gemm(dot).ok()) return std::nullopt; + + CustomFusionConfig config; + config.set_name("cutlass_gemm"); + return Match{config, {instr}}; + } +}; + +//===----------------------------------------------------------------------===// +// CutlassGemmFusion +//===----------------------------------------------------------------------===// + class CutlassGemmFusion : public CustomFusion { public: StatusOr> LoadKernels( const HloComputation* computation) const final { - // TODO(ezhulenev): This is the most basic check to pass a single test we - // have today. Expand it to properly check all invariants of a dot - // instruction supported by CUTLASS gemm kernels. auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) return absl::InternalError( "cutlass_gemm requires ROOT operation to be a dot"); - PrimitiveType dtype = dot->shape().element_type(); - if (dtype != PrimitiveType::F32) - return absl::InternalError("Unsupported element type"); + TF_RETURN_IF_ERROR(IsF32Gemm(dot)); + + auto dtype = dot->shape().element_type(); auto& lhs_shape = dot->operand(0)->shape(); auto& rhs_shape = dot->operand(1)->shape(); @@ -61,4 +115,5 @@ class CutlassGemmFusion : public CustomFusion { } // namespace xla::gpu +XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmPattern); XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", ::xla::gpu::CutlassGemmFusion); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index e9f23f5604aa1b..4829648dd03cc7 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -36,28 +36,28 @@ TEST_F(CutlassFusionTest, SimpleF32Gemm) { HloModule cublas ENTRY e { - arg0 = f32[32, 64]{1,0} parameter(0) - arg1 = f32[64, 16]{1,0} parameter(1) - gemm = (f32[32,16]{1,0}, s8[0]{0}) custom-call(arg0, arg1), + arg0 = f32[100,784]{1,0} parameter(0) + arg1 = f32[784,10]{1,0} parameter(1) + gemm = (f32[100,10]{1,0}, s8[0]{0}) custom-call(arg0, arg1), custom_call_target="__cublas$gemm", backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} - ROOT get-tuple-element = f32[32,16]{1,0} get-tuple-element((f32[32,16]{1,0}, s8[0]{0}) gemm), index=0 + ROOT get-tuple-element = f32[100,10]{1,0} get-tuple-element((f32[100,10]{1,0}, s8[0]{0}) gemm), index=0 })"; const char* hlo_text_custom_fusion = R"( HloModule cutlass cutlass_gemm { - arg0 = f32[32,64]{1,0} parameter(0) - arg1 = f32[64,16]{1,0} parameter(1) - ROOT dot = f32[32,16]{1,0} dot(arg0, arg1), + arg0 = f32[100,784]{1,0} parameter(0) + arg1 = f32[784,10]{1,0} parameter(1) + ROOT dot = f32[100,10]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { - arg0 = f32[32, 64]{1,0} parameter(0) - arg1 = f32[64, 16]{1,0} parameter(1) - ROOT _ = f32[32,16]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, + arg0 = f32[100,784]{1,0} parameter(0) + arg1 = f32[784,10]{1,0} parameter(1) + ROOT _ = f32[100,10]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}} })"; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc index 07c7601af55987..5a897b830d3c05 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.cc @@ -54,7 +54,9 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, size_t shared_memory_bytes = sizeof(typename GemmKernel::SharedStorage); // Packs device memory arguments into CUTLASS kernel parameters struct. - auto pack = [problem_size, tiled_shape](const se::KernelArgs &args) { + using PackedArgs = StatusOr>; + auto pack = [problem_size, + tiled_shape](const se::KernelArgs &args) -> PackedArgs { auto *mem_args = Cast(&args); // Converts DeviceMemoryBase to an opaque `void *` device pointer. @@ -71,6 +73,20 @@ StatusOr GetCutlassGemmKernel(PrimitiveType dtype, int32_t m, int32_t ldb = problem_size.n(); int32_t ldc = problem_size.n(); + // Check if GemmKernel can implement the given problem size. + cutlass::Status can_implement = GemmKernel::can_implement( + problem_size, // problem size + {device_ptr(0), lda}, // Tensor-ref for source matrix A + {device_ptr(1), ldb}, // Tensor-ref for source matrix B + {device_ptr(2), ldc}, // Tensor-ref for source matrix C + {device_ptr(2), ldc} // Tensor-ref for destination matrix D + ); + + if (can_implement != cutlass::Status::kSuccess) { + return absl::InternalError( + "CUTLASS GemmKernel can not implement gemm for a given problem size"); + } + // Sanity check that we do not accidentally get a giant parameters struct. static_assert(sizeof(GemmKernel::Params) < 512, "GemmKernel::Params struct size is unexpectedly large"); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d579a8c11a30e8..ec3cc81380d4ff 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -436,6 +436,15 @@ message DebugOptions { // If true, use XLA runtime for XLA:GPU backend. bool xla_gpu_enable_xla_runtime_executable = 169; + // If true, XLA will try to pattern match subgraphs of HLO operations into + // custom fusions registered in the current process (pre-compiled hand written + // kernels, e.g. various GEMM fusions writtent in CUTLASS). + bool xla_gpu_enable_custom_fusions = 263; + + // A regular expression enabling only a subset of custom fusions. Enabled only + // if `xla_gpu_enable_custom_fusion` set to true. + string xla_gpu_enable_custom_fusions_re = 264; + // If true, use OpenXLA runtime for XLA:GPU backend. That is, use IREE VM // as a host executable, optional CUDA HAL for dispatching device kernels and // custom modules for integration with libraries required for running @@ -663,7 +672,7 @@ message DebugOptions { // Enable radix sort using CUB. bool xla_gpu_enable_cub_radix_sort = 259; - // Next id: 263 + // Next id: 265 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 1ed9ba9c5e7b37f79f2af0546f4bd5ced1b1d2ef Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Wed, 15 Nov 2023 13:25:56 -0800 Subject: [PATCH 138/391] Remove deprecated target "python/framework:framework". Also, migrate the last few remaining references away from the deprecated target. PiperOrigin-RevId: 582786070 --- tensorflow/python/framework/BUILD | 75 ------------------------------- 1 file changed, 75 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 38b5ad2ac9344d..f7d2e01d74fbe4 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -301,81 +301,6 @@ py_strict_library( ], ) -py_strict_library( - name = "framework", - deprecation = "This target has been split. Depend on the sub-targets instead.", - srcs_version = "PY3", - visibility = visibility + ["//tensorflow:internal"], - deps = [ - ":_errors_test_helper", - ":_pywrap_python_api_dispatcher", - ":_pywrap_python_api_info", - ":_pywrap_python_api_parameter_converter", - ":_pywrap_python_op_gen", - ":byte_swap_tensor", - ":c_api_util", - ":composite_tensor", - ":config", - ":cpp_shape_inference_proto_py", - ":device", - ":dtypes", - ":error_interpolation", - ":errors", - ":fast_tensor_util", - ":for_generated_wrappers", - ":framework_lib", - ":function", - ":graph_io", - ":graph_util", - ":importer", - ":indexed_slices", - ":load_library", - ":meta_graph", - ":op_def_registry", - ":ops", - ":random_seed", - ":sparse_tensor", - ":tensor", - ":tensor_conversion_registry", - ":tensor_shape", - ":tensor_spec", - ":tensor_util", - ":type_spec", - ":versions", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:_pywrap_py_exception_registry", - "//tensorflow/python:_pywrap_quantize_training", - "//tensorflow/python:pywrap_mlir", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:pywrap_tfe", - "//tensorflow/python:tf2", - "//tensorflow/python/client:_pywrap_debug_events_writer", - "//tensorflow/python/client:_pywrap_events_writer", - "//tensorflow/python/client:pywrap_tf_session", - "//tensorflow/python/eager:context", - "//tensorflow/python/lib/core:_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed. - "//tensorflow/python/lib/io:file_io", - "//tensorflow/python/ops:control_flow_util", - "//tensorflow/python/platform:_pywrap_stacktrace_handler", - "//tensorflow/python/platform:tf_logging", - "//tensorflow/python/util:_pywrap_checkpoint_reader", - "//tensorflow/python/util:_pywrap_kernel_registry", - "//tensorflow/python/util:_pywrap_nest", - "//tensorflow/python/util:_pywrap_stat_summarizer", - "//tensorflow/python/util:_pywrap_tfprof", - "//tensorflow/python/util:_pywrap_transform_graph", - "//tensorflow/python/util:_pywrap_util_port", - "//tensorflow/python/util:_pywrap_utils", - "//tensorflow/python/util:compat", - "//tensorflow/python/util:deprecation", - "//tensorflow/python/util:tf_export", - "//third_party/py/numpy", - "@pypi_packaging//:pkg", - ] + if_xla_available([ - "//tensorflow/python:_pywrap_tfcompile", - ]), -) - py_strict_library( name = "byte_swap_tensor", srcs = ["byte_swap_tensor.py"], From 5513a07ea3edb2cfed17fd8972e42a2d22a620e2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 15 Nov 2023 13:35:14 -0800 Subject: [PATCH 139/391] Reapply: Make the CPU backend participate in distributed initialization. The main effect of this change is that CPU devices end up with a unique global ID and the correct process index. PiperOrigin-RevId: 582789416 --- third_party/xla/xla/pjrt/BUILD | 2 + .../xla/xla/pjrt/tfrt_cpu_pjrt_client.cc | 92 ++++-- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.h | 27 +- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/xla.cc | 27 +- third_party/xla/xla/python/xla_client.py | 157 +++++++---- third_party/xla/xla/python/xla_client.pyi | 6 +- .../xla/xla/python/xla_extension/__init__.pyi | 261 +++++++++++------- 8 files changed, 389 insertions(+), 184 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 29bd730a136c16..abbb5886b81cdd 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -678,6 +678,7 @@ cc_library( "//xla/client:executable_build_options", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:topology_util", "//xla/runtime:cpu_event", "//xla/service:buffer_assignment", "//xla/service:compiler", @@ -706,6 +707,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc index 863f46ce2725af..ad4b5ea45711ec 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -57,6 +58,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/abstract_tfrt_cpu_buffer.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -236,7 +238,11 @@ class TfrtCpuAsyncHostToDeviceTransferManager } // namespace -TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id) : id_(id) { +TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id, int process_index, + int local_hardware_id) + : id_(id), + process_index_(process_index), + local_hardware_id_(local_hardware_id) { debug_string_ = absl::StrCat("TFRT_CPU_", id); to_string_ = absl::StrCat("CpuDevice(id=", id, ")"); } @@ -253,8 +259,9 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { return to_string_; } -TfrtCpuDevice::TfrtCpuDevice(int id, int max_inflight_computations) - : description_(id), +TfrtCpuDevice::TfrtCpuDevice(int id, int process_index, int local_hardware_id, + int max_inflight_computations) + : description_(id, process_index, local_hardware_id), max_inflight_computations_semaphore_( /*capacity=*/max_inflight_computations) {} @@ -281,30 +288,47 @@ static int CpuDeviceCount() { return GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } -static StatusOr>> GetTfrtCpuDevices( - int cpu_device_count, int max_inflight_computations_per_device) { - std::vector> devices; - for (int i = 0; i < cpu_device_count; ++i) { - auto device = std::make_unique( - /*id=*/i, max_inflight_computations_per_device); - devices.push_back(std::move(device)); - } - return std::move(devices); -} - StatusOr> GetTfrtCpuClient( const CpuClientOptions& options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); - TF_ASSIGN_OR_RETURN( - std::vector> devices, - GetTfrtCpuDevices(cpu_device_count, - options.max_inflight_computations_per_device)); + LocalTopologyProto local_topology; + local_topology.set_node_id(options.node_id); + std::string boot_id_str; + auto boot_id_str_or_status = GetBootIdString(); + if (!boot_id_str_or_status.ok()) { + LOG(INFO) << boot_id_str_or_status.status(); + } else { + boot_id_str = boot_id_str_or_status.value(); + } + local_topology.set_boot_id(boot_id_str); + for (int i = 0; i < cpu_device_count; ++i) { + DeviceProto* device_proto = local_topology.add_devices(); + device_proto->set_local_device_ordinal(i); + device_proto->set_name("cpu"); + } + + GlobalTopologyProto global_topology; + TF_RETURN_IF_ERROR( + ExchangeTopologies("cpu", options.node_id, options.num_nodes, + absl::Minutes(2), absl::Minutes(5), options.kv_get, + options.kv_put, local_topology, &global_topology)); + + std::vector> devices; + for (const LocalTopologyProto& node : global_topology.nodes()) { + for (const DeviceProto& device_proto : node.devices()) { + auto device = std::make_unique( + /*id=*/device_proto.global_device_id(), node.node_id(), + device_proto.local_device_ordinal(), + options.max_inflight_computations_per_device); + devices.push_back(std::move(device)); + } + } return std::unique_ptr(std::make_unique( - /*process_index=*/0, std::move(devices), num_threads)); + /*process_index=*/options.node_id, std::move(devices), num_threads)); } TfrtCpuClient::TfrtCpuClient( @@ -613,6 +637,25 @@ StatusOr> TfrtCpuClient::Compile( }, &num_replicas, &num_partitions, &device_assignment)); + // TODO(phawkins): cross-process computations aren't implemented yet. Check + // for these and error. + if (device_assignment) { + for (int replica = 0; replica < device_assignment->replica_count(); + ++replica) { + for (int computation = 0; + computation < device_assignment->computation_count(); + ++computation) { + int id = (*device_assignment)(replica, computation); + TF_ASSIGN_OR_RETURN(auto* device, LookupDevice(id)); + if (device->process_index() != process_index()) { + return InvalidArgument( + "Multiprocess computations aren't implemented on the CPU " + "backend."); + } + } + } + } + std::vector argument_layout_pointers; TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( computation, &LayoutUtil::GetWithDefaultLayout, options.argument_layouts, @@ -762,6 +805,10 @@ StatusOr> TfrtCpuClient::BufferFromHostBuffer( VLOG(2) << "TfrtCpuClient::BufferFromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); + if (!device->IsAddressable()) { + return InvalidArgument("Cannot copy array to non-addressable device %s", + device->DebugString()); + } TF_ASSIGN_OR_RETURN( std::unique_ptr tracked_device_buffer, AbstractTfrtCpuBuffer::BufferFromHostBufferHelper( @@ -833,6 +880,11 @@ StatusOr> TfrtCpuBuffer::CopyToDevice( return CopyToDeviceAcrossClients(dst_device); } + if (!dst_device->IsAddressable()) { + return InvalidArgument("Cannot copy array to non-addressable device %s", + dst_device->DebugString()); + } + TF_ASSIGN_OR_RETURN( std::unique_ptr tracked_device_buffer, CopyToDeviceHelper(client()->async_work_runner())); @@ -1179,7 +1231,7 @@ StatusOr TfrtCpuExecutable::ExecuteHelper( ExecutableRunOptions run_options; run_options.set_run_id(run_id); - run_options.set_device_ordinal(device->local_hardware_id()); + run_options.set_device_ordinal(device->id()); // Need to keep device_assignment alive until execution completes. run_options.set_device_assignment(device_assignment.get()); run_options.set_intra_op_thread_pool(client_->eigen_intraop_device()); diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h index c744d10ac3a9ea..e9543ab92e93e7 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -63,11 +63,13 @@ namespace xla { class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { public: - explicit TfrtCpuDeviceDescription(int id); + TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); int id() const override { return id_; } - int process_index() const override { return 0; } + int process_index() const override { return process_index_; } + + int local_hardware_id() const { return local_hardware_id_; } absl::string_view device_kind() const override; @@ -82,6 +84,8 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { private: int id_; + int process_index_; + int local_hardware_id_; std::string debug_string_; std::string to_string_; absl::flat_hash_map attributes_ = {}; @@ -89,7 +93,8 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { class TfrtCpuDevice final : public PjRtDevice { public: - explicit TfrtCpuDevice(int id, int max_inflight_computations = 32); + explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, + int max_inflight_computations = 32); const TfrtCpuDeviceDescription& description() const override { return description_; @@ -106,8 +111,9 @@ class TfrtCpuDevice final : public PjRtDevice { return process_index() == client()->process_index(); } - // Used as `device_ordinal`. - int local_hardware_id() const override { return id(); } + int local_hardware_id() const override { + return description_.local_hardware_id(); + } Status TransferToInfeed(const LiteralSlice& literal) override; @@ -518,6 +524,17 @@ struct CpuClientOptions { std::optional cpu_device_count = std::nullopt; int max_inflight_computations_per_device = 32; + + // Number of distributed nodes. node_id, kv_get, and kv_put are ignored if + // this is set to 1. + int num_nodes = 1; + + // My node ID. + int node_id = 0; + + // KV store primitives for sharing topology information. + PjRtClient::KeyValueGetCallback kv_get = nullptr; + PjRtClient::KeyValuePutCallback kv_put = nullptr; }; StatusOr> GetTfrtCpuClient( const CpuClientOptions& options); diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index c0a5938032acd7..c7367a7f35ba59 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1153,6 +1153,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:literal", diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 9ece9636a1d7d7..5f2d4321aaf8c6 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -42,6 +42,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "pybind11/attr.h" // from @pybind11 #include "pybind11/cast.h" // from @pybind11 @@ -493,16 +494,38 @@ static void Init(py::module_& m) { m.def( "get_tfrt_cpu_client", - [](bool asynchronous) -> std::shared_ptr { + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes) -> std::shared_ptr { py::gil_scoped_release gil_release; CpuClientOptions options; + if (distributed_client != nullptr) { + std::string key_prefix = "cpu:"; + options.kv_get = + [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + options.kv_put = [distributed_client, key_prefix]( + std::string_view k, + std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), + v); + }; + options.node_id = node_id; + options.num_nodes = num_nodes; + } + options.asynchronous = asynchronous; std::unique_ptr client = xla::ValueOrThrow(GetTfrtCpuClient(options)); return std::make_shared( ifrt::PjRtClient::Create(std::move(client))); }, - py::arg("asynchronous") = true); + py::arg("asynchronous") = true, py::arg("distributed_client") = nullptr, + py::arg("node_id") = 0, py::arg("num_nodes") = 1); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { xla::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 1a5c3fff7c4bcb..9e13fe5e4b0d66 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 215 +_version = 216 # Version number for MLIR:Python components. mlir_api_version = 54 @@ -63,11 +63,18 @@ _NameValueMapping = Mapping[str, Union[str, int, List[int], float, bool]] -def make_cpu_client() -> ...: - register_custom_call_handler( - 'cpu', _xla.register_custom_call_target +def make_cpu_client( + distributed_client=None, + node_id=0, + num_nodes=1, +) -> ...: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + return _xla.get_tfrt_cpu_client( + asynchronous=True, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, ) - return _xla.get_tfrt_cpu_client(asynchronous=True) def make_gpu_client( @@ -97,12 +104,8 @@ def make_gpu_client( if memory_fraction: config.memory_fraction = float(memory_fraction) config.preallocate = preallocate not in ('0', 'false', 'False') - register_custom_call_handler( - 'CUDA', _xla.register_custom_call_target - ) - register_custom_call_handler( - 'ROCM', _xla.register_custom_call_target - ) + register_custom_call_handler('CUDA', _xla.register_custom_call_target) + register_custom_call_handler('ROCM', _xla.register_custom_call_target) return _xla.get_gpu_client( asynchronous=True, @@ -224,6 +227,7 @@ def generate_pjrt_gpu_plugin_options( class OpMetadata: """Python representation of a xla.OpMetadata protobuf.""" + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') def __init__(self, op_type='', op_name='', source_file='', source_line=0): @@ -238,10 +242,8 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): full_filename, lineno = inspect.stack()[skip_frames][1:3] filename = os.path.basename(full_filename) return OpMetadata( - op_type=op_type, - op_name=op_name, - source_file=filename, - source_line=lineno) + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) PrimitiveType = _xla.PrimitiveType @@ -380,7 +382,8 @@ def convert(pyval): if isinstance(pyval, tuple): if layout is not None: raise NotImplementedError( - 'shape_from_pyval does not support layouts for tuple shapes') + 'shape_from_pyval does not support layouts for tuple shapes' + ) return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) else: return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) @@ -478,8 +481,9 @@ class PaddingType(enum.Enum): SAME = 2 -def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, - window_strides): +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): """Maps PaddingType or string to pad values (list of pairs of ints).""" if not isinstance(padding_type, (str, PaddingType)): msg = 'padding_type must be str or PaddingType, got {}.' @@ -501,7 +505,8 @@ def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, pad_sizes = [ max((out_size - 1) * stride + filter_size - in_size, 0) for out_size, stride, filter_size, in_size in zip( - out_shape, window_strides, rhs_dims, lhs_dims) + out_shape, window_strides, rhs_dims, lhs_dims + ) ] return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] else: @@ -605,6 +610,7 @@ def register_custom_call_handler(platform: str, handler: Any) -> None: class PaddingConfigDimension: """Python representation of a xla.PaddingConfigDimension protobuf.""" + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') edge_padding_low: int @@ -619,6 +625,7 @@ def __init__(self): class PaddingConfig: """Python representation of a xla.PaddingConfig protobuf.""" + __slots__ = ('dimensions',) def __init__(self): @@ -652,8 +659,13 @@ def make_padding_config( class DotDimensionNumbers: """Python representation of a xla.DotDimensionNumbers protobuf.""" - __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', - 'lhs_batch_dimensions', 'rhs_batch_dimensions') + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) def __init__(self): self.lhs_contracting_dimensions = [] @@ -663,9 +675,10 @@ def __init__(self): def make_dot_dimension_numbers( - dimension_numbers: Union[DotDimensionNumbers, - Tuple[Tuple[List[int], List[int]], - Tuple[List[int], List[int]]]] + dimension_numbers: Union[ + DotDimensionNumbers, + Tuple[Tuple[List[int], List[int]], Tuple[List[int], List[int]]], + ] ) -> DotDimensionNumbers: """Builds a DotDimensionNumbers object from a specification. @@ -692,11 +705,18 @@ def make_dot_dimension_numbers( class ConvolutionDimensionNumbers: """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" - __slots__ = ('input_batch_dimension', 'input_feature_dimension', - 'input_spatial_dimensions', 'kernel_input_feature_dimension', - 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', - 'output_batch_dimension', 'output_feature_dimension', - 'output_spatial_dimensions') + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) def __init__(self): self.input_batch_dimension = 0 @@ -711,30 +731,32 @@ def __init__(self): def make_convolution_dimension_numbers( - dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, - str]], - num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, Tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: """Builds a ConvolutionDimensionNumbers object from a specification. Args: dimension_numbers: optional, either a ConvolutionDimensionNumbers object or - a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of - length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and - the output with the character 'N', (2) feature dimensions in lhs and the - output with the character 'C', (3) input and output feature dimensions - in rhs with the characters 'I' and 'O' respectively, and (4) spatial - dimension correspondences between lhs, rhs, and the output using any - distinct characters. For example, to indicate dimension numbers - consistent with the Conv operation with two spatial dimensions, one - could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate - dimension numbers consistent with the TensorFlow Conv2D operation, one - could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of - convolution dimension specification, window strides are associated with - spatial dimension character labels according to the order in which the - labels appear in the rhs_spec string, so that window_strides[0] is - matched with the dimension corresponding to the first character - appearing in rhs_spec that is not 'I' or 'O'. By default, use the same - dimension numbering as Conv and ConvWithGeneralPadding. + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. num_spatial_dimensions: the number of spatial dimensions. Returns: @@ -764,18 +786,26 @@ def make_convolution_dimension_numbers( dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) dimension_numbers.input_spatial_dimensions.extend( - sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]))) + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) dimension_numbers.output_spatial_dimensions.extend( - sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]))) + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) return dimension_numbers class PrecisionConfig: """Python representation of a xla.PrecisionConfig protobuf.""" + __slots__ = ('operand_precision',) Precision = _xla.PrecisionConfig_Precision @@ -786,8 +816,13 @@ def __init__(self): class GatherDimensionNumbers: """Python representation of a xla.GatherDimensionNumbers protobuf.""" - __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', - 'index_vector_dim') + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) def __init__(self): self.offset_dims = [] @@ -798,8 +833,13 @@ def __init__(self): class ScatterDimensionNumbers: """Python representation of a xla.ScatterDimensionNumbers protobuf.""" - __slots__ = ('update_window_dims', 'inserted_window_dims', - 'scatter_dims_to_operand_dims', 'index_vector_dim') + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) def __init__(self): self.update_window_dims = [] @@ -810,6 +850,7 @@ def __init__(self): class ReplicaGroup: """Python representation of a xla.ReplicaGroup protobuf.""" + __slots__ = ('replica_ids',) def __init__(self): diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index 2eb82ec094f2d7..04e22b8e7cf417 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -81,7 +81,11 @@ def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: .. def heap_profile(client: Client) -> bytes: ... -def make_cpu_client() -> Client: +def make_cpu_client( + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., +) -> Client: ... def make_gpu_client( diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 5267109c42bfec..2f9a147c379927 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -20,14 +20,26 @@ import inspect import types import typing from typing import ( - Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, - Type, TypeVar, Union, overload) + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) import numpy as np -from . import ops from . import jax_jit from . import mlir +from . import ops from . import outfeed_receiver from . import pmap_lib from . import profiler @@ -89,7 +101,8 @@ class Shape: type: Union[np.dtype, PrimitiveType], dims_seq: Any = ..., layout_seq: Any = ..., - dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ... + dynamic_dimensions: Optional[List[bool]] = ..., + ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @staticmethod @@ -137,7 +150,7 @@ class XlaComputation: def get_hlo_module(self) -> HloModule: ... def program_shape(self) -> ProgramShape: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - def as_hlo_text(self, print_large_constants: bool=False) -> str: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... def as_hlo_dot_graph(self) -> str: ... def hash(self) -> int: ... def as_hlo_module(self) -> HloModule: ... @@ -177,10 +190,11 @@ class HloModule: @property def name(self) -> str: ... def to_string(self, options: HloPrintOptions = ...) -> str: ... - def as_serialized_hlo_module_proto(self)-> bytes: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... @staticmethod def from_serialized_hlo_module_proto( - serialized_hlo_module_proto: bytes) -> HloModule: ... + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... def computations(self) -> List[HloComputation]: ... class HloModuleGroup: @@ -192,10 +206,9 @@ class HloModuleGroup: def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... - def hlo_module_cost_analysis( - client: Client, - module: HloModule) -> Dict[str, float]: ... + client: Client, module: HloModule +) -> Dict[str, float]: ... class XlaOp: ... @@ -215,7 +228,8 @@ class XlaBuilder: self, __output_index: Sequence[int], __param_number: int, - __param_index: Sequence[int]) -> None: ... + __param_index: Sequence[int], + ) -> None: ... class DeviceAssignment: @staticmethod @@ -239,12 +253,18 @@ class CompileOptions: profile_version: int device_assignment: Optional[DeviceAssignment] compile_portable_executable: bool - env_option_overrides: List[Tuple[str,str]] - -def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ... -def register_custom_call_partitioner(name: str, prop_user_sharding: Callable, - partition: Callable, infer_sharding_from_operands: Callable, - can_side_effecting_have_replicated_sharding: bool) -> None: ... + env_option_overrides: List[Tuple[str, str]] + +def register_custom_call_target( + fn_name: str, capsule: Any, platform: str +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool, +) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... class DebugOptions: @@ -350,11 +370,16 @@ class HloSharding: @staticmethod def from_string(sharding: str) -> HloSharding: ... @staticmethod - def tuple_sharding(shape: Shape, shardings: Sequence[HloSharding]) -> HloSharding: ... + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... @staticmethod - def iota_tile(dims: Sequence[int], reshape_dims: Sequence[int], - transpose_perm: Sequence[int], - subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type], + ) -> HloSharding: ... @staticmethod def replicate() -> HloSharding: ... @staticmethod @@ -416,16 +441,17 @@ class Memory: class GpuAllocatorConfig: class Kind(enum.IntEnum): - DEFAULT: int - PLATFORM: int - BFC: int - CUDA_ASYNC: int + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int def __init__( self, kind: Kind = ..., memory_fraction: float = ..., - preallocate: bool = ...) -> None: ... + preallocate: bool = ..., + ) -> None: ... class HostBufferSemantics(enum.IntEnum): IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics @@ -451,61 +477,78 @@ class Client: argument: Any, device: Optional[Device] = ..., force_copy: bool = ..., - host_buffer_semantics: HostBufferSemantics = ...) -> ArrayImpl: ... + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... def make_cross_host_receive_buffers( - self, - shapes: Sequence[Shape], - device: Device) -> List[Tuple[ArrayImpl, bytes]]: ... + self, shapes: Sequence[Shape], device: Device + ) -> List[Tuple[ArrayImpl, bytes]]: ... def compile( self, computation: Union[str, bytes], - compile_options: CompileOptions = ..., host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... def deserialize_executable( - self, serialized: bytes, - options: Optional[CompileOptions], host_callbacks: Sequence[Any] = ...) -> LoadedExecutable: ... + self, + serialized: bytes, + options: Optional[CompileOptions], + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> _Status: ... def get_emit_python_callback_descriptor( - self, callable: Callable, operand_shapes: Sequence[Shape], - results_shapes: Sequence[Shape]) -> Tuple[Any, Any]: ... + self, + callable: Callable, + operand_shapes: Sequence[Shape], + results_shapes: Sequence[Shape], + ) -> Tuple[Any, Any]: ... def make_python_callback_from_host_send_and_recv( - self, callable: Callable, operand_shapes: Sequence[Shape], - result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], - recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ...) -> Any: ... + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Optional[Callable] = ..., + ) -> Any: ... def __getattr__(self, name: str) -> Any: ... -def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., +) -> Client: ... def get_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., + num_nodes: int = ..., allowed_devices: Optional[Any] = ..., platform_name: Optional[str] = ..., - mock:Optional[bool]=...) -> Client:... + mock: Optional[bool] = ..., +) -> Client: ... def get_mock_gpu_client( asynchronous: bool = ..., allocator_config: GpuAllocatorConfig = ..., distributed_client: Optional[DistributedRuntimeClient] = ..., node_id: int = ..., allowed_devices: Optional[Any] = ..., - platform_name: Optional[str] = ...) -> Client:... + platform_name: Optional[str] = ..., +) -> Client: ... def get_c_api_client( platform_name: str, options: Dict[str, Union[str, int, List[int], float, bool]], distributed_client: Optional[DistributedRuntimeClient] = ..., ) -> Client: ... - def get_default_c_api_topology( platform_name: str, topology_name: str, options: Dict[str, Union[str, int, List[int], float]], -) -> DeviceTopology: - ... -def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: - ... - +) -> DeviceTopology: ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... def load_pjrt_plugin(platform_name: str, library_path: str) -> _Status: ... def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... @@ -542,10 +585,14 @@ ArrayImpl = Any # traceback: Traceback # _HAS_DYNAMIC_ATTRIBUTES: bool = ... -def copy_array_to_devices_with_sharding(self: ArrayImpl, devices: List[Device], sharding: Any) -> ArrayImpl: ... - +def copy_array_to_devices_with_sharding( + self: ArrayImpl, devices: List[Device], sharding: Any +) -> ArrayImpl: ... def batched_device_put( - aval: Any, sharding: Any, shards: Sequence[Any], devices: List[Device], + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: List[Device], committed: bool = True, ) -> ArrayImpl: ... @@ -554,11 +601,8 @@ def check_and_canonicalize_memory_kind( memory_kind: Optional[str], device_list: DeviceList) -> Optional[str]: ... def array_result_handler( - aval: Any, - sharding: Any, - committed: bool, - _skip_checks: bool = ...) -> Callable: - ... + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... class Token: def block_until_ready(self): ... @@ -570,7 +614,9 @@ class ShardedToken: class ExecuteResults: def __len__(self) -> int: ... def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... - def disassemble_prefix_into_single_device_arrays(self, n: int) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> List[List[ArrayImpl]]: ... def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... def consume_token(self) -> ShardedToken: ... @@ -582,18 +628,17 @@ class LoadedExecutable: def delete(self) -> None: ... def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... def execute_with_token( - self, - arguments: Sequence[ArrayImpl]) -> Tuple[List[ArrayImpl], Token]: - ... + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... def execute_sharded_on_local_devices( - self, - arguments: Sequence[List[ArrayImpl]]) -> List[List[ArrayImpl]]: ... + self, arguments: Sequence[List[ArrayImpl]] + ) -> List[List[ArrayImpl]]: ... def execute_sharded_on_local_devices_with_tokens( - self, - arguments: Sequence[List[ArrayImpl]]) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... + self, arguments: Sequence[List[ArrayImpl]] + ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... def execute_sharded( - self, - arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ...) -> ExecuteResults: ... + self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... def hlo_modules(self) -> List[HloModule]: ... def get_output_memory_kinds(self) -> List[List[str]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... @@ -626,14 +671,18 @@ class DeviceTopology: def __getattr__(self, name: str) -> Any: ... def buffer_to_dlpack_managed_tensor( - buffer: ArrayImpl, - stream: int | None = None) -> Any: ... + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... def dlpack_managed_tensor_to_buffer( - tensor: Any, device: Device, stream: int | None) -> ArrayImpl: ... + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... + # Legacy overload def dlpack_managed_tensor_to_buffer( - tensor: Any, cpu_backend: Optional[Client] = ..., - gpu_backend: Optional[Client] = ...) -> ArrayImpl: ... + tensor: Any, + cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... # === BEGIN py_traceback.cc @@ -652,12 +701,12 @@ class Traceback: def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... - @staticmethod def code_addr2line(code: types.CodeType, lasti: int) -> int: ... @staticmethod - def code_addr2location(code: types.CodeType, - lasti: int) -> Tuple[int, int, int, int]: ... + def code_addr2location( + code: types.CodeType, lasti: int + ) -> Tuple[int, int, int, int]: ... def replace_thread_exc_traceback(traceback: Any): ... @@ -665,16 +714,20 @@ def replace_thread_exc_traceback(traceback: Any): ... class DistributedRuntimeService: def shutdown(self) -> None: ... + class DistributedRuntimeClient: def connect(self) -> _Status: ... def shutdown(self) -> _Status: ... def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... - def blocking_key_value_get_bytes(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str) -> _Status: ... - def key_value_delete(self, key:str) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int) -> _Status: ... + def get_distributed_runtime_service( address: str, num_nodes: int, @@ -691,17 +744,16 @@ def get_distributed_runtime_client( heartbeat_interval: Optional[int] = ..., max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., - shutdown_on_destruction: Optional[bool] = ...) -> DistributedRuntimeClient: ... + shutdown_on_destruction: Optional[bool] = ..., +) -> DistributedRuntimeClient: ... class PreemptionSyncManager: def initialize(self, client: DistributedRuntimeClient) -> _Status: ... def reached_sync_point(self, step_counter: int) -> bool: ... -def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def create_preemption_sync_manager() -> PreemptionSyncManager: ... def collect_garbage() -> None: ... - def is_optimized_build() -> bool: ... - def json_to_pprof_profile(json: str) -> bytes: ... def pprof_profile_to_json(proto: bytes) -> str: ... @@ -715,8 +767,9 @@ class PmapFunction: def _cache_size(self) -> int: ... def _cache_clear(self) -> None: ... -def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...): - ... +def weakref_lru_cache( + cache_context_fn: Callable, call: Callable, maxsize=... +): ... class DeviceList: def __init__(self, device_assignment: Tuple[Device, ...]): ... @@ -739,13 +792,18 @@ class DeviceList: def memory_kinds(self) -> Tuple[str, ...]: ... class Sharding: ... - class XLACompatibleSharding(Sharding): ... class NamedSharding(XLACompatibleSharding): - def __init__(self, mesh: Any, spec: Any, *, memory_kind: Optional[str] = None, - _parsed_pspec: Any = None, - _manual_axes: frozenset[Any] = frozenset()): ... + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: Optional[str] = None, + _parsed_pspec: Any = None, + _manual_axes: frozenset[Any] = frozenset(), + ): ... mesh: Any spec: Any _memory_kind: Optional[str] @@ -760,15 +818,21 @@ class SingleDeviceSharding(XLACompatibleSharding): _internal_device_list: DeviceList class PmapSharding(XLACompatibleSharding): - def __init__(self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec): ... + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... devices: List[Any] sharding_spec: pmap_lib.ShardingSpec _internal_device_list: DeviceList class GSPMDSharding(XLACompatibleSharding): - def __init__(self, devices: Sequence[Device], - op_sharding: Union[OpSharding, HloSharding], - *, memory_kind: Optional[str] = None): ... + def __init__( + self, + devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, + memory_kind: Optional[str] = None, + ): ... _devices: Tuple[Device, ...] _hlo_sharding: HloSharding _memory_kind: Optional[str] @@ -787,12 +851,16 @@ class PjitFunctionCache: @staticmethod def clear_all(): ... -def pjit(function_name: str, fun: Optional[Callable], cache_miss: Callable, - static_argnums: Sequence[int], static_argnames: Sequence[str], - donate_argnums: Sequence[int], - pytree_registry: pytree.PyTreeRegistry, - cache: Optional[PjitFunctionCache] = ..., - ) -> PjitFunction: ... +def pjit( + function_name: str, + fun: Optional[Callable], + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + donate_argnums: Sequence[int], + pytree_registry: pytree.PyTreeRegistry, + cache: Optional[PjitFunctionCache] = ..., +) -> PjitFunction: ... class HloPassInterface: @property @@ -814,9 +882,6 @@ class TupleSimplifer(HloPassInterface): def __init__(self) -> None: ... def is_asan() -> bool: ... - def is_msan() -> bool: ... - def is_tsan() -> bool: ... - def is_sanitized() -> bool: ... From c2fbada040e94138491c47a1805dba2a66302157 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 15 Nov 2023 13:44:22 -0800 Subject: [PATCH 140/391] Add `--output=jsonproto` to example usage in generate_compile_commands.py PiperOrigin-RevId: 582792461 --- third_party/xla/build_tools/lint/generate_compile_commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py index 1dc84fbf14a8ff..2deb1fc3f177c5 100644 --- a/third_party/xla/build_tools/lint/generate_compile_commands.py +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -15,7 +15,7 @@ r"""Produces a `compile_commands.json` from the output of `bazel aquery`. Example usage: - bazel aquery "mnemonic(CppCompile, //xla/...)" | \ + bazel aquery "mnemonic(CppCompile, //xla/...)" --output=jsonproto | \ python3 build_tools/lint/generate_compile_commands.py """ import dataclasses From 0fcda0bab2759da390df399c3708557fd2bddcbc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 15 Nov 2023 13:47:55 -0800 Subject: [PATCH 141/391] [xla:gpu] Add support for multi-instruction custom fusions Add an example of pattern matching dot instruction with upcasting one of the operands PiperOrigin-RevId: 582793687 --- third_party/xla/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/custom_fusion_rewriter.cc | 111 +++++++++++----- third_party/xla/xla/service/gpu/kernels/BUILD | 4 + .../gpu/kernels/cutlass_gemm_fusion.cc | 119 ++++++++++++++---- .../service/gpu/kernels/cutlass_gemm_fusion.h | 41 ++++++ .../gpu/kernels/cutlass_gemm_fusion_test.cc | 92 +++++++++++++- 6 files changed, 312 insertions(+), 57 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 757c3f313aae17..58ae615215ee3b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2676,10 +2676,12 @@ cc_library( "//xla/service:hlo_pass", "//xla/service/gpu/kernels:custom_fusion_library", "//xla/service/gpu/kernels:custom_fusion_pattern", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc index 4c80f7636a06c8..666b816650c58c 100644 --- a/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/custom_fusion_rewriter.cc @@ -15,14 +15,17 @@ limitations under the License. #include "xla/service/gpu/custom_fusion_rewriter.h" +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -37,42 +40,86 @@ CustomFusionRewriter::CustomFusionRewriter( const CustomFusionPatternRegistry* patterns) : patterns_(patterns) {} +// Returns instructions that have to become custom fusion parameters. Returns an +// error if matched pattern can't be outlined as a fusion. +static StatusOr> GetPatternCaptures( + const CustomFusionPattern::Match& match) { + HloInstruction* root = match.instructions.back(); + absl::InlinedVector captures; + + // Instruction that will go into the fusion body. + absl::flat_hash_set instructions_set( + match.instructions.begin(), match.instructions.end()); + + // Check that intermediate instructions do not have users outside of the + // matched pattern. Only root instruction can have external users. + for (HloInstruction* instr : match.instructions) { + for (HloInstruction* user : instr->users()) { + if (instr != root && !instructions_set.contains(user)) { + return absl::InvalidArgumentError(absl::StrCat( + "Custom fusion intermediate result ", instr->name(), + " has users outside of a matched pattern: ", user->name())); + } + } + } + + // Collect instructions captured by a matched pattern. + for (HloInstruction* instr : match.instructions) { + for (HloInstruction* operand : instr->operands()) { + if (!instructions_set.contains(operand)) captures.push_back(operand); + } + } + + return captures; +} + // Creates custom fusion computation and moves all matched instructions into it. static StatusOr CreateFusionBody( - HloModule* module, const CustomFusionPattern::Match& match) { + HloModule* module, const CustomFusionPattern::Match& match, + absl::Span captures) { HloComputation::Builder builder(match.config.name()); - // We do not currently support matching custom fusions with more than one - // instruction. - HloInstruction* root = match.instructions[0]; + // A mapping from original instructions to instructions in the fusion body. + absl::flat_hash_map instr_mapping; - // Fusion computation parameters inferred from a matched instruction. - absl::InlinedVector parameters; - for (HloInstruction* operand : root->operands()) { - parameters.push_back(builder.AddInstruction( - HloInstruction::CreateParameter(parameters.size(), operand->shape(), - absl::StrCat("p", parameters.size())))); + auto mapped_operands = [&](HloInstruction* instr) { + absl::InlinedVector operands; + for (HloInstruction* operand : instr->operands()) { + operands.push_back(instr_mapping.at(operand)); + } + return operands; + }; + + // For every parameter create a parameter instruction in the computation body + // and set up instruction mapping. + for (const HloInstruction* capture : captures) { + int64_t index = instr_mapping.size(); + instr_mapping[capture] = + builder.AddInstruction(HloInstruction::CreateParameter( + index, capture->shape(), absl::StrCat("p", index))); } - builder.AddInstruction(root->CloneWithNewOperands(root->shape(), parameters)); + // TODO(ezhulenev): Instructions in the pattern must be topologically sorted, + // otherwise we'll get a crash! Figure out how to do it! + for (HloInstruction* instr : match.instructions) { + instr_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); + } return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); } static StatusOr CreateFusionInstruction( HloModule* module, const CustomFusionPattern::Match& match, - HloComputation* body) { + absl::Span captures, HloComputation* body) { // We'll be replacing the root operation of a custom fusion with a fusion // instruction calling fusion computation. - HloInstruction* fusion_root = match.instructions[0]; - HloComputation* fusion_parent = fusion_root->parent(); + HloInstruction* root = match.instructions.back(); + HloComputation* parent = root->parent(); - HloInstruction* fusion = - fusion_parent->AddInstruction(HloInstruction::CreateFusion( - fusion_root->shape(), HloInstruction::FusionKind::kCustom, - fusion_root->operands(), body)); - - // Assign unique name to a new fusion instruction. + // Add a fusion operation calling outlined fusion computation. + HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( + root->shape(), HloInstruction::FusionKind::kCustom, captures, body)); module->SetAndUniquifyInstrName(fusion, match.config.name()); // Set backends config to a matched custom fusion config. @@ -81,9 +128,7 @@ static StatusOr CreateFusionInstruction( *backend_config.mutable_custom_fusion_config() = match.config; TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(backend_config))); - // Replace fusion root with a fusion instruction. - TF_RETURN_IF_ERROR(fusion_parent->ReplaceInstruction(fusion_root, fusion)); - + TF_RETURN_IF_ERROR(parent->ReplaceInstruction(root, fusion)); return fusion; } @@ -103,17 +148,23 @@ StatusOr CustomFusionRewriter::Run( if (matches.empty()) return false; for (const CustomFusionPattern::Match& match : matches) { - if (match.instructions.size() != 1) - return absl::InternalError( - "Custom fusions with multiple instruction are not yet supported"); + // Check if pattern can be outlined as a fusion and collect captured + // parameters (instructions defined outside of a fusion). + auto captures = GetPatternCaptures(match); + if (!captures.ok()) { + VLOG(2) << "Skip custom fusion " << match.config.name() << ": " + << captures.status(); + continue; + } TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, match)); + CreateFusionBody(module, match, *captures)); - TF_ASSIGN_OR_RETURN(HloInstruction * fusion, - CreateFusionInstruction(module, match, fusion_body)); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, match, *captures, fusion_body)); - VLOG(5) << "Added a fusion instruction: " << fusion->name() + VLOG(2) << "Added a fusion instruction: " << fusion->name() << " for custom fusion " << match.config.name() << " (instruction count = " << match.instructions.size() << ")"; } diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 5f71b06d94dfc2..267d51c520e97c 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -77,6 +77,7 @@ cc_library( # cc_library( # name = "cutlass_gemm_fusion", # srcs = ["cutlass_gemm_fusion.cc"], +# hdrs = ["cutlass_gemm_fusion.h"], # deps = [ # ":custom_fusion", # ":custom_fusion_pattern", @@ -88,6 +89,7 @@ cc_library( # "//xla:statusor", # "//xla:xla_data_proto_cc", # "//xla/hlo/ir:hlo", +# "//xla/service:pattern_matcher", # "@local_tsl//tsl/platform:errors", # "@local_tsl//tsl/platform:logging", # "@local_tsl//tsl/platform:statusor", @@ -100,10 +102,12 @@ cc_library( # srcs = ["cutlass_gemm_fusion_test.cc"], # backends = ["gpu"], # deps = [ +# ":custom_fusion_pattern", # ":cutlass_gemm_fusion", # "@com_google_absl//absl/strings", # "//xla:debug_options_flags", # "//xla:error_spec", +# "//xla/service/gpu:custom_fusion_rewriter", # "//xla/tests:hlo_test_base", # "@local_tsl//tsl/platform:test", # "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index 250eec73172bdc..e2b2de33aec502 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" + #include #include #include @@ -27,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/cutlass_gemm_kernel.h" +#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/status.h" #include "xla/statusor.h" @@ -40,48 +43,111 @@ namespace xla::gpu { // Cutlass Gemm pattern matching helpers //===----------------------------------------------------------------------===// -static Status IsF32Gemm(const HloDotInstruction* dot) { - const Shape& lhs = dot->operand(0)->shape(); - const Shape& rhs = dot->operand(1)->shape(); - const Shape& out = dot->shape(); +namespace { +namespace m = match; + +// Pattern for matching mixed precision GEMMs. +struct GemmWithUpcast { + explicit GemmWithUpcast(HloDotInstruction* dot) : dot(dot) {} - if (lhs.dimensions_size() != 2 || rhs.dimensions_size() != 2) - return absl::InternalError("dot operands must have rank 2"); + HloInstruction* dot; + HloInstruction* lhs_upcast = nullptr; // HLO convert instr + HloInstruction* rhs_upcast = nullptr; // HLO convert instr +}; +} // namespace - if (lhs.element_type() != PrimitiveType::F32 || - rhs.element_type() != PrimitiveType::F32 || - out.element_type() != PrimitiveType::F32) - return absl::InternalError("dot operations must use F32 data type"); +// Returns OK if dot instruction is a simple 2D row-major gemm. +static Status MatchRowMajorGemm(HloDotInstruction* dot) { + if (dot->operand(0)->shape().dimensions_size() != 2 || + dot->operand(1)->shape().dimensions_size() != 2) { + return absl::InternalError("operands must have rank 2"); + } - // Check that we do not transpose any of the operands. auto& dot_dims = dot->dot_dimension_numbers(); if (dot_dims.lhs_contracting_dimensions().size() != 1 || - dot_dims.lhs_contracting_dimensions()[0] != 1) + dot_dims.lhs_contracting_dimensions()[0] != 1) { return absl::InternalError("lhs contracting dimensions must be 1"); + } if (dot_dims.rhs_contracting_dimensions().size() != 1 || - dot_dims.rhs_contracting_dimensions()[0] != 0) + dot_dims.rhs_contracting_dimensions()[0] != 0) { return absl::InternalError("rhs contracting dimensions must be 0"); + } + + return OkStatus(); +} + +// Return OK if dot instruction is a simple gemm with all operands and result +// having the same data type. +static Status MatchSimpleGemm(HloDotInstruction* dot, PrimitiveType dtype) { + TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); + + if (dot->operand(0)->shape().element_type() != dtype || + dot->operand(1)->shape().element_type() != dtype || + dot->shape().element_type() != dtype) { + return absl::InternalError("operands and result must have the same type"); + } return OkStatus(); } +// Returns matched GEMM with one of the operands upcasted to the accumulator +// data type with an HLO convert instruction. +static StatusOr MatchGemmWithUpcast(HloDotInstruction* dot) { + TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); + + GemmWithUpcast matched(dot); + + // C <- convert(A) * B + if (Match(const_cast(dot->operand(0)), + m::Convert(&matched.lhs_upcast, m::Op()))) { + return matched; + } + + // C <- A * convert(B) + if (Match(const_cast(dot->operand(1)), + m::Convert(&matched.rhs_upcast, m::Op()))) { + return matched; + } + + return absl::InternalError("unsupported gemm with upcasing"); +} + //===----------------------------------------------------------------------===// -// CutlassGemmPattern +// Cutlass Gemm Patterns //===----------------------------------------------------------------------===// -class CutlassGemmPattern : public CustomFusionPattern { - public: - std::optional TryMatch(HloInstruction* instr) const override { - auto* dot = DynCast(instr); - if (!dot || !IsF32Gemm(dot).ok()) return std::nullopt; +std::optional CutlassGemmPattern::TryMatch( + HloInstruction* instr) const { + auto* dot = DynCast(instr); + if (!dot) return std::nullopt; - CustomFusionConfig config; - config.set_name("cutlass_gemm"); - return Match{config, {instr}}; - } -}; + auto matched = MatchSimpleGemm(dot, PrimitiveType::F32); + if (!matched.ok()) return std::nullopt; + + CustomFusionConfig config; + config.set_name("cutlass_gemm"); + return Match{config, {instr}}; +} + +std::optional +CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { + auto* dot = DynCast(instr); + if (!dot) return std::nullopt; + + auto matched = MatchGemmWithUpcast(dot); + if (!matched.ok()) return std::nullopt; + + // Only one operand can be upcasted. + DCHECK(matched->lhs_upcast == nullptr || matched->rhs_upcast == nullptr); + + CustomFusionConfig config; + config.set_name("cutlass_gemm_with_upcast"); + + return matched->lhs_upcast ? Match{config, {matched->lhs_upcast, instr}} + : Match{config, {matched->rhs_upcast, instr}}; +} //===----------------------------------------------------------------------===// // CutlassGemmFusion @@ -92,11 +158,12 @@ class CutlassGemmFusion : public CustomFusion { StatusOr> LoadKernels( const HloComputation* computation) const final { auto* dot = DynCast(computation->root_instruction()); - if (dot == nullptr) + if (dot == nullptr) { return absl::InternalError( "cutlass_gemm requires ROOT operation to be a dot"); + } - TF_RETURN_IF_ERROR(IsF32Gemm(dot)); + TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, PrimitiveType::F32)); auto dtype = dot->shape().element_type(); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h new file mode 100644 index 00000000000000..f448b2d0a4d915 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_FUSION_H_ +#define XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_FUSION_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" + +namespace xla::gpu { + +// Pattern matches simple row-major gemms to CUTLASS kernels. +class CutlassGemmPattern : public CustomFusionPattern { + public: + std::optional TryMatch(HloInstruction* instr) const override; +}; + +// Pattern matches mixed dtype gemms when one of the operands is upcasted to an +// accumulator (output) dtype, i.e. BF16 <= BF16 x S8. +class CutlassGemmWithUpcastPattern : public CustomFusionPattern { + public: + std::optional TryMatch(HloInstruction* instr) const override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_KERNELS_CUTLASS_GEMM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 4829648dd03cc7..128139ae2ba90c 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" + +#include + #include "xla/debug_options_flags.h" #include "xla/error_spec.h" +#include "xla/service/gpu/custom_fusion_rewriter.h" +#include "xla/service/gpu/kernels/custom_fusion_pattern.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -29,7 +35,91 @@ class CutlassFusionTest : public HloTestBase { } }; -TEST_F(CutlassFusionTest, SimpleF32Gemm) { +//===----------------------------------------------------------------------===// +// Pattern matching tests +//===----------------------------------------------------------------------===// + +TEST_F(CutlassFusionTest, RowMajorGemm) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: f32[15,19], p1: f32[19,17]) -> f32[15,17] { + %p0 = f32[15,19]{1,0} parameter(0) + %p1 = f32[19,17]{1,0} parameter(1) + ROOT %r = f32[15,17]{1,0} dot(%p0, %p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm {{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[15,19]{1,0} parameter(0) + ; CHECK: [[P1:%[^ ]+]] = f32[19,17]{1,0} parameter(1) + ; CHECK: ROOT [[DOT:%[^ ]+]] = f32[15,17]{1,0} dot([[P0]], [[P1]]), + ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[15,17]{1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm"} + ; CHECK: } + ; CHECK: } + )"; + + CustomFusionPatternRegistry patterns; + patterns.Emplace(); + + CustomFusionRewriter pass(&patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: bf16[15,19], p1: s8[19,17]) -> bf16[15,17] { + %p0 = bf16[15,19]{1,0} parameter(0) + %p1 = s8[19,17]{1,0} parameter(1) + %c1 = bf16[19,17]{1,0} convert(%p1) + ROOT %r = bf16[15,17]{1,0} dot(%p0, %c1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_upcast {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = bf16[15,19]{1,0} parameter + ; CHECK-DAG: [[P1:%[^ ]+]] = s8[19,17]{1,0} parameter + ; CHECK: [[C1:%[^ ]+]] = bf16[19,17]{1,0} convert([[P1]]) + ; CHECK: ROOT [[DOT:%[^ ]+]] = bf16[15,17]{1,0} dot([[P0]], [[C1]]), + ; CEHCK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = bf16[15,17]{1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_upcast, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm_with_upcast"} + ; CHECK: } + ; CHECK: } + )"; + + CustomFusionPatternRegistry patterns; + patterns.Emplace(); + + CustomFusionRewriter pass(&patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + +//===----------------------------------------------------------------------===// +// Run And Compare Tests +//===----------------------------------------------------------------------===// + +TEST_F(CutlassFusionTest, RowMajorGemmKernel) { ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; const char* hlo_text_cublas = R"( From 3c86102b5c2b2b9da15e0dc661855f95497fd109 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 13:53:13 -0800 Subject: [PATCH 142/391] Re-enable layering_check for target. PiperOrigin-RevId: 582795705 --- tensorflow/core/common_runtime/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index bf5b15eebdc72f..586025cca162d1 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -131,14 +131,12 @@ cc_library( srcs = ["collective_test_util.cc"], hdrs = ["collective_test_util.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":device_resolver_local", ":process_util", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:session_options", "//tensorflow/core:testlib", "//tensorflow/core/framework:allocator", "//tensorflow/core/framework:device_attributes_proto_cc", @@ -146,6 +144,8 @@ cc_library( "//tensorflow/core/nccl:collective_communicator", "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:unbounded_work_queue", + "@com_google_absl//absl/synchronization", ], ) From 6f4fd81ee62d0a22699532bf6e5ce28c26606c86 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 15 Nov 2023 14:01:27 -0800 Subject: [PATCH 143/391] Wrap GPU client creation options in a struct, and overload GetStreamExecutorGpuClient. This makes adding new options easier. PiperOrigin-RevId: 582798266 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 57 +++++++++----- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 30 +++++++- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 74 ++++++++----------- 3 files changed, 96 insertions(+), 65 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index e5be3dc6de5dfd..84844600eee99f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_computation.h" #include "xla/pjrt/distributed/topology_util.h" +#include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" @@ -874,33 +875,53 @@ StatusOr> GetStreamExecutorGpuClient( bool should_stage_host_to_device_transfers, PjRtClient::KeyValueGetCallback kv_get, PjRtClient::KeyValuePutCallback kv_put, bool enable_mock_nccl) { + GpuClientOptions options; + options.allocator_config = allocator_config; + options.node_id = node_id; + options.num_nodes = num_nodes; + options.allowed_devices = allowed_devices; + options.platform_name = platform_name; + options.should_stage_host_to_device_transfers = + should_stage_host_to_device_transfers; + options.kv_get = kv_get; + options.kv_put = kv_put; + options.enable_mock_nccl = enable_mock_nccl; + + return GetStreamExecutorGpuClient(options); +} + +StatusOr> GetStreamExecutorGpuClient( + const GpuClientOptions& options) { #if TENSORFLOW_USE_ROCM auto pjrt_platform_name = xla::RocmName(); #else // TENSORFLOW_USE_ROCM auto pjrt_platform_name = xla::CudaName(); #endif // TENSORFLOW_USE_ROCM - TF_ASSIGN_OR_RETURN(LocalClient * xla_client, - GetGpuXlaClient(platform_name, allowed_devices)); + TF_ASSIGN_OR_RETURN( + LocalClient * xla_client, + GetGpuXlaClient(options.platform_name, options.allowed_devices)); std::map> local_device_states; TF_ASSIGN_OR_RETURN(local_device_states, BuildLocalDeviceStates(xla_client)); EnablePeerAccess(xla_client->backend().stream_executors()); - TF_ASSIGN_OR_RETURN( - auto allocator, - GetStreamExecutorGpuDeviceAllocator( - xla_client->platform(), allocator_config, local_device_states)); + TF_ASSIGN_OR_RETURN(auto allocator, + GetStreamExecutorGpuDeviceAllocator( + xla_client->platform(), options.allocator_config, + local_device_states)); auto host_memory_allocator = GetGpuHostAllocator(local_device_states.begin()->second->executor()); std::vector> devices; auto gpu_run_options = std::make_unique(); - if (enable_mock_nccl) { + if (options.enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); } absl::flat_hash_map device_maps; absl::Mutex mu; - if (enable_mock_nccl) { - kv_get = [&device_maps, &mu, &num_nodes]( + PjRtClient::KeyValueGetCallback kv_get = options.kv_get; + PjRtClient::KeyValuePutCallback kv_put = options.kv_put; + if (options.enable_mock_nccl) { + kv_get = [&device_maps, &mu, &options]( std::string_view k, absl::Duration timeout) -> xla::StatusOr { std::string result; @@ -912,7 +933,7 @@ StatusOr> GetStreamExecutorGpuClient( int device_id; std::vector tokens = absl::StrSplit(k, ':'); if (tokens.size() != 2 || !absl::SimpleAtoi(tokens[1], &device_id)) { - device_id = num_nodes - 1; + device_id = options.num_nodes - 1; } // Return fake local topology with device_id info back. xla::LocalTopologyProto local; @@ -936,17 +957,17 @@ StatusOr> GetStreamExecutorGpuClient( return xla::OkStatus(); }; } - TF_RET_CHECK(num_nodes == 1 || kv_get != nullptr); - TF_RET_CHECK(num_nodes == 1 || kv_put != nullptr); + TF_RET_CHECK(options.num_nodes == 1 || kv_get != nullptr); + TF_RET_CHECK(options.num_nodes == 1 || kv_put != nullptr); TF_RETURN_IF_ERROR(BuildDistributedDevices( - pjrt_platform_name, std::move(local_device_states), node_id, num_nodes, - &devices, gpu_run_options.get(), kv_get, kv_put)); + pjrt_platform_name, std::move(local_device_states), options.node_id, + options.num_nodes, &devices, gpu_run_options.get(), kv_get, kv_put)); return std::unique_ptr(std::make_unique( - pjrt_platform_name, xla_client, std::move(devices), - /*node_id=*/node_id, std::move(allocator), - std::move(host_memory_allocator), should_stage_host_to_device_transfers, - /*gpu_run_options=*/std::move(gpu_run_options))); + pjrt_platform_name, xla_client, std::move(devices), options.node_id, + std::move(allocator), std::move(host_memory_allocator), + options.should_stage_host_to_device_transfers, + std::move(gpu_run_options))); } absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index d51613e0604339..c66888b2bc09e5 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -221,9 +221,33 @@ std::vector> BuildLocalDevices( std::map> local_device_states, int node_id); -// kv_get and kv_put are callbacks provided by the caller to access a key-value -// store shared between nodes. kv_get and kv_put must be non-null if num_nodes -// > 1. +struct GpuClientOptions { + GpuAllocatorConfig allocator_config; + + int node_id = 0; + + int num_nodes = 1; + + std::optional> allowed_devices = std::nullopt; + + std::optional platform_name = std::nullopt; + + bool should_stage_host_to_device_transfers = true; + + // `kv_get` and `kv_put` are callbacks provided by the caller to access a + // key-value store shared between nodes. `kv_get` and `kv_put` must be + // non-null if `num_nodes` > 1. + PjRtClient::KeyValueGetCallback kv_get = nullptr; + PjRtClient::KeyValuePutCallback kv_put = nullptr; + + bool enable_mock_nccl = false; +}; + +StatusOr> GetStreamExecutorGpuClient( + const GpuClientOptions& options); + +// TODO(b/311119497): Remove this function after all callsites are updated. +ABSL_DEPRECATED("Use the the above function that takes GpuClientOptions.") StatusOr> GetStreamExecutorGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, int node_id, int num_nodes = 1, diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c6c018eed9ed34..4361139b06c938 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -106,10 +106,8 @@ static constexpr char const* kProgram = R"(HloModule HostTransfer })"; TEST(StreamExecutorGpuClientTest, SendRecvChunked) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto executable, CompileExecutable(kProgram, *client)); @@ -160,9 +158,8 @@ TEST(StreamExecutorGpuClientTest, SendRecvChunked) { } TEST(StreamExecutorGpuClientTest, SendErrorNoDeadLock) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto executable, CompileExecutable(kProgram, *client)); @@ -195,9 +192,8 @@ TEST(StreamExecutorGpuClientTest, SendErrorNoDeadLock) { } TEST(StreamExecutorGpuClientTest, RecvErrorNoDeadLock) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); TF_ASSERT_OK_AND_ASSIGN(auto executable, CompileExecutable(kProgram, *client)); @@ -233,9 +229,8 @@ TEST(StreamExecutorGpuClientTest, RecvErrorNoDeadLock) { } TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); @@ -271,9 +266,8 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { } TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); auto src_literal = LiteralUtil::CreateR1({41.0f, 42.0f, 43.0f, 44.0f}); @@ -312,9 +306,8 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { } TEST(StreamExecutorGpuClientTest, FromHostAsync) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); std::vector src_literals; @@ -380,9 +373,8 @@ TEST(StreamExecutorGpuClientTest, FromHostAsync) { } } TEST(StreamExecutorGpuClientTest, CopyRawToHostFullBuffer) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto literal = xla::LiteralUtil::CreateR1({41.0f, 42.0f}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr buffer, @@ -400,9 +392,8 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostFullBuffer) { } TEST(StreamExecutorGpuClientTest, CopyRawToHostSubBuffer) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto literal = xla::LiteralUtil::CreateR1({41.0f, 42.0f}); TF_ASSERT_OK_AND_ASSIGN( @@ -418,9 +409,8 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostSubBuffer) { } TEST(StreamExecutorGpuClientTest, CopyRawToHostOutOfRange) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); auto literal = xla::LiteralUtil::CreateR1({41.0f, 42.0f}); TF_ASSERT_OK_AND_ASSIGN( @@ -436,9 +426,8 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostOutOfRange) { } TEST(StreamExecutorGpuClientTest, AsyncCopyToDevice) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 2); // d0 is the device we will perform local/remote sends from. @@ -469,9 +458,8 @@ TEST(StreamExecutorGpuClientTest, AsyncCopyToDevice) { } TEST(StreamExecutorGpuClientTest, CreateMixOfErrorBuffers) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 1); std::vector src_literals; @@ -574,13 +562,12 @@ TEST(StreamExecutorGpuClientTest, DistributeInit) { int num_nodes = 2; for (int i = 0; i < num_nodes; i++) { thread_pool.Schedule([&, i] { - TF_ASSERT_OK_AND_ASSIGN( - auto client, - GetStreamExecutorGpuClient( - true, /*allocator_config=*/{}, - /*node_id=*/i, num_nodes, /*allowed_devices=*/std::nullopt, - /*platform_name=*/std::nullopt, - /*should_stage_host_to_device_transfers=*/true, kv_get, kv_put)); + GpuClientOptions options; + options.node_id = i; + options.num_nodes = num_nodes; + options.kv_get = kv_get; + options.kv_put = kv_put; + TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(options)); EXPECT_TRUE(client->platform_name() == "cuda" || client->platform_name() == "rocm"); EXPECT_EQ(client->addressable_device_count(), 2); @@ -590,9 +577,8 @@ TEST(StreamExecutorGpuClientTest, DistributeInit) { } TEST(StreamExecutorGpuClientTest, GetAllocatorStatsTest) { - TF_ASSERT_OK_AND_ASSIGN( - auto client, GetStreamExecutorGpuClient(true, /*allocator_config=*/{}, - /*node_id=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); ASSERT_GE(client->addressable_devices().size(), 2); for (auto device : client->addressable_devices()) { From e6d2961ce26aae8cf3fe96cb13fd277dcfd03dbc Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 15 Nov 2023 14:16:45 -0800 Subject: [PATCH 144/391] Remove xla/stream_executor:cuda_platform from `tf_additional_binary_deps` Attempt to fix https://github.com/tensorflow/tensorflow/issues/62075 As described in the issue, there is a problem where various CUDA plugins are being registered twice. I'm not sure this will fix the issue, but I noticed a year ago that `cuda_platform` is depended on via `tf_additional_binary_deps` which is incorrect based the `check_deps` within tensorflow/BUILD. I'm not sure this will work - but a good first try. More info here: https://github.com/tensorflow/tensorflow/commit/c05b07dccd66feeada466d4e5951151a1a7473ca PiperOrigin-RevId: 582804278 --- .../core/platform/build_config.default.bzl | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index 24421c6d6e8b87..80c7d25ad1dd9e 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -1,24 +1,23 @@ """OSS versions of Bazel macros that can't be migrated to TSL.""" +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@local_xla//xla:xla.bzl", + _xla_clean_dep = "clean_dep", +) load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", ) load( - "@local_xla//xla:xla.bzl", - _xla_clean_dep = "clean_dep", + "//third_party/mkl:build_defs.bzl", + "if_mkl_ml", ) load( "@local_tsl//tsl:tsl.bzl", "if_libtpu", _tsl_clean_dep = "clean_dep", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl_ml", -) def tf_tpu_dependencies(): return if_libtpu(["//tensorflow/core/tpu/kernels"]) @@ -34,11 +33,7 @@ def tf_additional_binary_deps(): # core. str(Label("//tensorflow/core/kernels:lookup_util")), str(Label("//tensorflow/core/util/tensor_bundle")), - ] + if_cuda( - [ - str(Label("@local_xla//xla/stream_executor:cuda_platform")), - ], - ) + if_rocm( + ] + if_rocm( [ str(Label("@local_xla//xla/stream_executor:rocm_platform")), str(Label("@local_xla//xla/stream_executor/rocm:rocm_rpath")), From 950b2c1daaf1a975f5df15c28d6141f58c598e3a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 14:25:22 -0800 Subject: [PATCH 145/391] Turn computing default solutions using sharding propagation off by default. Avoid printing a warning about the absence of default solutions for nodes when default solutions have not been computed. PiperOrigin-RevId: 582806847 --- .../xla/xla/hlo/experimental/auto_sharding/BUILD | 1 + .../hlo/experimental/auto_sharding/auto_sharding.cc | 12 ++++++------ .../experimental/auto_sharding/auto_sharding_impl.cc | 10 ++++------ .../auto_sharding/auto_sharding_option.h | 2 +- .../auto_sharding/auto_sharding_wrapper.h | 5 ++--- 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index d629719b8d01e9..a792011b5a254b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -159,6 +159,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":auto_sharding_cost_graph", + ":auto_sharding_option", ":auto_sharding_strategy", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index bdf3c393b01df7..7f172c0b07b251 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2490,20 +2490,19 @@ AutoShardingSolverResult CallSolver( const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, - int64_t memory_budget_per_device, bool crash_at_infinity_costs_check, bool compute_iis, int64_t solver_timeout_in_seconds, - bool allow_alias_to_follower_conversion, + const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution) { // Serialize edges and edge costs to 1d numpy arrays AutoShardingSolverRequest request; request.num_nodes = leaf_strategies.size(); - request.memory_budget = memory_budget_per_device; + request.memory_budget = option.memory_budget_per_device; request.s_len = cost_graph.node_lens_; request.s_follow = cost_graph.follow_idx_; request.s_hint = s_hint; request.solver_timeout_in_seconds = solver_timeout_in_seconds; - request.crash_at_infinity_costs_check = crash_at_infinity_costs_check; + request.crash_at_infinity_costs_check = !option.try_multiple_mesh_shapes; request.compute_iis = compute_iis; for (const auto& iter : cost_graph.edge_costs_) { request.e.push_back(iter.first); @@ -2549,7 +2548,8 @@ AutoShardingSolverResult CallSolver( mi.push_back(strategy.memory_cost); pi.push_back(default_strategy && sharding == *default_strategy ? 0 : 1); } - if (*std::min_element(pi.begin(), pi.end()) > 0) { + if (option.use_sharding_propagation_for_default_shardings && + *std::min_element(pi.begin(), pi.end()) > 0) { LOG(WARNING) << "No default strategy for {node_idx " << node_idx << ", instruction ID " << strategies->instruction_id << ", instruction name " << instruction_name << "}"; @@ -2614,7 +2614,7 @@ AutoShardingSolverResult CallSolver( for (NodeStrategyIdx i = 0; i < row_indices.size() && convertable; ++i) { if (vij[i * col_indices.size() + i] != 0.0) convertable = false; } - if (convertable && allow_alias_to_follower_conversion) { + if (convertable && option.allow_alias_to_follower_conversion) { new_followers.push_back(std::make_pair(idx_a, idx_b)); } else { request.a.push_back(std::make_pair(idx_a, idx_b)); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 25ddf351deaa45..027426f1331b84 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -37,12 +37,10 @@ AutoShardingSolverResult Solve( const AliasSet& alias_set, const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver( - hlo_live_range, liveness_node_set, strategy_map, leaf_strategies, - cost_graph, alias_set, /*s_hint*/ {}, option.memory_budget_per_device, - /*crash_at_infinity_costs_check*/ !option.try_multiple_mesh_shapes, - /*compute_iis*/ true, option.solver_timeout_in_seconds, - option.allow_alias_to_follower_conversion, sharding_propagation_solution); + return CallSolver(hlo_live_range, liveness_node_set, strategy_map, + leaf_strategies, cost_graph, alias_set, /*s_hint*/ {}, + /*compute_iis*/ true, option.solver_timeout_in_seconds, + option, sharding_propagation_solution); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index c2898ce610693b..775c4a59bc301e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -189,7 +189,7 @@ struct AutoShardingOption { // In order to obtain default sharding strategies for instructions to limit // departures from the defaults, use sharding propagation instead of assuming // a simple replicated default. - bool use_sharding_propagation_for_default_shardings = true; + bool use_sharding_propagation_for_default_shardings = false; // Prints a debug string. std::string ToString() const; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 2b7d3e3854068a..374fb1b37f7e19 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -21,8 +21,8 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -40,9 +40,8 @@ AutoShardingSolverResult CallSolver( const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, - int64_t memory_budget_per_device, bool crash_at_infinity_costs_check, bool compute_iis, int64_t solver_timeout_in_seconds, - bool allow_alias_to_follower_conversion, + const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution = {}); From 55ab3045f37c72e03e6f344f73c799f3ce708c18 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Wed, 15 Nov 2023 14:27:55 -0800 Subject: [PATCH 146/391] Disable bazel test on TF-TPU wheels PiperOrigin-RevId: 582807555 --- ci/official/envs/nightly_linux_x86_tpu_py310 | 1 + ci/official/envs/nightly_linux_x86_tpu_py311 | 1 + ci/official/envs/nightly_linux_x86_tpu_py312 | 1 + ci/official/envs/nightly_linux_x86_tpu_py39 | 1 + 4 files changed, 4 insertions(+) diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310 index 22a7965e3454d7..31f7a15bd3ca00 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py310 +++ b/ci/official/envs/nightly_linux_x86_tpu_py310 @@ -9,4 +9,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.10 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311 index 474d80dced1852..4061b0330e6e93 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py311 +++ b/ci/official/envs/nightly_linux_x86_tpu_py311 @@ -9,4 +9,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.11 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py312 b/ci/official/envs/nightly_linux_x86_tpu_py312 index 0c26cb57f3b58b..31afd3709e4b92 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py312 +++ b/ci/official/envs/nightly_linux_x86_tpu_py312 @@ -9,4 +9,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.12 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39 index bdd0784f767f98..645eeed5827e6e 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py39 +++ b/ci/official/envs/nightly_linux_x86_tpu_py39 @@ -9,4 +9,5 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.9 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 +TFCI_WHL_BAZEL_TEST_ENABLE=0 TFCI_WHL_SIZE_LIMIT=580M From c51acd1ddb9edff219bea1195ce7a184fac1dd2e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 15 Nov 2023 14:28:19 -0800 Subject: [PATCH 147/391] [xla:gpu] Add a stub for gemm with upcast custom fusion We do not have latest CUTLASS available in XLA to test it end-to-end, for now just add stub for fusion. PiperOrigin-RevId: 582807673 --- .../gpu/kernels/cutlass_gemm_fusion.cc | 32 ++++++++++++++- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 40 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index e2b2de33aec502..ac58ba9c2d8cae 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -150,7 +150,7 @@ CutlassGemmWithUpcastPattern::TryMatch(HloInstruction* instr) const { } //===----------------------------------------------------------------------===// -// CutlassGemmFusion +// Cutlass Gemm Fusions //===----------------------------------------------------------------------===// class CutlassGemmFusion : public CustomFusion { @@ -180,7 +180,37 @@ class CutlassGemmFusion : public CustomFusion { } }; +class CutlassGemmWithUpcastFusion : public CustomFusion { + public: + StatusOr> LoadKernels( + const HloComputation* computation) const final { + auto* dot = DynCast(computation->root_instruction()); + if (dot == nullptr) { + return absl::InternalError( + "cutlass_gemm requires ROOT operation to be a dot"); + } + + TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithUpcast(dot)); + + // We only support upcasting of rhs operand. + if (matched.lhs_upcast != nullptr) + return absl::InternalError("only rhs upcasting is implemented"); + + auto dot_dtype = dot->shape().element_type(); + auto upcast_dtype = matched.rhs_upcast->shape().element_type(); + + // We only support BF16 <- BF16 x S8 upcasted gemm. + if (dot_dtype != PrimitiveType::BF16 || upcast_dtype != PrimitiveType::S8) + return absl::InternalError("unsupported upcasting pattern"); + + return absl::UnimplementedError("requires CUTLASS 3.3.0"); + } +}; + } // namespace xla::gpu XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmPattern); + XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm", ::xla::gpu::CutlassGemmFusion); +XLA_REGISTER_CUSTOM_FUSION("cutlass_gemm_with_upcast", + ::xla::gpu::CutlassGemmWithUpcastFusion); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 128139ae2ba90c..541ba5c569b088 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -155,4 +155,44 @@ TEST_F(CutlassFusionTest, RowMajorGemmKernel) { error_spec, /*run_hlo_passes=*/false)); } +TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { + GTEST_SKIP() << "Requires CUTLASS 3.3.0+"; + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = bf16[16,32]{1,0} parameter(0) + p1 = s8[32,8]{1,0} parameter(1) + c1 = bf16[32,8]{1,0} convert(p1) + gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1), + custom_call_target="__cublas$gemm", + backend_config={"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"} + ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element((bf16[16,8]{1,0}, s8[0]{0}) gemm), index=0 + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm_with_upcast { + p0 = bf16[16,32]{1,0} parameter(0) + p1 = s8[32,8]{1,0} parameter(1) + c1 = bf16[32,8]{1,0} convert(p1) + ROOT dot = bf16[16,8]{1,0} dot(p0, c1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY e { + p0 = bf16[16,32]{1,0} parameter(0) + p1 = s8[32,8]{1,0} parameter(1) + ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, + backend_config={kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast"}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + error_spec, /*run_hlo_passes=*/false)); +} + } // namespace xla::gpu From 15d63ea5f52d8db3d584afed807d96257c50b31b Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 15 Nov 2023 14:28:42 -0800 Subject: [PATCH 148/391] [stream_executor] Fix absl::Span usage in GpuCommandBuffer #5857 absl::Span could be invalidated if we call emplace_back to vector that it points to. PiperOrigin-RevId: 582807794 --- .../gpu/runtime3/command_buffer_thunk_test.cc | 67 +++++++++++++++++++ third_party/xla/xla/stream_executor/gpu/BUILD | 2 +- .../stream_executor/gpu/gpu_command_buffer.cc | 29 ++++---- .../stream_executor/gpu/gpu_command_buffer.h | 5 +- 4 files changed, 88 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 12e84df1b0ae59..9ad683a640980e 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -246,4 +246,71 @@ TEST(CommandBufferThunkTest, GemmCmd) { ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } +TEST(CommandBufferThunkTest, MultipleLaunchCmd) { + se::StreamExecutor* executor = CudaExecutor(); + + se::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42, b=0 + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory c = executor->AllocateArray(length, 0); + se::DeviceMemory d = executor->AllocateArray(length, 0); + + stream.ThenMemset32(&a, 42, byte_length); + stream.ThenMemZero(&b, byte_length); + stream.ThenMemset32(&c, 21, byte_length); + stream.ThenMemZero(&d, byte_length); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation alloc_b(/*index=*/1, byte_length, /*color=*/0); + BufferAllocation alloc_c(/*index=*/2, byte_length, /*color=*/0); + BufferAllocation alloc_d(/*index=*/3, byte_length, /*color=*/0); + + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); + BufferAllocation::Slice slice_c(&alloc_c, 0, byte_length); + BufferAllocation::Slice slice_d(&alloc_d, 0, byte_length); + + auto args = {slice_a, slice_a, slice_b}; // b = a + a + auto args_1 = {slice_c, slice_c, slice_d}; // d = c + c + + // Prepare commands sequence for constructing command buffer. + CommandBufferCmdSequence commands; + commands.Emplace("add", args, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + commands.Emplace("add", args_1, LaunchDimensions(1, 4), + /*shmem_bytes=*/0); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({a, b, c, d}, 0, executor->GetAllocator()); + Thunk::ExecuteParams params(run_options, allocations, &stream, {}); + + CommandBufferCmd::ExecutableSource source = { + /*text=*/se::cuda::internal::kAddI32Kernel, /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize(executor, source)); + + // Execute command buffer thunk and verify that it added the value. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::vector dst(4, 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `d` data back to host. + std::vector dst_1(4, 0); + stream.ThenMemcpy(dst.data(), d, byte_length); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 26d5aa4680b0c3..1ea30b06ef827f 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -105,6 +105,7 @@ cc_library( ":gpu_types_header", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_internal", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -114,7 +115,6 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index d50bb091d040a6..0f98cb3be02da9 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/functional/any_invocable.h" #include "absl/log/check.h" @@ -141,9 +142,12 @@ tsl::Status GpuCommandBuffer::Trace( return tsl::OkStatus(); } -absl::Span GpuCommandBuffer::GetDependencies() { - return nodes_.empty() ? absl::Span() - : absl::Span(&nodes_.back(), 1); +GpuCommandBuffer::Dependencies GpuCommandBuffer::GetDependencies() { + if (nodes_.empty()) { + return {}; + } + + return {nodes_.back()}; } tsl::Status GpuCommandBuffer::CheckNotFinalized() { @@ -178,11 +182,11 @@ tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads, // Adds a new kernel node to the graph under construction. if (state_ == State::kCreate) { - absl::Span deps = GetDependencies(); + Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); return GpuDriver::GraphAddKernelNode( - node, graph_, deps, kernel.name(), gpu_func, blocks.x, blocks.y, - blocks.z, threads.x, threads.y, threads.z, + node, graph_, absl::MakeSpan(deps), kernel.name(), gpu_func, blocks.x, + blocks.y, blocks.z, threads.x, threads.y, threads.z, args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr); } @@ -206,9 +210,10 @@ tsl::Status GpuCommandBuffer::AddNestedCommandBuffer( GpuGraphHandle child_graph = GpuCommandBuffer::Cast(&nested)->graph(); // Adds a child graph node to the graph under construction. if (state_ == State::kCreate) { - absl::Span deps = GetDependencies(); + Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); - return GpuDriver::GraphAddChildNode(node, graph_, deps, child_graph); + return GpuDriver::GraphAddChildNode(node, graph_, absl::MakeSpan(deps), + child_graph); } // Updates child graph node in the executable graph. @@ -227,11 +232,11 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, // Adds a new memcpy node to the graph under construction. if (state_ == State::kCreate) { - absl::Span deps = GetDependencies(); + Dependencies deps = GetDependencies(); GpuGraphNodeHandle* node = &nodes_.emplace_back(); - return GpuDriver::GraphAddMemcpyD2DNode(parent_->gpu_context(), node, - graph_, deps, AsDevicePtr(*dst), - AsDevicePtr(src), size); + return GpuDriver::GraphAddMemcpyD2DNode( + parent_->gpu_context(), node, graph_, absl::MakeSpan(deps), + AsDevicePtr(*dst), AsDevicePtr(src), size); } return UnsupportedStateError(state_); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 7c30e40414b101..70215b30d4f24a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" -#include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -83,7 +83,8 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a // dependency between all nodes added to a command buffer. We need a concept // of a barrier at a command buffer level. - absl::Span GetDependencies(); + using Dependencies = absl::InlinedVector; + Dependencies GetDependencies(); // Returns OK status if command buffer is not finalized and it is still // possible to add new commands to it, otherwise returns internal error. From f18323f8745b5e0f1a1dbb52046fa1f415c541d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 15:31:24 -0800 Subject: [PATCH 149/391] Reintroduces some CP-SAT solver parameters that are designed to ensure determinism. PiperOrigin-RevId: 582826685 --- .../experimental/auto_sharding/auto_sharding_solver.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 2924b53931ff6a..80b2ce58957355 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -210,9 +210,12 @@ AutoShardingSolverResult CallORToolsSolver( #ifdef PLATFORM_GOOGLE if (solver->ProblemType() == operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { - // Set num_workers for parallelism. - solver_parameter_str = absl::StrCat("num_workers:", num_workers); - solver->SetSolverSpecificParametersAsString(solver_parameter_str); + // Set random_seed, interleave_search and share_binary_clauses for + // determinism, and num_workers for parallelism. + solver_parameter_str = absl::StrCat( + "share_binary_clauses:false,random_seed:1,interleave_" + "search:true,num_workers:", + num_workers); } #endif // Create variables From 42b40d22c1b5316d826c770698d037c6368a308b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 15:45:28 -0800 Subject: [PATCH 150/391] Add MHLO and StableHLO passes to stablehlo_quant_opt This makes it easier to run passes locally. PiperOrigin-RevId: 582830620 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 5 +++++ .../quantization/stablehlo/tools/stablehlo_quant_opt.cc | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index a19017c95006ab..15dce89f4aac99 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -534,14 +534,19 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", + "@local_xla//xla/mlir_hlo:mhlo_passes", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index 3afc42e21d1f6e..6883fbababb535 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project @@ -21,6 +23,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" @@ -29,8 +32,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/core/ir/types/dialect.h" int main(int argc, char** argv) { tensorflow::InitMlir y(&argc, &argv); @@ -39,6 +45,8 @@ int main(int argc, char** argv) { mlir::registerTensorFlowPasses(); mlir::quant::stablehlo::registerPasses(); mlir::quant::stablehlo::registerBridgePasses(); + mlir::stablehlo::registerPasses(); + mlir::mhlo::registerAllMhloPasses(); mlir::DialectRegistry registry; registry.insert Date: Wed, 15 Nov 2023 16:12:48 -0800 Subject: [PATCH 151/391] [XLA] Match other dimensions to try to swap instead of just directly matching replicated dimensions when the current replicated dimensions doesn't match. It's not necessarily faster to match a non-matching replicated dimension when grouping NonContracting dimensions and it can cause significant reshuffling of data. Instead try to match another dimension that if it was the replicated one then we would be able to match. If we find it the reshuffling is much more regular (and over specific axes) than the random one introduced by just blindly matching the replicated dimension. PiperOrigin-RevId: 582838552 --- .../xla/xla/service/spmd/dot_handler.cc | 52 ++++++++++++++----- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 2f487f2b90e805..8f0ed21153b4ef 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -2330,11 +2330,39 @@ GetNonContractingPartitionGroupedShardingForOtherOperand( GroupedSharding output_grouped = hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims); std::vector other_group_dims; + // Try to match on the replicated dimensions first. if (other_sharding.ReplicateOnLastTileDim() && other_sharding.tile_assignment().dimensions().back() % group_count == 0) { + // Try to aggressively match the replicated dimension with the current + // output device groups. If fails then try find a dimension to swap instead + // of reordering the mesh with collective permutes that can create weird + // patterns. If that fails also do the traditional replication matching. + for (int64_t i = other_sharding.tile_assignment().num_dimensions() - 1; + i >= 0; --i) { + if (other_sharding.tile_assignment().dimensions()[i] % group_count == 0) { + std::vector perm( + other_sharding.tile_assignment().num_dimensions(), 0); + absl::c_iota(perm, 0); + std::swap(perm[i], + perm[other_sharding.tile_assignment().num_dimensions() - 1]); + auto sharding_to_match = + i == other_sharding.tile_assignment().num_dimensions() - 1 + ? other_sharding + : hlo_sharding_util::TransposeSharding(other_sharding, perm); + if (auto grouped_sharding = hlo_sharding_util:: + PartialReplicatedGroupShardingWithAssignedDeviceGroups( + sharding_to_match, + sharding_to_match.tile_assignment().dimensions().back() / + group_count, + output_grouped.device_groups)) { + return grouped_sharding.value(); + } + } + } other_group_dims.push_back( other_sharding.tile_assignment().num_dimensions() - 1); - } else { + } + if (other_group_dims.empty()) { const bool may_replicate_other_contracting_dims = (other_contracting_partitions == group_count && other_non_contracting_partitions == @@ -2342,9 +2370,18 @@ GetNonContractingPartitionGroupedShardingForOtherOperand( const bool may_replicate_other_non_contracting_dims = group_count == other_non_contracting_partitions && matching_contracting_partitions == other_contracting_partitions; + if (auto found_dims = FindMatchingPartitionedDimsForGrouping( other_sharding, output_grouped.device_groups)) { other_group_dims = std::move(*found_dims); + } else if (other_sharding.ReplicateOnLastTileDim() && + // Match grouping non-matching replicated dimension at a lower + // priority than finding matched dimensions as it usually pro + other_sharding.tile_assignment().dimensions().back() % + group_count == + 0) { + other_group_dims.push_back( + other_sharding.tile_assignment().num_dimensions() - 1); } else if (may_replicate_other_contracting_dims && (!may_replicate_other_non_contracting_dims || ShapeUtil::ByteSizeOf(other_shape) <= @@ -2364,18 +2401,6 @@ GetNonContractingPartitionGroupedShardingForOtherOperand( if (other_group_dims.size() == 1 && other_group_dims[0] == other_sharding.tile_assignment().num_dimensions() - 1) { - // Try to reuse the device groups from the output to match the partially - // replicated dim. - if (auto grouped_sharding = hlo_sharding_util:: - PartialReplicatedGroupShardingWithAssignedDeviceGroups( - other_sharding, - other_sharding.tile_assignment().dimensions().back() / - group_count, - output_grouped.device_groups)) { - std::vector group_dim_shards = { - other_sharding.tile_assignment().dimensions().back() / group_count}; - return grouped_sharding.value(); - } std::vector group_dim_shards = { other_sharding.tile_assignment().dimensions().back() / group_count}; return AlignGroupsWith( @@ -2445,7 +2470,6 @@ StatusOr PartitionDotGroupOnNonContracting( lhs_matching ? dims_mapping.rhs_non_contracting_dims : dims_mapping.lhs_non_contracting_dims, dims_mapping.contracting_dims); - if (!other_grouped) { other = other.Replicate(); } From f54b6e2359b5b5df9b2c5531d72fedd16023b74e Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Wed, 15 Nov 2023 16:27:02 -0800 Subject: [PATCH 152/391] Improve side effect model of `tf.CheckNumerics` This creates dependencies for `tf.CheckNumerics` ops on the same device. This is needed to avoid the ops being treated as unknown side-effecting and causing deadlock in multi-GPU jobs. PiperOrigin-RevId: 582842646 --- tensorflow/compiler/mlir/tensorflow/BUILD | 20 ++++++ .../mlir/tensorflow/ir/tf_generated_ops.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 20 ++++++ .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 62 +++++------------- .../mlir/tensorflow/ir/tf_side_effects.h | 5 ++ .../utils/side_effect_analysis_util.cc | 63 +++++++++++++++++++ .../utils/side_effect_analysis_util.h | 44 +++++++++++++ 7 files changed, 170 insertions(+), 46 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a24ea5e7a8fe63..89c54afc244f99 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -353,6 +353,7 @@ cc_library( ":attribute_utils", ":convert_type", ":dynamic_shape_utils", + ":side_effect_analysis_util", ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_op_interfaces", @@ -407,6 +408,7 @@ cc_library( deps = [ ":attribute_utils", ":serialize_mlir_module_utils", + ":side_effect_analysis_util", ":tensorflow_attributes", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", @@ -451,6 +453,7 @@ cc_library( "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ + ":side_effect_analysis_util", ":tensorflow_attributes", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", @@ -492,6 +495,7 @@ cc_library( "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], deps = [ + ":side_effect_analysis_util", ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_op_interfaces", @@ -1978,6 +1982,22 @@ tf_cc_test( ], ) +cc_library( + name = "side_effect_analysis_util", + srcs = [ + "utils/side_effect_analysis_util.cc", + ], + hdrs = [ + "utils/side_effect_analysis_util.h", + ], + deps = [ + "tensorflow_side_effects", + "tensorflow_types", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + build_test( name = "tensorflow_build_test", targets = [ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index f69d3d4f9c97f4..c35ff30bff0a9f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2095,7 +2095,7 @@ def TF_CeilOp : TF_Op<"Ceil", [Pure, TF_Idempotent, TF_SameOperandsAndResultType }]; } -def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [TF_SameOperandsAndResultTypeResolveRef]> { +def TF_CheckNumericsOp : TF_Op<"CheckNumerics", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, TF_SameOperandsAndResultTypeResolveRef]> { let summary = "Checks a tensor for NaN and Inf values."; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index cee0e40f9cfeb5..8b8b069ea6f40d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -67,6 +67,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -85,6 +86,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" @@ -1090,6 +1092,24 @@ OpFoldResult CastOp::fold(FoldAdaptor) { return {}; } +//===----------------------------------------------------------------------===// +// CheckNumericsOp +//===----------------------------------------------------------------------===// + +void CheckNumericsOp::getEffects( + SmallVectorImpl>& + effects) { + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::CheckNumerics::get()); + MarkResourceAsReadOnly(getTensor(), effects); +} + +// For `CheckNumerics` ops the `device` attribute corresponds to the resource +// instance. +std::optional CheckNumericsOp::GetResourceInstanceStr() { + return GetDeviceAttrAsResourceInstanceStr(*this); +} + //===----------------------------------------------------------------------===// // CollectiveReduceV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 01cbbb9a46967c..122677ee4ad6da 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -88,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" namespace mlir { namespace TF { @@ -2344,21 +2345,13 @@ void TPUExecuteOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::TPUExecute::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing TPUCompile, there is not sufficient information to determine + // effects on resources. For the MLIR bridge, this op will never be + // populated with resource handles and tf.TPUExecuteAndUpdateVariables is + // used instead. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing TPUCompile, there is not sufficient information to determine - // effects on resources. For the MLIR bridge, this op will never be - // populated with resource handles and tf.TPUExecuteAndUpdateVariables is - // used instead. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } @@ -2373,19 +2366,11 @@ void _XlaRunOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::_XlaRun::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing _XlaCompile, there is not sufficient information to determine + // effects on resources. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing _XlaCompile, there is not sufficient information to determine - // effects on resources. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } @@ -3059,35 +3044,22 @@ LogicalResult XlaCallModuleOp::verifySymbolUses( void XlaLaunchOp::getEffects( SmallVectorImpl> &effects) { - effects.reserve(getArgs().size() + 1); + effects.reserve(2 * getArgs().size() + 1); effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::XlaLaunch::get()); + // Conservatively mark resource handles as read and write, as without + // analyzing XlaLaunch, there is not sufficient information to determine + // effects on resources. for (Value value : getArgs()) { - if (value.getType() - .cast() - .getElementType() - .isa()) { - // Conservatively mark resource handles as read and write, as without - // analyzing XlaLaunch, there is not sufficient information to determine - // effects on resources. - effects.emplace_back(MemoryEffects::Read::get(), value, - ResourceEffects::Variable::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - ResourceEffects::Variable::get()); - } + MarkResourceAsReadAndWrite(value, effects); } } // For `XlaLaunch` ops the `device` attribute corresponds to the resource // instance. std::optional XlaLaunchOp::GetResourceInstanceStr() { - auto device_attr = (*this)->getAttrOfType("device"); - // Treat missing device attribute like unspecified (= empty string) attribute. - // Note that different op instances with the same string (including empty - // string) are seen as dependent (same resource instance). - if (!device_attr) return ""; - return device_attr.str(); + return GetDeviceAttrAsResourceInstanceStr(*this); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 6384d0770a3358..9bcc75fbe1e424 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -121,6 +121,11 @@ struct _XlaRun : public ::mlir::SideEffects::Resource::Base<_XlaRun> { StringRef getName() final { return "_XlaRun"; } }; +struct CheckNumerics + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "CheckNumerics"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc new file mode 100644 index 00000000000000..7a6da9fcbd04d2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h" + +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +std::string GetDeviceAttrAsResourceInstanceStr(mlir::Operation* op) { + auto device_attr = op->getAttrOfType("device"); + // Treat missing device attribute like unspecified (= empty string) attribute. + // Note that different op instances with the same string (including empty + // string) are seen as dependent (same resource instance). + if (!device_attr) return ""; + return device_attr.str(); +} + +void MarkResourceAsReadAndWrite( + Value value, + SmallVectorImpl>& + effects) { + if (value.getType().cast().getElementType().isa()) { + effects.emplace_back(MemoryEffects::Read::get(), value, + ResourceEffects::Variable::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + ResourceEffects::Variable::get()); + } +} + +void MarkResourceAsReadOnly( + Value value, + SmallVectorImpl>& + effects) { + if (value.getType().cast().getElementType().isa()) { + effects.emplace_back(MemoryEffects::Read::get(), value, + ResourceEffects::Variable::get()); + } +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h new file mode 100644 index 00000000000000..c55ad530f15962 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ + +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +std::string GetDeviceAttrAsResourceInstanceStr(Operation* op); + +void MarkResourceAsReadAndWrite( + Value value, + SmallVectorImpl>& + effect); + +void MarkResourceAsReadOnly( + Value value, + SmallVectorImpl>& + effect); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ From 77f18668c626bd42a512beab1ee7c2e1cebd6779 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 15 Nov 2023 17:01:48 -0800 Subject: [PATCH 153/391] [PJRT] Fix PjRtExecutable::GetParameterLayouts for tupled params Also factors out layout-getting logic into ComputationLayout. PiperOrigin-RevId: 582851059 --- third_party/xla/xla/pjrt/pjrt_executable.cc | 19 +----- third_party/xla/xla/python/xla_client_test.py | 48 ++++++++++++--- third_party/xla/xla/service/BUILD | 2 + .../xla/xla/service/computation_layout.cc | 60 +++++++++++++++++++ .../xla/xla/service/computation_layout.h | 10 ++++ 5 files changed, 113 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_executable.cc b/third_party/xla/xla/pjrt/pjrt_executable.cc index 771a2295aad649..141987063ae1de 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.cc +++ b/third_party/xla/xla/pjrt/pjrt_executable.cc @@ -335,12 +335,7 @@ StatusOr> PjRtExecutable::GetParameterLayouts() const { "from executable."); } ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); - std::vector result; - result.reserve(comp_layout.parameter_count()); - for (const ShapeLayout& layout : comp_layout.parameter_layouts()) { - result.push_back(layout.layout()); - } - return result; + return comp_layout.FlattenedParameterLayouts(); } StatusOr> PjRtExecutable::GetOutputLayouts() const { @@ -357,17 +352,7 @@ StatusOr> PjRtExecutable::GetOutputLayouts() const { "from executable."); } ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); - const Shape& result_shape = comp_layout.result_shape(); - - std::vector result; - if (!result_shape.IsTuple()) { - result.push_back(result_shape.layout()); - } else { - for (const Shape& subshape : result_shape.tuple_shapes()) { - result.push_back(subshape.layout()); - } - } - return result; + return comp_layout.FlattenedResultLayouts(); } StatusOr> diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 511b6bf13b2074..2708311aaab9d6 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -513,7 +513,6 @@ class LayoutsTest(ComputationTest): """Tests related to getting and setting on-device memory layouts.""" @unittest.skipIf(pathways, "not implemented") - @unittest.skipIf(pathways_ifrt, "check fails") def testGetArgumentLayouts(self): # Create computation with a few parameters. c = self._NewComputation() @@ -542,7 +541,39 @@ def MakeArg(shape, dtype): self.assertEmpty(layouts[2].minor_to_major()) @unittest.skipIf(pathways, "not implemented") - @unittest.skipIf(pathways_ifrt, "not implemented") + def testGetArgumentLayoutsTupled(self): + # Generated with: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} +""" + options = xla_client.CompileOptions() + # 'parameter_is_tupled_arguments' causes MLIR untupled arguments to get + # turned into HLO tupled arguments. + options.parameter_is_tupled_arguments = True + executable = self.backend.compile(module_str, compile_options=options) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(layouts[0].minor_to_major(), 3) + self.assertEmpty(layouts[1].minor_to_major()) + self.assertLen(layouts[2].minor_to_major(), 1) + + @unittest.skipIf(pathways, "not implemented") def testGetOutputLayouts(self): # Generated with jax.jit(lambda: (np.ones((1024, 128)), np.int32(42), # np.ones(10)))() @@ -552,13 +583,12 @@ def testGetOutputLayouts(self): func.func public @main() -> (tensor<1024x128xf32> {jax.result_info = "[0]"}, tensor {jax.result_info = "[1]"}, tensor<10xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> loc(#loc) - %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> loc(#loc) - %2 = stablehlo.constant dense<42> : tensor loc(#loc) - return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<42> : tensor + return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> + } +} """ executable = self.backend.compile(module_str) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 9b2c0ff3dc079c..788e7c110ef4e4 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5477,6 +5477,8 @@ cc_library( deps = [ "//xla:printer", "//xla:shape_layout", + "//xla:shape_util", + "//xla:statusor", "//xla:types", "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/service/computation_layout.cc b/third_party/xla/xla/service/computation_layout.cc index 717b7ac550442e..5f83c31343bf5a 100644 --- a/third_party/xla/xla/service/computation_layout.cc +++ b/third_party/xla/xla/service/computation_layout.cc @@ -17,11 +17,15 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/layout.h" #include "xla/printer.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" #include "xla/types.h" namespace xla { @@ -55,6 +59,62 @@ bool ComputationLayout::AnyLayoutSet() const { result_layout_.LayoutIsSet(); } +StatusOr> ComputationLayout::FlattenedParameterLayouts() + const { + std::vector result; + for (int i = 0; i < parameter_count(); ++i) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + parameter_shape(i), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return OkStatus(); + } + if (!subshape.IsArray()) { + return Unimplemented( + "ComputationLayout::FlattenedParameterLayouts doesn't support " + "token or opaque parameters (got: %s)", + ToString()); + } + if (!subshape.has_layout()) { + return InvalidArgument( + "ComputationLayout::FlattenedParameterLayouts can only be " + "called after all parameters have layouts assigned (got: %s)", + ToString()); + } + result.push_back(subshape.layout()); + return OkStatus(); + })); + } + return result; +} + +StatusOr> ComputationLayout::FlattenedResultLayouts() + const { + std::vector result; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + result_shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return OkStatus(); + } + if (!subshape.IsArray()) { + return Unimplemented( + "ComputationLayout::FlattenedResultLayouts doesn't support " + "token or opaque outputs (got: %s)", + ToString()); + } + if (!subshape.has_layout()) { + return InvalidArgument( + "ComputationLayout::FlattenedResultLayouts can only be called " + "after all outputs have layouts assigned (got: %s)", + ToString()); + } + result.push_back(subshape.layout()); + return OkStatus(); + })); + return result; +} + void ComputationLayout::Print(Printer* printer) const { printer->Append("("); if (!parameter_layouts_.empty()) { diff --git a/third_party/xla/xla/service/computation_layout.h b/third_party/xla/xla/service/computation_layout.h index c44d34d5e9d04b..659ce362201829 100644 --- a/third_party/xla/xla/service/computation_layout.h +++ b/third_party/xla/xla/service/computation_layout.h @@ -83,6 +83,16 @@ class ComputationLayout { // Returns true if any layouts (parameters and result) have been set. bool AnyLayoutSet() const; + // Returns a list of each parameter's layout. If the parameters are tupled, + // returns an untupled list. Must only be called if all parameters have + // layouts set (check with LayoutIsSet()). + StatusOr> FlattenedParameterLayouts() const; + + // Returns a list of each output's layout. If the result shape is a tuple, + // returns an untupled list. Must only be called if all outputs have layouts + // set (check with LayoutIsSet()). + StatusOr> FlattenedResultLayouts() const; + // Prints a string representation of this object. void Print(Printer* printer) const; From 7d1f07cc30dc9febbabe8d163e3a75edcba806a5 Mon Sep 17 00:00:00 2001 From: Jackson Stokes Date: Wed, 15 Nov 2023 17:22:07 -0800 Subject: [PATCH 154/391] [XLA:GPU] Modify the the triton softmax emitter to accept multiple parameters where tiling is across a single dim. This prepares for fusion of auxiliary kernels into triton softmax kernels. The parameters can be included in any order, but each must be tiled on a single dimension. We retain the assumption that the entire kernel maintains the same block size as the reduce op, and so the last axis of every reduction parameter must retain the same length. PiperOrigin-RevId: 582855608 --- .../xla/xla/service/gpu/ir_emitter_triton.cc | 113 ++++- .../xla/service/gpu/ir_emitter_triton_test.cc | 437 ++++++++++++++++++ 2 files changed, 525 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 2ef9d0fdbe1c33..dc19d7d572e18a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_triton.h" #include +#include #include #include #include @@ -1579,7 +1580,8 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, // * the last axis of every reduction parameter has the same length // * reductions only reduce a single operand // * all the shapes have canonical layout (logical layout = physical layout) - // * the computation has a single input and a single output + // * the computation has a single output + // * we tile along a single dimension // TODO(bchetioui): allow doing several rows per block (e.g. for when rows // are smaller than the minimum transaction size) @@ -1596,49 +1598,110 @@ Status EmitSoftMax(mlir::OpBuilder builder, absl::string_view libdevice_path, CHECK_EQ(reduce->dimensions()[0], reduce_input_shape.rank() - 1); int row_len = reduce_input_shape.dimensions_minor(0); - int block_size = 1; - - // block_size must be a power of two. - while (block_size < row_len) { - block_size *= 2; - } Value pid = b.create( b.getI64Type(), b.create(mt::ProgramIDDim::X)); Value row_stride = CreateConst(b, b.getI32Type(), row_len); + Value row_offset = b.create( + pid, b.create(b.getI64Type(), row_stride)); + Value zero_offset = CreateConst(b, b.getI64Type(), 0); + absl::flat_hash_map values_out; - auto make_tensor_pointer = [&](Value base) { - Value offset = b.create( - pid, b.create(b.getI64Type(), row_stride)); - return b.create( - /*base=*/AddPtr(b, base, offset), - /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), row_len)}, - /*strides=*/ValueRange{CreateConst(b, b.getI64Type(), 1)}, - /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), 0)}, + std::vector boundary_checks; + + // block_size must be a power of two. + int result_block_size = pow(2, ceil(log(row_len) / log(2))); + + if (result_block_size != row_len) { + boundary_checks.push_back(0); + } + + // Emits load instructions + for (int param_idx = 0; param_idx < computation->num_parameters(); + ++param_idx) { + HloInstruction* param = computation->parameter_instruction(param_idx); + // Current tiling derivation assigns index 0 to the reduction dimension and + // index 1 to the batch dimension. + auto reduce_iterspec = analysis.IterSpec( + TritonFusionAnalysis::Scope::OUTPUT, param, /*dimension=*/0); + auto batch_iterspec = analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, + param, /*dimension=*/1); + + // Make sure only batch and reduce dims are present in tiling + CHECK_EQ(analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, param, + /*dimension=*/2), + nullptr); + + if (!reduce_iterspec) { + // This parameter's broadcast is along the reduce dimension, and so + // each pid uses and broadcasts its own index. + + // If batchDimIterSpec is also not present, then this parameter is a + // scalar, in which case we reuse this for each pid with offset. + Value batch_offset = batch_iterspec ? pid : zero_offset; + + values_out[param] = EmitParameterLoad( + b, AddPtr(b, fn.getArgument(param_idx), batch_offset), + boundary_checks); + continue; + } + + CHECK_NE(reduce_iterspec, nullptr); + CHECK_EQ(reduce_iterspec->size(), 1); + + // TODO(b/310721908): The below assumes that we tile along a single dim. + int reduce_dim_len = reduce_iterspec->front().count; + int reduce_dim_stride = reduce_iterspec->front().stride; + int slice_offset = reduce_iterspec->front().slice_start; + + // If the batch dimension is present in this parameter's tile, we must make + // sure each batch idx is offset by the correct number of rows. If it is not + // present, then the reduce dim data is reused without any offset. + Value base_offset = batch_iterspec ? row_offset : zero_offset; + + // We assume that the reduced axis of this parameter has length row_len. + CHECK_EQ(reduce_dim_len, row_len); + + // block_size must be a power of two. + int block_size = pow(2, ceil(log(reduce_dim_len) / log(2))); + + // Verify that this param contains a single contiguous fragment. + CHECK_EQ(reduce_iterspec->front().subfragments.size(), 1); + + Value emitted_tensor = b.create( + /*base=*/AddPtr(b, fn.getArgument(param_idx), base_offset), + /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), reduce_dim_len)}, + /*strides=*/ + ValueRange{CreateConst(b, b.getI64Type(), reduce_dim_stride)}, + /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), slice_offset)}, /*tensorShape=*/std::vector{block_size}, /*order=*/std::vector{0}); - }; - std::vector boundary_checks; - if (block_size != row_len) { - boundary_checks.push_back(0); + values_out[param] = EmitParameterLoad(b, emitted_tensor, boundary_checks); } - values_out[computation->parameter_instruction(0)] = EmitParameterLoad( - b, make_tensor_pointer(fn.getArgument(0)), boundary_checks); + // Dimension 0 is the reduced one by construction and it's the only one // present in the tile shapes. std::vector tiled_dims = {DimProperties( - /*index=*/0, pid, block_size, /*split_value=*/1)}; + /*index=*/0, pid, result_block_size, /*split_value=*/1)}; TF_ASSIGN_OR_RETURN( Value result, EmitScope(b, libdevice_path, &analysis, TritonFusionAnalysis::Scope::OUTPUT, tiled_dims, computation->MakeInstructionPostOrder(), values_out)); - b.create(make_tensor_pointer(fn.getArgument(1)), result, - std::vector{0}, mt::CacheModifier::NONE, - mt::EvictionPolicy::NORMAL); + Value store_tensor = b.create( + /*base=*/AddPtr(b, fn.getArgument(computation->num_parameters()), + row_offset), + /*shape=*/ValueRange{CreateConst(b, b.getI64Type(), row_len)}, + /*strides=*/ValueRange{CreateConst(b, b.getI64Type(), 1)}, + /*offsets=*/ValueRange{CreateConst(b, b.getI32Type(), 0)}, + /*tensorShape=*/std::vector{result_block_size}, + /*order=*/std::vector{0}); + + b.create(store_tensor, result, std::vector{0}, + mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); return OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 148977393bc81f..682c430c82d414 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -206,6 +206,443 @@ CHECK: } tsl::testing::IsOkAndHolds(true)); } +TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleParameter) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK: arith.extsi %[[PID]] : i32 to i64 +CHECK: tt.addptr %[[P0]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG2:[^:]*]]: f32, %[[ARG3:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.splat +CHECK: arith.mulf +CHECK-SAME: tensor<128xf32> +CHECK: tt.addptr %[[P1]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithSingleScalarParameter) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[] parameter(0) + broadcast_1 = f32[125,127]{1,0} broadcast(parameter_0), dimensions={} + multiply_0 = f32[125,127]{1,0} multiply(broadcast_1, broadcast_1) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[] constant(42) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +})"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK-DAG: %[[ARG_0:.*]] = tt.addptr %[[P0]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK: tt.load %[[ARG_0]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 +CHECK-NEXT: tt.splat +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG2:[^:]*]]: f32, %[[ARG3:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.splat +CHECK: arith.mulf +CHECK-SAME: tensor<128xf32> +CHECK: tt.addptr %[[P1]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleParameters) { + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[125,127]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG3:[^:]*]]: f32, %[[ARG4:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.addptr %[[P2]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, + TestSoftmaxEmitterWithMultipleParametersOrderSwapped) { + // This mirrors the multiple parameter test above, but with the parameter to + // be batch-broadcasted in the parameter_0 place instead of parameter_1. + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[125,127]{1,0} parameter(1) + param_1 = f32[127]{0} parameter(0) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[125,127]{1,0} multiply(param_0, broadcast_0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(1) + param_1 = f32[127]{0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_1, param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG3:[^:]*]]: f32, %[[ARG4:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.splat +CHECK: tt.addptr %[[P2]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, + TestSoftmaxEmitterWithAdditionalParameterEnteringAfterDiamond) { + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[125,127]{1,0} parameter(0) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(param_0, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + param_1 = f32[127]{0} parameter(1) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + ROOT multiply_0 = f32[125,127]{1,0} multiply(broadcast_4, broadcast_0) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG3:[^:]*]]: f32, %[[ARG4:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.addptr %[[P2]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, + TestSoftmaxEmitterWithMultipleParametersAlongTiledDimension) { + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[125,127]{1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + param_2 = f32[125]{0} parameter(2) + broadcast_0 = f32[125,127]{1,0} broadcast(param_1), dimensions={1} + multiply_0 = f32[125,127]{1,0} multiply(param_0, broadcast_0) + broadcast_1 = f32[125,127]{1,0} broadcast(param_2), dimensions={0} + multiply_1 = f32[125,127]{1,0} multiply(multiply_0, broadcast_1) + constant_0 = f32[] constant(0) + reduce_0 = f32[125]{0} reduce(multiply_1, constant_0), dimensions={1}, to_apply=add + broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_1, broadcast_4) +} + +ENTRY main { + param_0 = f32[125,127]{1,0} parameter(1) + param_1 = f32[127]{0} parameter(0) + param_2 = f32[125]{0} parameter(2) + ROOT triton_softmax = f32[125,127]{1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG2:.*]] = tt.addptr %[[P2]], %[[PID_i64]] : !tt.ptr, i64 +CHECK-NEXT: tt.load %[[ARG2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG4:[^:]*]]: f32, %[[ARG5:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG4]], %[[ARG5]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.splat +CHECK: arith.mulf +CHECK-SAME: tensor<128xf32> +CHECK: tt.addptr %[[P3]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + +TEST_F(TritonFilecheckTest, TestSoftmaxEmitterWithMultipleTiledDimensions) { + const std::string kHloText = R"( +HloModule t + +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + param_0 = f32[10,125,127]{2,1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + param_2 = f32[10,125]{1,0} parameter(2) + broadcast_0 = f32[10,125,127]{2,1,0} broadcast(param_1), dimensions={2} + multiply_0 = f32[10,125,127]{2,1,0} multiply(param_0, broadcast_0) + broadcast_1 = f32[10,125,127]{2,1,0} broadcast(param_2), dimensions={0,1} + multiply_1 = f32[10,125,127]{2,1,0} multiply(multiply_0, broadcast_1) + constant_0 = f32[] constant(0) + reduce_0 = f32[10,125]{1,0} reduce(multiply_1, constant_0), dimensions={2}, to_apply=add + broadcast_4 = f32[10,125,127]{2,1,0} broadcast(reduce_0), dimensions={0,1} + ROOT multiply = f32[10,125,127]{2,1,0} multiply(multiply_1, broadcast_4) +} + +ENTRY main { + param_0 = f32[10,125,127]{2,1,0} parameter(0) + param_1 = f32[127]{0} parameter(1) + param_2 = f32[10,125]{1,0} parameter(2) + ROOT triton_softmax = f32[10,125,127]{2,1,0} fusion(param_0, param_1, param_2), kind=kCustom, calls=triton_softmax_computation, backend_config={"kind":"__triton_softmax"} +} +)"; + TritonGemmConfig config(16, 64, 32, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitSoftMax, + "triton_softmax_computation", R"( +CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { +CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 +CHECK-DAG: %[[PID_i64:.*]] = arith.extsi %[[PID]] : i32 to i64 +CHECK-DAG: %[[C127_i64:.*]] = arith.constant 127 : i64 +CHECK-DAG: %[[ZERO_OFFSET:.*]] = arith.constant 0 : i64 +CHECK: %[[ROW_OFFSET:.*]] = arith.muli %[[PID_i64]], %[[C127_i64]] : i64 +CHECK: %[[ARG0:.*]] = tt.addptr %[[P0]], %[[ROW_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG1:.*]] = tt.addptr %[[P1]], %[[ZERO_OFFSET]] : !tt.ptr, i64 +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.load +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr, 1> -> tensor<128xf32> +CHECK: %[[ARG2:.*]] = tt.addptr %[[P2]], %[[PID_i64]] : !tt.ptr, i64 +CHECK-NEXT: tt.load %[[ARG2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 +CHECK: tt.reduce +CHECK-NEXT: ^bb0(%[[ARG4:[^:]*]]: f32, %[[ARG5:[^:]*]]: f32): +CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[ARG4]], %[[ARG5]] : f32 +CHECK-NEXT: tt.reduce.return %[[ADD]] : f32 +CHECK-NEXT: }) : (tensor<128xf32>) -> f32 +CHECK: tt.splat +CHECK: arith.mulf +CHECK-SAME: tensor<128xf32> +CHECK: tt.addptr %[[P3]] +CHECK-NEXT: tt.make_tensor_ptr +CHECK-SAME: , 1> +CHECK-NEXT: tt.store +CHECK-SAME: {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128xf32> +CHECK: tt.return +CHECK: } +)"), + tsl::testing::IsOkAndHolds(true)); +} + TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { const std::string kHloText = R"( triton_gemm_r { From 831c0849a0c0c9ef0acd4e92b7c2ea07dd6ef0ff Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 15 Nov 2023 17:37:42 -0800 Subject: [PATCH 155/391] Restructure the Py[Graph|Tensor|Operation] structs to use a separate struct for TensorFlow members. This avoids calling their dtor in tp_dealloc before calling tf_free. tf_free can access memory for GC booking keeping, and that's UB after the dtor is called. PiperOrigin-RevId: 582858777 --- .../python/client/tf_session_wrapper.cc | 277 ++++++++++-------- 1 file changed, 153 insertions(+), 124 deletions(-) diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 790629c96d2e4f..160416c4199102 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -138,9 +138,9 @@ pybind11::object method(pybind11::object type, Func&& function, // generation. The type is assumed to be a GC type (containing other types). // To add the required Python type fields, classes definitions must start with // -// TFObject_Head(classname) +// TFObject_Head(classname, TfObjectDataType) // -// Required attributes/methods: +// Required attributes/methods for TfObjectDataType type: // // Constructor(PyObject* args, PyObject* kw) // ~Destructor @@ -148,8 +148,10 @@ pybind11::object method(pybind11::object type, Func&& function, // Visit(visitproc visit, void* arg) // // Individual methods/attributes are added to the type later, as seen below. -template +template void MakeTfObjectType(PyObject** py_type) { + using TfObjectDataType = typename T::TfObjectDataType; + py::str name = py::str(T::kTypeName); py::str qualname = py::str(T::kTypeName); PyHeapTypeObject* heap_type = reinterpret_cast( @@ -162,11 +164,14 @@ void MakeTfObjectType(PyObject** py_type) { type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE; type->tp_name = T::kTypeName; - type->tp_basicsize = sizeof(T); + + // Allocation size for both Python object header and the TF data members. + type->tp_basicsize = sizeof(T) + sizeof(TfObjectDataType); type->tp_new = [](PyTypeObject* subtype, PyObject* args, PyObject* kwds) -> PyObject* { T* self = reinterpret_cast(subtype->tp_alloc(subtype, 0)); + TfObjectDataType* data = reinterpret_cast(&self[1]); if (!self) return nullptr; // PyType_GenericAlloc (the default implementation of tp_alloc) by default @@ -176,7 +181,7 @@ void MakeTfObjectType(PyObject** py_type) { // // We disable the GC here until initialization is finished. PyObject_GC_UnTrack(self); - new (self) T(args, kwds); + new (data) TfObjectDataType(args, kwds); self->dict = PyDict_New(); PyObject_GC_Track(self); @@ -193,9 +198,9 @@ void MakeTfObjectType(PyObject** py_type) { PyObject_ClearWeakRefs(self); T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_CLEAR(o->dict); - o->~T(); - + data->~TfObjectDataType(); tp->tp_free(self); Py_DECREF(tp); }; @@ -203,16 +208,18 @@ void MakeTfObjectType(PyObject** py_type) { type->tp_traverse = [](PyObject* self, visitproc visit, void* arg) { VLOG(3) << "Visit: " << T::kTypeName; T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_VISIT(Py_TYPE(self)); Py_VISIT(o->dict); - return o->Visit(visit, arg); + return data->Visit(visit, arg); }; type->tp_clear = [](PyObject* self) { VLOG(3) << "Clear: " << T::kTypeName; T* o = reinterpret_cast(self); + TfObjectDataType* data = reinterpret_cast(&o[1]); Py_CLEAR(o->dict); - o->Clear(); + data->Clear(); return 0; }; @@ -238,11 +245,13 @@ void MakeTfObjectType(PyObject** py_type) { *py_type = reinterpret_cast(type); } -#define TFObject_HEAD(typename) \ - PyObject_HEAD; \ - PyObject* dict = nullptr; \ - PyObject* weakrefs = nullptr; \ - static PyObject* py_type; \ +#define TFObject_HEAD(typename, datatypename) \ + using TfObjectDataType = datatypename; \ + PyObject_HEAD; \ + PyObject* dict = nullptr; \ + PyObject* weakrefs = nullptr; \ + TfObjectDataType data[0]; \ + static PyObject* py_type; \ static constexpr const char* kTypeName = #typename; struct PyGraph; @@ -272,7 +281,7 @@ PYBIND11_MAKE_OPAQUE(OpsByIdMap); PYBIND11_MAKE_OPAQUE(OpsByNameMap); // Convert the given handle to a TF object type. -template +template T* AsPyTfObject(py::handle handle) { if (handle.get_type() == T::py_type) { return reinterpret_cast(handle.ptr()); @@ -296,11 +305,15 @@ T* AsPyTfObject(py::handle handle) { py::cast(py::str(handle)))); } -template +template py::object AsPyObject(T* obj) { return py::reinterpret_borrow(reinterpret_cast(obj)); } +template +typename T::TfObjectDataType* AsPyTfObjectData(py::handle handle) { + return AsPyTfObject(handle)->data; +} // Reference counting helper for PyTfObjects. // // Similar to the pybind holder types, this manages the Python reference @@ -309,7 +322,7 @@ py::object AsPyObject(T* obj) { // As a special case to support Dismantle(), this allows setting our underlying // pointer to None when clearing the type. Direct access to attributes is not // allowed after this point. -template +template class tf_handle { public: tf_handle() : obj_(nullptr) {} @@ -402,9 +415,7 @@ struct TF_OperationDeleter { void operator()(TF_Operation* op) {} }; -struct PyGraph { - TFObject_HEAD(PyGraph); - +struct PyGraphData { TF_Graph* graph; // The C++ graph maintains an ID for every node, however our Python code has @@ -424,7 +435,7 @@ struct PyGraph { OpsByIdMap ops_by_id; OpsByNameMap ops_by_name; - PyGraph(PyObject* args, PyObject* kwds) { + PyGraphData(PyObject* args, PyObject* kwds) { graph = TF_NewGraph(); // By default shape inference functions are required, however this breaks @@ -433,7 +444,7 @@ struct PyGraph { graph->refiner.set_require_shape_inference_fns(false); } - ~PyGraph() { + ~PyGraphData() { Clear(); TF_DeleteGraph(graph); } @@ -462,22 +473,26 @@ struct PyGraph { } return 0; } +}; + +struct PyGraph { + TFObject_HEAD(PyGraph, PyGraphData); int64_t add_op(py::object obj); - py::list operations() { return op_list; } - int64_t num_operations() const { return op_list.size(); } + py::list operations() { return data->op_list; } + int64_t num_operations() const { return data->op_list.size(); } // Return operations that are part of the Graph, but do not yet have // OperationHandle's. This logic is only invoked when importing an existing // GraphDef into Python. It should be removed once all logic moves to C++. std::vector new_operations() { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); std::vector ops; // SUBTLE: `op_nodes` skips the SOURCE and SINK nodes - for (auto n : graph->graph.op_nodes()) { - if (ops_by_name.find(n->name()) == ops_by_name.end()) { + for (auto n : tf_graph()->graph.op_nodes()) { + if (data->ops_by_name.find(n->name()) == data->ops_by_name.end()) { ops.push_back(reinterpret_cast(n)); } } @@ -485,15 +500,15 @@ struct PyGraph { } py::object get_operation_by_name(const std::string& name) { - tsl::mutex_lock l(graph->mu); - auto it = ops_by_name.find(name); - if (it == ops_by_name.end()) { + tsl::mutex_lock l(tf_graph()->mu); + auto it = data->ops_by_name.find(name); + if (it == data->ops_by_name.end()) { throw py::key_error(); } return it->second; } - int version() const { return ops_by_id.size(); } + int version() const { return data->ops_by_id.size(); } py::bytes version_def() const { // Potential deadlock: @@ -509,8 +524,8 @@ struct PyGraph { std::string versions; { py::gil_scoped_release release; - tsl::mutex_lock l(graph->mu); - versions = graph->graph.versions().SerializeAsString(); + tsl::mutex_lock l(tf_graph()->mu); + versions = tf_graph()->graph.versions().SerializeAsString(); } pybind11::gil_scoped_acquire acquire; return py::bytes(versions); @@ -518,52 +533,52 @@ struct PyGraph { tsl::StatusOr _op_def_for_type( const std::string& kTypeName) const { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); const tensorflow::OpDef* op_def; TF_RETURN_IF_ERROR( - graph->graph.op_registry()->LookUpOpDef(kTypeName, &op_def)); + tf_graph()->graph.op_registry()->LookUpOpDef(kTypeName, &op_def)); return py::bytes(op_def->SerializeAsString()); } void add_control_input(tensorflow::Node* src, tensorflow::Node* dst) { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); - graph->graph.AddControlEdge(src, dst); + tf_graph()->graph.AddControlEdge(src, dst); record_mutation(*dst, "adding control edge"); } void remove_all_control_inputs(const tensorflow::Node& node) { - tsl::mutex_lock l(graph->mu); + tsl::mutex_lock l(tf_graph()->mu); std::vector control_edges; for (const tensorflow::Edge* edge : node.in_edges()) { if (!edge->IsControlEdge()) continue; control_edges.push_back(edge); } for (const tensorflow::Edge* edge : control_edges) { - graph->graph.RemoveControlEdge(edge); + tf_graph()->graph.RemoveControlEdge(edge); } } void record_mutation(const tensorflow::Node& node, const std::string& reason) - TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - tensorflow::RecordMutation( - graph, reinterpret_cast(node), reason.c_str()); + TF_EXCLUSIVE_LOCKS_REQUIRED(tf_graph()->mu) { + tensorflow::RecordMutation(tf_graph(), + reinterpret_cast(node), + reason.c_str()); } - TF_Graph* tf_graph() { return graph; } + TF_Graph* tf_graph() const { return data->graph; } }; -struct PyOperation { - TFObject_HEAD(PyOperation); - +struct PyOperationData { TF_Operation* tf_op = nullptr; + py::list outputs; // N.B. initialized later by Python. tf_handle graph; py::function tensor_fn; - PyOperation(PyObject* args, PyObject* kwds) { + PyOperationData(PyObject* args, PyObject* kwds) { PyObject *py_op, *py_tensor_fn; if (!PyArg_ParseTuple(args, "OO", &py_op, &py_tensor_fn)) { return; @@ -572,90 +587,92 @@ struct PyOperation { tensor_fn = py::cast(py_tensor_fn); } - ~PyOperation() { Clear(); } + ~PyOperationData() { Clear(); } + + void Dismantle(PyOperation* py_op); void Clear() { Py_CLEAR(outputs.release().ptr()); graph.Clear(); } - void Dismantle(); - int Visit(visitproc visit, void* arg) { Py_VISIT(graph.ptr()); Py_VISIT(outputs.ptr()); return 0; } +}; + +struct PyOperation { + TFObject_HEAD(PyOperation, PyOperationData); + + TF_Operation* tf_op() const { return data->tf_op; } void _init_outputs() { - int num_outputs = TF_OperationNumOutputs(tf_op); + int num_outputs = TF_OperationNumOutputs(tf_op()); for (int i = 0; i < num_outputs; ++i) { - auto dtype = TF_OperationOutputType(TF_Output{tf_op, i}); - outputs.append(tensor_fn(AsPyObject(this), i, dtype)); + auto dtype = TF_OperationOutputType(TF_Output{tf_op(), i}); + data->outputs.append(data->tensor_fn(AsPyObject(this), i, dtype)); } } tsl::Status _add_outputs(py::list dtypes, py::list shapes); - const TF_Operation* op() { return tf_op; } - - TF_Output _tf_output(int idx) const { return TF_Output{tf_op, idx}; } - TF_Input _tf_input(int idx) const { return TF_Input{tf_op, idx}; } + TF_Output _tf_output(int idx) const { return TF_Output{tf_op(), idx}; } + TF_Input _tf_input(int idx) const { return TF_Input{tf_op(), idx}; } py::bytes node_def() { - return py::bytes(tf_op->node.def().SerializeAsString()); + return py::bytes(tf_op()->node.def().SerializeAsString()); } py::bytes op_def() const { - return py::bytes(tf_op->node.op_def().SerializeAsString()); + return py::bytes(tf_op()->node.op_def().SerializeAsString()); } - bool is_stateful() const { return tf_op->node.op_def().is_stateful(); } + bool is_stateful() const { return tf_op()->node.op_def().is_stateful(); } - const std::string& type() { return tf_op->node.type_string(); } + const std::string& type() { return tf_op()->node.type_string(); } void add_control_input(PyOperation* input) { - graph->add_control_input(&input->tf_op->node, &tf_op->node); + data->graph->add_control_input(&input->tf_op()->node, &tf_op()->node); } void add_control_inputs(py::iterable inputs); py::list control_inputs() { py::list output; - for (const auto* edge : tf_op->node.in_edges()) { + for (const auto* edge : tf_op()->node.in_edges()) { if (edge->IsControlEdge() && !edge->src()->IsSource()) { - output.append(graph->ops_by_id[edge->src()->id()]); + output.append(data->graph->data->ops_by_id[edge->src()->id()]); } } return output; } py::list control_outputs() { py::list output; - for (const auto* edge : tf_op->node.out_edges()) { + for (const auto* edge : tf_op()->node.out_edges()) { if (edge->IsControlEdge() && !edge->dst()->IsSink()) { - output.append(graph->ops_by_id[edge->dst()->id()]); + output.append(data->graph->data->ops_by_id[edge->dst()->id()]); } } return output; } void remove_all_control_inputs() { - graph->remove_all_control_inputs(tf_op->node); + data->graph->remove_all_control_inputs(tf_op()->node); } void set_device(const std::string& device) { - tsl::mutex_lock l(graph->graph->mu); - tf_op->node.set_requested_device(device); - graph->record_mutation(tf_op->node, "setting device"); + tsl::mutex_lock l(data->graph->tf_graph()->mu); + tf_op()->node.set_requested_device(device); + data->graph->record_mutation(tf_op()->node, "setting device"); } - const std::string& device() { return tf_op->node.requested_device(); } - const std::string& name() { return tf_op->node.name(); } + const std::string& device() { return tf_op()->node.requested_device(); } + const std::string& name() { return tf_op()->node.name(); } }; -struct PyTensor { - TFObject_HEAD(PyTensor); - +struct PyTensorData { py::object tf_output = py::none(); py::object name = py::none(); py::object dtype = py::none(); @@ -667,7 +684,7 @@ struct PyTensor { int value_index = -1; - PyTensor(PyObject* args, PyObject* kwds) { + PyTensorData(PyObject* args, PyObject* kwds) { PyObject *py_op, *py_index, *py_dtype, *py_uid; if (!PyArg_ParseTuple(args, "OOOO", &py_op, &py_index, &py_dtype, &py_uid)) { @@ -676,12 +693,13 @@ struct PyTensor { dtype = py::reinterpret_borrow(py_dtype); value_index = py::cast(py::handle(py_index)); op = py_op; - graph = op->graph; + graph = op->data->graph; name = py::str(absl::StrCat(op->name(), ":", value_index)); - tf_output = py::cast(TF_Output{op->tf_op, value_index}); + tf_output = py::cast(TF_Output{op->tf_op(), value_index}); uid = py::reinterpret_borrow(py_uid); } - ~PyTensor() { Clear(); } + + ~PyTensorData() { Clear(); } void Clear() { Py_CLEAR(tf_output.release().ptr()); @@ -703,14 +721,20 @@ struct PyTensor { Py_VISIT(uid.ptr()); return 0; } +}; + +struct PyTensor { + TFObject_HEAD(PyTensor, PyTensorData); + + int value_index() const { return data->value_index; } tsl::StatusOr shape() { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); bool unknown_shape = false; auto dims = tensorflow::TF_GraphGetTensorShapeHelper( - graph->tf_graph(), TF_Output{op->tf_op, value_index}, status.get(), - &unknown_shape); + data->graph->tf_graph(), TF_Output{data->op->tf_op(), value_index()}, + status.get(), &unknown_shape); if (!status.get()->status.ok()) { return status.get()->status; } @@ -737,17 +761,17 @@ struct PyTensor { } } tensorflow::TF_GraphSetTensorShape_wrapper( - graph->tf_graph(), TF_Output{op->tf_op, value_index}, dims, - unknown_shape, status.get()); + data->graph->tf_graph(), TF_Output{data->op->tf_op(), value_index()}, + dims, unknown_shape, status.get()); return status.get()->status; } int64_t rank() { - tsl::mutex_lock l(graph->graph->mu); + tsl::mutex_lock l(data->graph->tf_graph()->mu); tensorflow::shape_inference::InferenceContext* ic = - graph->graph->refiner.GetContext(&op->tf_op->node); + data->graph->tf_graph()->refiner.GetContext(&data->op->tf_op()->node); - tensorflow::shape_inference::ShapeHandle shape = ic->output(value_index); + tensorflow::shape_inference::ShapeHandle shape = ic->output(value_index()); if (ic->RankKnown(shape)) { return ic->Rank(shape); } @@ -756,11 +780,11 @@ struct PyTensor { py::list consumers() { py::list out; - for (const auto* edge : op->tf_op->node.out_edges()) { - if (edge->src_output() != value_index) { + for (const auto* edge : data->op->tf_op()->node.out_edges()) { + if (edge->src_output() != value_index()) { continue; } - out.append(graph->ops_by_id[edge->dst()->id()]); + out.append(data->graph->data->ops_by_id[edge->dst()->id()]); } return out; } @@ -770,17 +794,17 @@ PyObject* PyOperation::py_type = nullptr; PyObject* PyTensor::py_type = nullptr; PyObject* PyGraph::py_type = nullptr; -void PyOperation::Dismantle() { +void PyOperationData::Dismantle(PyOperation* py_op) { outputs = py::list(); - PyDict_Clear(dict); graph.Destroy(); + PyDict_Clear(py_op->dict); } tsl::Status PyOperation::_add_outputs(py::list dtypes, py::list shapes) { - int orig_outputs = outputs.size(); + int orig_outputs = data->outputs.size(); for (int i = 0; i < dtypes.size(); ++i) { py::object tensor = - tensor_fn(AsPyObject(this), orig_outputs + i, dtypes[i]); + data->tensor_fn(AsPyObject(this), orig_outputs + i, dtypes[i]); // The passed in `shapes` may be TensorShapes, convert them to lists if // needed. @@ -799,24 +823,25 @@ tsl::Status PyOperation::_add_outputs(py::list dtypes, py::list shapes) { } TF_RETURN_IF_ERROR( AsPyTfObject(tensor)->set_shape(dims, unknown_shape)); - outputs.append(tensor); + data->outputs.append(tensor); } return tsl::OkStatus(); } void PyOperation::add_control_inputs(py::iterable inputs) { - tsl::mutex_lock l(graph->tf_graph()->mu); + tsl::mutex_lock l(data->graph->tf_graph()->mu); for (py::handle input : inputs) { auto* input_handle = py::cast(input); - graph->tf_graph()->graph.AddControlEdge(&input_handle->tf_op->node, - &tf_op->node); + data->graph->tf_graph()->graph.AddControlEdge(&input_handle->tf_op()->node, + &tf_op()->node); } - graph->record_mutation(tf_op->node, "adding control input"); + data->graph->record_mutation(tf_op()->node, "adding control input"); } -void PyGraph::Dismantle() { +void PyGraphData::Dismantle() { for (auto& op : op_list) { - AsPyTfObject(op.ptr())->Dismantle(); + AsPyTfObjectData(op.ptr())->Dismantle( + AsPyTfObject(op.ptr())); } op_list = py::list(); ops_by_id.clear(); @@ -825,10 +850,10 @@ void PyGraph::Dismantle() { int64_t PyGraph::add_op(py::object obj) { PyOperation* op_handle = AsPyTfObject(obj); - int64_t op_id = op_handle->tf_op->node.id(); - op_list.append(obj); - ops_by_id[op_id] = obj; - ops_by_name[op_handle->name()] = obj; + int64_t op_id = op_handle->tf_op()->node.id(); + data->op_list.append(obj); + data->ops_by_id[op_id] = obj; + data->ops_by_name[op_handle->name()] = obj; return op_id; } @@ -848,7 +873,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { m.attr("PyGraph") = c_graph; c_graph.attr("__module__") = module_name; c_graph.attr("Dismantle") = method(c_graph, [](py::handle handle) { - AsPyTfObject(handle)->Dismantle(); + AsPyTfObjectData(handle)->Dismantle(); }); c_graph.attr("_version_def") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->version_def(); @@ -861,10 +886,10 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->_op_def_for_type(type); }); c_graph.attr("_nodes_by_name") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->ops_by_name; + return AsPyTfObjectData(handle)->ops_by_name; }); c_graph.attr("_nodes_by_id") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->ops_by_id; + return AsPyTfObjectData(handle)->ops_by_id; }); c_graph.attr("_get_operation_by_name") = method(c_graph, [](py::handle handle, std::string name) { @@ -919,18 +944,18 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->remove_all_control_inputs(); }); c_op.attr("outputs") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->outputs; + return AsPyTfObjectData(handle)->outputs; }); c_op.attr("graph") = property( [](py::handle handle) { - return AsPyTfObject(handle)->graph.borrow(); + return AsPyTfObjectData(handle)->graph.borrow(); }, [](py::handle handle, py::handle graph) { auto op = AsPyTfObject(handle); - op->graph = graph.ptr(); + op->data->graph = graph.ptr(); }); c_op.attr("_c_op") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->tf_op; + return AsPyTfObject(handle)->tf_op(); }); c_op.attr("_is_stateful") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->is_stateful(); @@ -983,7 +1008,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { m.attr("PyTensor") = c_tensor; c_tensor.attr("__module__") = module_name; c_tensor.attr("device") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->op->device(); + return AsPyTfObjectData(handle)->op->device(); }); c_tensor.attr("ndim") = property_readonly([](py::handle handle) { return AsPyTfObject(handle)->rank(); @@ -995,40 +1020,44 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->shape(); }); c_tensor.attr("_dtype") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->dtype; + return AsPyTfObjectData(handle)->dtype; }); c_tensor.attr("_name") = property( - [](py::handle handle) { return AsPyTfObject(handle)->name; }, + [](py::handle handle) { + return AsPyTfObjectData(handle)->name; + }, [](py::handle handle, py::object name) { - AsPyTfObject(handle)->name = name; + AsPyTfObjectData(handle)->name = name; }); c_tensor.attr("_shape_val") = property( [](py::handle handle) { auto py_tensor = AsPyTfObject(handle); - return py_tensor->shape_val; + return py_tensor->data->shape_val; }, [](py::handle handle, py::object shape) { - AsPyTfObject(handle)->shape_val = shape; + AsPyTfObjectData(handle)->shape_val = shape; }); c_tensor.attr("_id") = property( - [](py::handle handle) { return AsPyTfObject(handle)->uid; }, + [](py::handle handle) { + return AsPyTfObjectData(handle)->uid; + }, [](py::handle handle, py::object uid) { - AsPyTfObject(handle)->uid = uid; + AsPyTfObjectData(handle)->uid = uid; }); c_tensor.attr("graph") = property_readonly([](py::handle handle) -> py::handle { - auto& graph = AsPyTfObject(handle)->graph; + auto& graph = AsPyTfObjectData(handle)->graph; if (graph.ptr() != nullptr) { return graph.borrow(); } return py::none(); }); c_tensor.attr("_as_tf_output") = method(c_tensor, [](py::handle handle) { - return AsPyTfObject(handle)->tf_output; + return AsPyTfObjectData(handle)->tf_output; }); c_tensor.attr("_op") = property_readonly([](py::handle handle) -> py::handle { - auto& op = AsPyTfObject(handle)->op; + auto& op = AsPyTfObjectData(handle)->op; if (op.ptr() != nullptr) { return op.borrow(); } @@ -1036,7 +1065,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { }); c_tensor.attr("op") = property_readonly([](py::handle handle) -> py::handle { - auto& op = AsPyTfObject(handle)->op; + auto& op = AsPyTfObjectData(handle)->op; if (op.ptr() != nullptr) { return op.borrow(); } @@ -1048,7 +1077,7 @@ PYBIND11_MODULE(_pywrap_tf_session, m) { return AsPyTfObject(handle)->set_shape(shape, unknown_shape); }); c_tensor.attr("value_index") = property_readonly([](py::handle handle) { - return AsPyTfObject(handle)->value_index; + return AsPyTfObject(handle)->value_index(); }); c_tensor.attr("consumers") = method(c_tensor, [](py::handle handle) { return AsPyTfObject(handle)->consumers(); From 0785b6e975c570d365bb8a695275cdd0d110dc9c Mon Sep 17 00:00:00 2001 From: Chuan He Date: Wed, 15 Nov 2023 17:53:46 -0800 Subject: [PATCH 156/391] Support "metadata_buffer" field in the TFLite flatbuffer importer/exporter. PiperOrigin-RevId: 582861948 --- .../compiler/mlir/lite/flatbuffer_export.cc | 20 +------------------ .../compiler/mlir/lite/flatbuffer_import.cc | 5 ----- .../flatbuffer2mlir/metadata_buffer.mlir | 9 --------- .../mlir2flatbuffer/metadata_buffer.mlir | 11 ---------- 4 files changed, 1 insertion(+), 44 deletions(-) delete mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir delete mode 100644 tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index d9d200b7fa8f04..44e81f58ad9686 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -681,11 +681,6 @@ class Translator { std::optional>> CreateMetadataVector(); - // Encodes the `tfl.metadata_buffer` array attribute of the module to the - // metadata_buffer section in the final model. Returns empty if there isn't - // such attribute in the mlir module. - VectorBufferOffset CreateMetadataBufferVector(); - // Builds and returns list of tfl.SignatureDef sections in the model. std::optional>> CreateSignatureDefs(const std::vector& signature_defs); @@ -2655,18 +2650,6 @@ Translator::CreateMetadataVector() { return builder_.CreateVector(metadata); } -VectorBufferOffset Translator::CreateMetadataBufferVector() { - auto array_attr = - module_->getAttrOfType("tfl.metadata_buffer"); - std::vector metadata_buffer; - if (!array_attr) return 0; - for (auto value : array_attr.getAsValueRange()) { - metadata_buffer.push_back(value.getSExtValue()); - } - - return builder_.CreateVector(metadata_buffer); -} - // Helper method that returns list of all strings in a StringAttr identified // by 'attr_key' and values are separated by a comma. llvm::SmallVector GetStringsFromAttrWithSeparator( @@ -3071,8 +3054,7 @@ std::optional Translator::TranslateInternal() { // Build the model and finish the model building process. auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = - CreateMetadataBufferVector(); // Deprecated + VectorBufferOffset metadata_buffer = 0; // Deprecated auto metadata = CreateMetadataVector(); if (!metadata) return std::nullopt; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 61083fbe47ed29..69dd8ad342cfe1 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -1905,11 +1905,6 @@ OwningOpRef tflite::FlatBufferToMlir( mlir::UnitAttr::get(builder.getContext())); } - if (!model->metadata_buffer.empty()) { - module->setAttr("tfl.metadata_buffer", - builder.getI32ArrayAttr(model->metadata_buffer)); - } - if (use_stablehlo_constant) { module->setAttr("tfl.metadata", builder.getDictionaryAttr(builder.getNamedAttr( diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir deleted file mode 100644 index 6b76b31c9a52bf..00000000000000 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/metadata_buffer.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s - -// CHECK: tfl.metadata_buffer = [3 : i32, 7 : i32] -module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} { - func.func @main(%arg0: tensor, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> { - %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> - func.return %0 : tensor<3x2xi32> - } -} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir deleted file mode 100644 index f53f3954f14211..00000000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/metadata_buffer.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s - -module attributes {tfl.metadata_buffer = [3 : i32, 7 : i32]} { - func.func @main(%arg0: tensor, %arg1: tensor<3x2xi32>) -> tensor<3x2xi32> { - %0 = "tfl.add" (%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> - func.return %0 : tensor<3x2xi32> - } -} - -// CHECK: metadata_buffer: [ 3, 7 ], -// CHECK-NEXT: metadata: \ No newline at end of file From dc40c1ce0457ab953e7c01c678caaa605eb3050b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 20:57:13 -0800 Subject: [PATCH 157/391] Replaces 'StrategyVector' with 'StrategyGroup'. PiperOrigin-RevId: 582897760 --- .../xla/hlo/experimental/auto_sharding/BUILD | 5 +- .../auto_sharding/auto_sharding.cc | 895 +++++++++--------- .../auto_sharding/auto_sharding.h | 14 +- .../auto_sharding/auto_sharding_cost_graph.h | 31 +- .../auto_sharding_dot_handler.cc | 40 +- .../auto_sharding/auto_sharding_strategy.h | 23 +- .../auto_sharding/auto_sharding_util.cc | 120 +-- 7 files changed, 582 insertions(+), 546 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index a792011b5a254b..42a95183639812 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -133,7 +133,10 @@ cc_library( deps = [ ":auto_sharding_strategy", ":matrix", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -226,12 +229,12 @@ cc_library( "//xla:array", "//xla:shape_tree", "//xla:shape_util", + "//xla:status", "//xla:statusor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", - "//xla/service:hlo_cost_analysis", "//xla/service:sharding_propagation", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 7f172c0b07b251..07d97f4d92a180 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -90,52 +90,51 @@ namespace spmd { // Compute the resharding cost vector from multiple possible strategies // to a desired sharding spec. std::vector ReshardingCostVector( - const StrategyVector* strategies, const Shape& operand_shape, + const StrategyGroup* strategy_group, const Shape& operand_shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env) { - CHECK(!strategies->is_tuple) << "Only works with strategy vector."; + CHECK(!strategy_group->is_tuple) << "Only works with strategy vector."; std::vector ret; - ret.reserve(strategies->leaf_vector.size()); + ret.reserve(strategy_group->leaf_vector.size()); auto required_sharding_for_resharding = required_sharding.IsTileMaximal() ? HloSharding::Replicate() : required_sharding; - for (const auto& x : strategies->leaf_vector) { + for (const auto& x : strategy_group->leaf_vector) { ret.push_back(cluster_env.ReshardingCost(operand_shape, x.output_sharding, required_sharding_for_resharding)); } return ret; } -// Factory functions for StrategyVector. -std::unique_ptr CreateLeafStrategyVectorWithoutInNodes( +// Factory functions for StrategyGroup. +std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( size_t instruction_id, LeafStrategies& leaf_strategies) { - auto strategies = std::make_unique(); - strategies->is_tuple = false; - strategies->node_idx = leaf_strategies.size(); - leaf_strategies.push_back(strategies.get()); - strategies->instruction_id = instruction_id; - return strategies; + auto strategy_group = std::make_unique(); + strategy_group->is_tuple = false; + strategy_group->node_idx = leaf_strategies.size(); + leaf_strategies.push_back(strategy_group.get()); + strategy_group->instruction_id = instruction_id; + return strategy_group; } -// Factory functions for StrategyVector. -std::unique_ptr CreateLeafStrategyVector( +// Factory functions for StrategyGroup. +std::unique_ptr CreateLeafStrategyGroup( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, LeafStrategies& leaf_strategies) { - auto strategies = - CreateLeafStrategyVectorWithoutInNodes(instruction_id, leaf_strategies); + auto strategy_group = + CreateLeafStrategyGroupWithoutInNodes(instruction_id, leaf_strategies); for (int64_t i = 0; i < ins->operand_count(); ++i) { - strategies->in_nodes.push_back(strategy_map.at(ins->operand(i)).get()); + strategy_group->in_nodes.push_back(strategy_map.at(ins->operand(i)).get()); } - return strategies; + return strategy_group; } -std::unique_ptr CreateTupleStrategyVector( - size_t instruction_id) { - auto strategies = std::make_unique(); - strategies->is_tuple = true; - strategies->node_idx = -1; - strategies->instruction_id = instruction_id; - return strategies; +std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id) { + auto strategy_group = std::make_unique(); + strategy_group->is_tuple = true; + strategy_group->node_idx = -1; + strategy_group->instruction_id = instruction_id; + return strategy_group; } // ShardingPropagation::GetShardingFromUser does not handle TopK custom @@ -229,53 +228,54 @@ GenerateReshardingCostsAndShardingsForAllOperands( CHECK(sharding_optional.has_value()); } - return std::make_pair(resharding_costs, input_shardings_optional); + return {resharding_costs, input_shardings_optional}; } -std::unique_ptr MaybeFollowInsStrategyVector( - const StrategyVector* src_strategies, const Shape& shape, +std::unique_ptr MaybeFollowInsStrategyGroup( + const StrategyGroup* src_strategy_group, const Shape& shape, size_t instruction_id, bool have_memory_cost, LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, StableHashMap>& pretrimmed_strategy_map) { - std::unique_ptr strategies; - if (src_strategies->is_tuple) { + std::unique_ptr strategy_group; + if (src_strategy_group->is_tuple) { CHECK(shape.IsTuple()); - CHECK_EQ(shape.tuple_shapes_size(), src_strategies->childs.size()); - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(src_strategies->childs.size()); - for (size_t i = 0; i < src_strategies->childs.size(); ++i) { - auto child_strategies = MaybeFollowInsStrategyVector( - src_strategies->childs[i].get(), shape.tuple_shapes(i), + CHECK_EQ(shape.tuple_shapes_size(), src_strategy_group->childs.size()); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(src_strategy_group->childs.size()); + for (size_t i = 0; i < src_strategy_group->childs.size(); ++i) { + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[i].get(), shape.tuple_shapes(i), instruction_id, have_memory_cost, leaf_strategies, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategies->childs.push_back(std::move(child_strategies)); + strategy_group->childs.push_back(std::move(child_strategies)); } } else { CHECK(shape.IsArray() || shape.IsToken()); - strategies = - CreateLeafStrategyVectorWithoutInNodes(instruction_id, leaf_strategies); - strategies->in_nodes.push_back(src_strategies); + strategy_group = + CreateLeafStrategyGroupWithoutInNodes(instruction_id, leaf_strategies); + strategy_group->in_nodes.push_back(src_strategy_group); // Only follows the given strategy when there is no other strategy to be // restored. - if (!pretrimmed_strategy_map.contains(src_strategies->node_idx)) { - strategies->following = src_strategies; + if (!pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { + strategy_group->following = src_strategy_group; } - strategies->leaf_vector.reserve(src_strategies->leaf_vector.size()); + strategy_group->leaf_vector.reserve(src_strategy_group->leaf_vector.size()); // Creates the sharding strategies and restores the trimmed strategies if // there is any. for (int64_t sid = 0; - sid < src_strategies->leaf_vector.size() + - pretrimmed_strategy_map[src_strategies->node_idx].size(); + sid < src_strategy_group->leaf_vector.size() + + pretrimmed_strategy_map[src_strategy_group->node_idx].size(); ++sid) { const HloSharding* output_spec; - if (sid < src_strategies->leaf_vector.size()) { - output_spec = &src_strategies->leaf_vector[sid].output_sharding; + if (sid < src_strategy_group->leaf_vector.size()) { + output_spec = &src_strategy_group->leaf_vector[sid].output_sharding; } else { output_spec = - &pretrimmed_strategy_map[src_strategies->node_idx] - [sid - src_strategies->leaf_vector.size()] + &pretrimmed_strategy_map[src_strategy_group->node_idx] + [sid - + src_strategy_group->leaf_vector.size()] .output_sharding; VLOG(1) << "Adding outspec from the trimmed strategy map: " << output_spec->ToString(); @@ -284,9 +284,9 @@ std::unique_ptr MaybeFollowInsStrategyVector( double compute_cost = 0, communication_cost = 0; double memory_cost = have_memory_cost ? GetBytes(shape) / output_spec->NumTiles() : 0; - auto resharding_costs = ReshardingCostVector(src_strategies, shape, + auto resharding_costs = ReshardingCostVector(src_strategy_group, shape, *output_spec, cluster_env); - strategies->leaf_vector.push_back( + strategy_group->leaf_vector.push_back( ShardingStrategy({name, *output_spec, compute_cost, @@ -296,19 +296,19 @@ std::unique_ptr MaybeFollowInsStrategyVector( {*output_spec}})); } } - return strategies; + return strategy_group; } -StatusOr> FollowReduceStrategy( +StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, const HloInstruction* operand, const HloInstruction* unit, size_t instruction_id, StrategyMap& strategy_map, LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, bool allow_mixed_mesh_shape, bool crash_at_error) { - std::unique_ptr strategies; + std::unique_ptr strategy_group; if (output_shape.IsTuple()) { - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(ins->shape().tuple_shapes_size()); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { auto child_strategy_status = FollowReduceStrategy( ins, ins->shape().tuple_shapes().at(i), ins->operand(i), @@ -319,15 +319,16 @@ StatusOr> FollowReduceStrategy( return child_strategy_status; } child_strategy_status.value()->tuple_element_idx = i; - strategies->childs.push_back(std::move(child_strategy_status.value())); + strategy_group->childs.push_back( + std::move(child_strategy_status.value())); } } else if (output_shape.IsArray()) { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); - const StrategyVector* src_strategies = strategy_map.at(operand).get(); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + leaf_strategies); + const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); // Follows the strategy of the operand. - strategies->following = src_strategies; - strategies->leaf_vector.reserve(src_strategies->leaf_vector.size()); + strategy_group->following = src_strategy_group; + strategy_group->leaf_vector.reserve(src_strategy_group->leaf_vector.size()); // Map operand dims to inst dim // Example: f32[1,16]{1,0} reduce(f32[1,16,4096]{2,1,0} %param0, f32[] // %param1), dimensions={2} @@ -339,9 +340,9 @@ StatusOr> FollowReduceStrategy( operand->shape().rank()) << "Invalid kReduce: output size + reduced dimensions size != op count"; - for (size_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + for (size_t sid = 0; sid < src_strategy_group->leaf_vector.size(); ++sid) { HloSharding input_sharding = - src_strategies->leaf_vector[sid].output_sharding; + src_strategy_group->leaf_vector[sid].output_sharding; const auto& tensor_dim_to_mesh = cluster_env.GetTensorDimToMeshDimWrapper( operand->shape(), input_sharding, /* consider_reverse_device_meshes */ true, @@ -370,7 +371,7 @@ StatusOr> FollowReduceStrategy( output_shape, operand_clone.get(), unit_clone.get(), ins->dimensions(), ins->to_apply()); operand_clone->set_sharding( - src_strategies->leaf_vector[sid].output_sharding); + src_strategy_group->leaf_vector[sid].output_sharding); auto s = new_reduce->ReplaceOperandWith(0, operand_clone.get()); if (!s.ok()) { continue; @@ -411,12 +412,12 @@ StatusOr> FollowReduceStrategy( memory_cost, resharding_costs, {input_sharding}}); - strategies->leaf_vector.push_back(strategy); + strategy_group->leaf_vector.push_back(strategy); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); } - return strategies; + return strategy_group; } std::vector FindReplicateStrategyIndices( @@ -433,7 +434,7 @@ std::vector FindReplicateStrategyIndices( std::pair>, std::vector>> ReshardingCostsForTupleOperand(const HloInstruction* operand, - StrategyVector* operand_strategy_vector) { + StrategyGroup* operand_strategy_vector) { // TODO(yuemmawang) Support instructions with more than one tuple operand. // Creates resharding costs such that favors when operand strategies are // replicated. @@ -457,10 +458,9 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, resharding_costs.back().at(i) = 0.0; } } - return std::make_pair( - resharding_costs, - std::vector>( - {HloSharding::Tuple(operand->shape(), tuple_element_shardings)})); + return {resharding_costs, + std::vector>( + {HloSharding::Tuple(operand->shape(), tuple_element_shardings)})}; } std::vector> CreateZeroReshardingCostsForAllOperands( @@ -497,7 +497,7 @@ std::vector> CreateZeroReshardingCostsForAllOperands( void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, double replicated_penalty) { HloSharding output_spec = HloSharding::Replicate(); std::vector> resharding_costs; @@ -541,7 +541,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, } resharding_costs.push_back({}); double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -587,7 +587,7 @@ double ComputeCommunicationCost( void AddReplicatedStrategy( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, double replicated_penalty, + std::unique_ptr& strategy_group, double replicated_penalty, absl::flat_hash_set operands_to_consider_all_strategies_for = {}) { HloSharding replicated_strategy = HloSharding::Replicate(); HloSharding output_spec = replicated_strategy; @@ -634,7 +634,7 @@ void AddReplicatedStrategy( for (size_t j = 0; j < possible_input_shardings.size(); ++j) { double communication_cost = ComputeCommunicationCost( ins, possible_input_shardings[j], cluster_env); - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {"R", replicated_strategy, replicated_penalty, communication_cost, memory_cost, std::move(possible_resharding_costs[j]), std::move(possible_input_shardings[j])})); @@ -665,7 +665,7 @@ void AddReplicatedStrategy( } } } - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -689,7 +689,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, bool only_allow_divisible, const std::string& suffix, const CallGraph& call_graph) { @@ -744,7 +744,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, communication_cost = ComputeSortCommunicationCost( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -755,7 +755,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, const CallGraph& call_graph, absl::Span tensor_dims); @@ -764,7 +764,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, const CallGraph& call_graph, @@ -773,7 +773,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForOp(ins, shape, device_mesh, cluster_env, - strategy_map, strategies, call_graph, + strategy_map, strategy_group, call_graph, tensor_dims); return; } @@ -798,7 +798,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumerateAllPartition(ins, shape, device_mesh, cluster_env, strategy_map, - strategies, batch_dim_map, only_allow_divisible, + strategy_group, batch_dim_map, only_allow_divisible, call_graph, partition_dimensions, next_tensor_dims); } } @@ -808,7 +808,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, const Array& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, const CallGraph& call_graph, absl::Span tensor_dims) { std::vector mesh_dims(tensor_dims.size()); @@ -861,19 +861,17 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } } } - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } // Enumerate all 1d partition strategies for reshape. -void EnumerateAll1DPartitionReshape(const HloInstruction* ins, - const Array& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - std::unique_ptr& strategies, - bool only_allow_divisible, - const std::string& suffix) { +void EnumerateAll1DPartitionReshape( + const HloInstruction* ins, const Array& device_mesh, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, bool only_allow_divisible, + const std::string& suffix) { const HloInstruction* operand = ins->operand(0); for (int64_t i = 0; i < ins->shape().rank(); ++i) { @@ -905,7 +903,7 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, std::vector> resharding_costs{ ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), *input_spec, cluster_env)}; - strategies->leaf_vector.push_back( + strategy_group->leaf_vector.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -917,12 +915,11 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, } } -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const Array& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - std::unique_ptr& strategies, - absl::Span tensor_dims); +void BuildStrategyAndCostForReshape( + const HloInstruction* ins, const Array& device_mesh, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + absl::Span tensor_dims); // Enumerate all partitions for reshape. Batch dim is always partitioned. void EnumeratePartitionReshape(const HloInstruction* ins, @@ -930,14 +927,14 @@ void EnumeratePartitionReshape(const HloInstruction* ins, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const InstructionBatchDimMap& batch_dim_map, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, bool only_allow_divisible, int64_t partition_dimensions, const std::vector& tensor_dims = {}) { const auto tensor_dims_size = tensor_dims.size(); if (tensor_dims_size == partition_dimensions) { BuildStrategyAndCostForReshape(ins, device_mesh, cluster_env, strategy_map, - strategies, tensor_dims); + strategy_group, tensor_dims); return; } auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); @@ -964,17 +961,17 @@ void EnumeratePartitionReshape(const HloInstruction* ins, std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategies, only_allow_divisible, - partition_dimensions, next_tensor_dims); + batch_dim_map, strategy_group, + only_allow_divisible, partition_dimensions, + next_tensor_dims); } } -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const Array& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - std::unique_ptr& strategies, - absl::Span tensor_dims) { +void BuildStrategyAndCostForReshape( + const HloInstruction* ins, const Array& device_mesh, + const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, + std::unique_ptr& strategy_group, + absl::Span tensor_dims) { const HloInstruction* operand = ins->operand(0); std::vector mesh_dims(tensor_dims.size()); std::iota(mesh_dims.begin(), mesh_dims.end(), 0); @@ -994,7 +991,7 @@ void BuildStrategyAndCostForReshape(const HloInstruction* ins, std::vector> resharding_costs{ ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), *input_spec, cluster_env)}; - strategies->leaf_vector.push_back( + strategy_group->leaf_vector.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -1007,15 +1004,16 @@ void BuildStrategyAndCostForReshape(const HloInstruction* ins, // Return the maximum number of tiles among all strategies of an instruction. int64_t MaxNumTiles(const StrategyMap& strategy_map, const HloInstruction* ins) { - const StrategyVector* strategies = strategy_map.at(ins).get(); + const StrategyGroup* strategy_group = strategy_map.at(ins).get(); // TODO(zhuohan): optimize with path compression. - while (strategies->following != nullptr) { - strategies = strategies->following; + while (strategy_group->following != nullptr) { + strategy_group = strategy_group->following; } int64_t max_num_tiles = -1; - for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) { - max_num_tiles = std::max( - max_num_tiles, strategies->leaf_vector[i].output_sharding.NumTiles()); + for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { + max_num_tiles = + std::max(max_num_tiles, + strategy_group->leaf_vector[i].output_sharding.NumTiles()); } return max_num_tiles; @@ -1048,7 +1046,7 @@ std::pair ChooseOperandToFollow( for (int64_t i = 0; i < ins->operand_count(); ++i) { const HloInstruction* operand = ins->operand(i); if (operand == it->second) { - return std::make_pair(i, false); + return {i, false}; } } } @@ -1073,7 +1071,7 @@ std::pair ChooseOperandToFollow( } CHECK(follow_idx.has_value()); - return std::make_pair(*follow_idx, tie); + return {*follow_idx, tie}; } // Return whether an instruciton can follow one of its operand when @@ -1122,17 +1120,17 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim( } } -StatusOr> CreateAllStrategiesVector( +StatusOr> CreateAllStrategiesVector( const HloInstruction* ins, const Shape& shape, size_t instruction_id, LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, bool only_allow_divisible, bool create_replicated_strategies) { - std::unique_ptr strategies; + std::unique_ptr strategy_group; if (shape.IsTuple()) { - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(shape.tuple_shapes_size()); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(shape.tuple_shapes_size()); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { auto child_strategies = CreateAllStrategiesVector(ins, shape.tuple_shapes(i), instruction_id, @@ -1142,62 +1140,62 @@ StatusOr> CreateAllStrategiesVector( create_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; - strategies->childs.push_back(std::move(child_strategies)); + strategy_group->childs.push_back(std::move(child_strategies)); } } else if (shape.IsArray()) { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + leaf_strategies); EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategies, only_allow_divisible, "", - call_graph); + strategy_map, strategy_group, only_allow_divisible, + "", call_graph); // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategies, batch_dim_map, + strategy_map, strategy_group, batch_dim_map, only_allow_divisible, call_graph, /*partitions*/ 2); } // Split 3 dims if (cluster_env.IsDeviceMesh3D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategies, batch_dim_map, + strategy_map, strategy_group, batch_dim_map, only_allow_divisible, call_graph, /*partitions*/ 3); } if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { // Set penalty for 1d partial tiled layout - for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) { - strategies->leaf_vector[i].compute_cost += replicated_penalty * 0.8; + for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { + strategy_group->leaf_vector[i].compute_cost += replicated_penalty * 0.8; } // Split 1 dim, but for 1d mesh EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_1d_, - cluster_env, strategy_map, strategies, + cluster_env, strategy_map, strategy_group, only_allow_divisible, " 1d", call_graph); } - if (create_replicated_strategies || strategies->leaf_vector.empty()) { - AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategies, - replicated_penalty); + if (create_replicated_strategies || strategy_group->leaf_vector.empty()) { + AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, + strategy_group, replicated_penalty); } // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies // and only keep the data parallel strategies. if (option.force_batch_dim_to_mesh_dim >= 0 && batch_dim_map.contains(GetBatchDimMapKey(ins))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins, shape, strategies, cluster_env, + TF_RETURN_IF_ERROR(FilterStrategy(ins, shape, strategy_group, cluster_env, batch_dim_map, option)); } } else if (shape.IsToken()) { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); - AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategies, + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + leaf_strategies); + AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, replicated_penalty); } else { LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString(); } - return strategies; + return strategy_group; } -StatusOr> CreateParameterStrategyVector( +StatusOr> CreateParameterStrategyGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, @@ -1251,17 +1249,17 @@ bool ShardingIsConsistent(const HloSharding& partial_sharding, // HloSharding. // These two are distinguished by ShardingIsComplete(). void TrimOrGenerateStrategiesBasedOnExistingSharding( - const Shape& output_shape, StrategyVector* strategies, + const Shape& output_shape, StrategyGroup* strategy_group, const StrategyMap& strategy_map, const std::vector instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableHashMap>& pretrimmed_strategy_map, const CallGraph& call_graph, bool strict) { - if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); ++i) { + if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); ++i) { TrimOrGenerateStrategiesBasedOnExistingSharding( - output_shape.tuple_shapes(i), strategies->childs.at(i).get(), + output_shape.tuple_shapes(i), strategy_group->childs.at(i).get(), strategy_map, instructions, existing_sharding.tuple_elements().at(i), cluster_env, pretrimmed_strategy_map, call_graph, strict); } @@ -1269,10 +1267,11 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( if (ShardingIsComplete(existing_sharding, cluster_env.device_mesh_.num_elements())) { // Sharding provided by XLA users, we need to keep them. - strategies->following = nullptr; + strategy_group->following = nullptr; std::vector strategy_indices; - for (size_t i = 0; i < strategies->leaf_vector.size(); i++) { - if (strategies->leaf_vector[i].output_sharding == existing_sharding) { + for (size_t i = 0; i < strategy_group->leaf_vector.size(); i++) { + if (strategy_group->leaf_vector[i].output_sharding == + existing_sharding) { strategy_indices.push_back(i); } } @@ -1281,27 +1280,28 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( << spmd::ToString(strategy_indices); // Stores other strategies in the map, removes them in the vector and // only keeps the one we found. - pretrimmed_strategy_map[strategies->node_idx] = strategies->leaf_vector; + pretrimmed_strategy_map[strategy_group->node_idx] = + strategy_group->leaf_vector; std::vector new_leaf_vector; for (int32_t found_strategy_index : strategy_indices) { ShardingStrategy found_strategy = - strategies->leaf_vector[found_strategy_index]; + strategy_group->leaf_vector[found_strategy_index]; new_leaf_vector.push_back(found_strategy); } - strategies->leaf_vector.clear(); - strategies->leaf_vector = new_leaf_vector; + strategy_group->leaf_vector.clear(); + strategy_group->leaf_vector = new_leaf_vector; } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); std::vector> resharding_costs; std::vector> input_shardings; - if (strategies->in_nodes.empty()) { + if (strategy_group->in_nodes.empty()) { resharding_costs = {}; } else { - HloInstruction* ins = instructions.at(strategies->instruction_id); - for (size_t i = 0; i < strategies->in_nodes.size(); i++) { + HloInstruction* ins = instructions.at(strategy_group->instruction_id); + for (size_t i = 0; i < strategy_group->in_nodes.size(); i++) { HloInstruction* operand = - instructions.at(strategies->in_nodes.at(i)->instruction_id); + instructions.at(strategy_group->in_nodes.at(i)->instruction_id); std::optional input_sharding_or = ShardingPropagation::GetShardingFromUser(*operand, *ins, 10, true, call_graph); @@ -1309,29 +1309,29 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( input_shardings.push_back(input_sharding_or.value()); } - StrategyVector* operand_strategies; + StrategyGroup* operand_strategy_group; Shape operand_shape; if (ins->opcode() == HloOpcode::kGetTupleElement) { - operand_strategies = + operand_strategy_group = strategy_map.at(operand)->childs[ins->tuple_index()].get(); operand_shape = operand->shape().tuple_shapes(ins->tuple_index()); } else { - operand_strategies = strategy_map.at(operand).get(); + operand_strategy_group = strategy_map.at(operand).get(); operand_shape = operand->shape(); } resharding_costs.push_back( - ReshardingCostVector(operand_strategies, operand_shape, + ReshardingCostVector(operand_strategy_group, operand_shape, existing_sharding, cluster_env)); } } double memory_cost = GetBytes(output_shape) / existing_sharding.NumTiles(); - if (!strategies->leaf_vector.empty()) { - pretrimmed_strategy_map[strategies->node_idx] = - strategies->leaf_vector; + if (!strategy_group->leaf_vector.empty()) { + pretrimmed_strategy_map[strategy_group->node_idx] = + strategy_group->leaf_vector; } - strategies->leaf_vector.clear(); - strategies->leaf_vector.push_back( + strategy_group->leaf_vector.clear(); + strategy_group->leaf_vector.push_back( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, resharding_costs, input_shardings})); } @@ -1339,23 +1339,23 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. - if (strategies->leaf_vector.size() == 1) { + if (strategy_group->leaf_vector.size() == 1) { for (auto& operand_resharding_costs : - strategies->leaf_vector[0].resharding_costs) { + strategy_group->leaf_vector[0].resharding_costs) { if (operand_resharding_costs.size() == 1 && operand_resharding_costs[0] >= kInfinityCost) { operand_resharding_costs[0] = 0; } } } - } else if (!strategies->following) { + } else if (!strategy_group->following) { // If existing sharding is a partial sharding from previous iteration, // find the strategies that are 1D&&complete or align with user // sharding. // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. std::vector new_vector; - for (const auto& strategy : strategies->leaf_vector) { + for (const auto& strategy : strategy_group->leaf_vector) { if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -1372,29 +1372,30 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // not have to strictly keep those shardings and the only purpose is to // reduce problem size for the last iteration. if (!new_vector.empty() && - new_vector.size() != strategies->leaf_vector.size()) { - strategies->following = nullptr; - strategies->leaf_vector = std::move(new_vector); + new_vector.size() != strategy_group->leaf_vector.size()) { + strategy_group->following = nullptr; + strategy_group->leaf_vector = std::move(new_vector); } } } } -void CheckMemoryCosts(StrategyVector* strategies, const Shape& shape) { - if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); i++) { - CheckMemoryCosts(strategies->childs[i].get(), shape.tuple_shapes().at(i)); +void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { + if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { + CheckMemoryCosts(strategy_group->childs[i].get(), + shape.tuple_shapes().at(i)); } } else { double full_mem = 0.0; - for (const auto& strategy : strategies->leaf_vector) { + for (const auto& strategy : strategy_group->leaf_vector) { if (strategy.output_sharding.IsReplicated()) { full_mem = strategy.memory_cost; size_t size = GetInstructionSize(shape); CHECK_EQ(strategy.memory_cost, size); } } - for (const auto& strategy : strategies->leaf_vector) { + for (const auto& strategy : strategy_group->leaf_vector) { if (!strategy.output_sharding.IsReplicated() && full_mem > 0.0) { CHECK_EQ(strategy.memory_cost * strategy.output_sharding.NumTiles(), full_mem); @@ -1404,16 +1405,17 @@ void CheckMemoryCosts(StrategyVector* strategies, const Shape& shape) { } void RemoveInvalidShardingsWithShapes(const Shape& shape, - StrategyVector* strategies, + StrategyGroup* strategy_group, bool instruction_has_user_sharding) { - if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); i++) { + if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { RemoveInvalidShardingsWithShapes(shape.tuple_shapes().at(i), - strategies->childs[i].get(), + strategy_group->childs[i].get(), instruction_has_user_sharding); } } else { - if (instruction_has_user_sharding && strategies->leaf_vector.size() == 1) { + if (instruction_has_user_sharding && + strategy_group->leaf_vector.size() == 1) { // If an instruction has a specified user sharding, and there is only a // single strategy, removing that strategy would mean we won't have any // strategy for that instruction. Further, given that the user has @@ -1422,7 +1424,7 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, return; } std::vector new_vector; - for (const auto& strategy : strategies->leaf_vector) { + for (const auto& strategy : strategy_group->leaf_vector) { if (strategy.output_sharding.IsReplicated()) { new_vector.push_back(strategy); continue; @@ -1442,49 +1444,50 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, new_vector.push_back(strategy); } } - strategies->leaf_vector = std::move(new_vector); + strategy_group->leaf_vector = std::move(new_vector); } } -void CheckReshardingCostsShape(StrategyVector* strategies) { - if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); i++) { - CheckReshardingCostsShape(strategies->childs[i].get()); +void CheckReshardingCostsShape(StrategyGroup* strategy_group) { + if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { + CheckReshardingCostsShape(strategy_group->childs[i].get()); } } else { - for (const auto& strategy : strategies->leaf_vector) { - if (strategies->in_nodes.size() == 1 && - strategies->in_nodes.at(0)->is_tuple) { + for (const auto& strategy : strategy_group->leaf_vector) { + if (strategy_group->in_nodes.size() == 1 && + strategy_group->in_nodes.at(0)->is_tuple) { // This is when current instruction's only operand is tuple, and the // first dimension of resharding_costs should equal its number of // tuple elements. CHECK_EQ(strategy.resharding_costs.size(), - strategies->in_nodes.at(0)->childs.size()) - << "Instruction ID: " << strategies->instruction_id << "\n" - << strategies->ToString(); + strategy_group->in_nodes.at(0)->childs.size()) + << "Instruction ID: " << strategy_group->instruction_id << "\n" + << strategy_group->ToString(); } else { // The rest of the time, the first dimension of resharding_costs // should equal its number of operands (in_nodes). - CHECK_EQ(strategy.resharding_costs.size(), strategies->in_nodes.size()) - << "Instruction ID: " << strategies->instruction_id << "\n" - << strategies->ToString(); + CHECK_EQ(strategy.resharding_costs.size(), + strategy_group->in_nodes.size()) + << "Instruction ID: " << strategy_group->instruction_id << "\n" + << strategy_group->ToString(); } for (size_t i = 0; i < strategy.resharding_costs.size(); i++) { size_t to_compare; - if (strategies->in_nodes.size() == 1 && - strategies->in_nodes.at(0)->is_tuple) { + if (strategy_group->in_nodes.size() == 1 && + strategy_group->in_nodes.at(0)->is_tuple) { to_compare = - strategies->in_nodes.at(0)->childs.at(i)->leaf_vector.size(); - } else if (strategies->is_tuple) { - to_compare = strategies->in_nodes.at(i)->childs.size(); + strategy_group->in_nodes.at(0)->childs.at(i)->leaf_vector.size(); + } else if (strategy_group->is_tuple) { + to_compare = strategy_group->in_nodes.at(i)->childs.size(); } else { - to_compare = strategies->in_nodes.at(i)->leaf_vector.size(); + to_compare = strategy_group->in_nodes.at(i)->leaf_vector.size(); } CHECK_EQ(strategy.resharding_costs[i].size(), to_compare) << "\nIndex of resharding_costs: " << i - << "\nInstruction ID: " << strategies->instruction_id + << "\nInstruction ID: " << strategy_group->instruction_id << "\nCurrent strategies:\n" - << strategies->ToString(); + << strategy_group->ToString(); } } } @@ -1499,15 +1502,15 @@ bool LeafVectorsAreConsistent(const std::vector& one, return true; } -void ScaleCostsWithExecutionCounts(StrategyVector* strategies, +void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, int64_t execution_count) { - if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); ++i) { - ScaleCostsWithExecutionCounts(strategies->childs[i].get(), + if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); ++i) { + ScaleCostsWithExecutionCounts(strategy_group->childs[i].get(), execution_count); } } else { - for (auto& strategy : strategies->leaf_vector) { + for (auto& strategy : strategy_group->leaf_vector) { strategy.compute_cost *= execution_count; strategy.communication_cost *= execution_count; for (auto i = 0; i < strategy.resharding_costs.size(); ++i) { @@ -1521,7 +1524,7 @@ void ScaleCostsWithExecutionCounts(StrategyVector* strategies, // Enumerates sharding strategies for elementwise operators by following // strategies of an operand of the elementwise op. -std::unique_ptr CreateElementwiseOperatorStrategies( +std::unique_ptr CreateElementwiseOperatorStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, @@ -1529,7 +1532,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( pretrimmed_strategy_map, int64_t max_depth, LeafStrategies& leaf_strategies, AssociativeDotPairs& associative_dot_pairs) { - std::unique_ptr strategies = CreateLeafStrategyVector( + std::unique_ptr strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, leaf_strategies); // Choose an operand to follow @@ -1539,20 +1542,20 @@ std::unique_ptr CreateElementwiseOperatorStrategies( ChooseOperandToFollow(strategy_map, depth_map, alias_map, max_depth, ins); if (!tie || AllowTieFollowing(ins)) { - strategies->following = strategy_map.at(ins->operand(follow_idx)).get(); + strategy_group->following = strategy_map.at(ins->operand(follow_idx)).get(); } else { - strategies->following = nullptr; + strategy_group->following = nullptr; } // Get all possible sharding specs from operands for (int64_t i = 0; i < ins->operand_count(); ++i) { - if (strategies->following != nullptr && i != follow_idx) { + if (strategy_group->following != nullptr && i != follow_idx) { // If ins follows one operand, do not consider sharding specs from // other operands. continue; } - auto process_src_strategies = + auto process_src_strategy_group = [&](const std::vector& src_strategies_leaf_vector) { for (int64_t sid = 0; sid < src_strategies_leaf_vector.size(); ++sid) { @@ -1571,18 +1574,18 @@ std::unique_ptr CreateElementwiseOperatorStrategies( input_shardings.push_back(output_spec); } - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } }; - StrategyVector* src_strategies = strategy_map.at(ins->operand(i)).get(); - CHECK(!src_strategies->is_tuple); + StrategyGroup* src_strategy_group = strategy_map.at(ins->operand(i)).get(); + CHECK(!src_strategy_group->is_tuple); - process_src_strategies(src_strategies->leaf_vector); - if (pretrimmed_strategy_map.contains(src_strategies->node_idx)) { - process_src_strategies( - pretrimmed_strategy_map.at(src_strategies->node_idx)); + process_src_strategy_group(src_strategy_group->leaf_vector); + if (pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { + process_src_strategy_group( + pretrimmed_strategy_map.at(src_strategy_group->node_idx)); } } if (ins->opcode() == HloOpcode::kAdd) { @@ -1600,19 +1603,19 @@ std::unique_ptr CreateElementwiseOperatorStrategies( strategy_map.at(ins->operand(1)).get()}); } } - return strategies; + return strategy_group; } // Enumerates sharding strategies for reshape operators. The function does so by // essentially reshaping the sharding of the operand in a manner similar to the // tensor reshape itself. -std::unique_ptr CreateReshapeStrategies( +std::unique_ptr CreateReshapeStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, bool only_allow_divisible, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, LeafStrategies& leaf_strategies) { - std::unique_ptr strategies = CreateLeafStrategyVector( + std::unique_ptr strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, leaf_strategies); const HloInstruction* operand = ins->operand(0); const Array& device_mesh = cluster_env.device_mesh_; @@ -1621,15 +1624,15 @@ std::unique_ptr CreateReshapeStrategies( int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); if (mesh_nn_dims < 2 || !option.allow_mixed_mesh_shape) { // Create follow strategies - const StrategyVector* src_strategies = strategy_map.at(operand).get(); - CHECK(!src_strategies->is_tuple); - strategies->following = src_strategies; + const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); ++sid) { std::optional output_spec = hlo_sharding_util::ReshapeSharding( operand->shape(), ins->shape(), - src_strategies->leaf_vector[sid].output_sharding); + src_strategy_group->leaf_vector[sid].output_sharding); if (!output_spec.has_value()) { continue; @@ -1646,54 +1649,56 @@ std::unique_ptr CreateReshapeStrategies( double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); std::vector resharding_costs = ReshardingCostVector( - src_strategies, operand->shape(), - src_strategies->leaf_vector[sid].output_sharding, cluster_env); - strategies->leaf_vector.push_back(ShardingStrategy( + src_strategy_group, operand->shape(), + src_strategy_group->leaf_vector[sid].output_sharding, cluster_env); + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, *output_spec, compute_cost, communication_cost, memory_cost, {resharding_costs}, - {src_strategies->leaf_vector[sid].output_sharding}})); + {src_strategy_group->leaf_vector[sid].output_sharding}})); } } // Fail to create follow strategies, enumerate all possible cases - if (strategies->leaf_vector.empty()) { - strategies->leaf_vector.clear(); - strategies->following = nullptr; + if (strategy_group->leaf_vector.empty()) { + strategy_group->leaf_vector.clear(); + strategy_group->following = nullptr; // Split 1 dim if (cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env, - strategy_map, strategies, + strategy_map, strategy_group, only_allow_divisible, ""); } if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { // Split 1 dim, but for 1d mesh EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env, - strategy_map, strategies, + strategy_map, strategy_group, only_allow_divisible, " 1d"); } if (cluster_env.IsDeviceMesh2D()) { // Split 2 dim, one is always the batch dim EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategies, only_allow_divisible, + batch_dim_map, strategy_group, + only_allow_divisible, /*partitions*/ 2); } if (cluster_env.IsDeviceMesh3D()) { // Split 3 dim, one is always the batch dim EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, strategies, only_allow_divisible, + batch_dim_map, strategy_group, + only_allow_divisible, /*partitions*/ 3); } // Replicate AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty); + strategy_group, replicated_penalty); } - return strategies; + return strategy_group; } // NOLINTBEGIN(readability/fn_size) @@ -1740,7 +1745,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, const HloInstruction* ins = instructions[instruction_id]; VLOG(2) << "instruction_id = " << instruction_id << ": " << ToAdaptiveString(ins); - std::unique_ptr strategies; + std::unique_ptr strategy_group; HloOpcode opcode = ins->opcode(); @@ -1763,29 +1768,31 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kParameter: case HloOpcode::kRngBitGenerator: case HloOpcode::kRng: { - strategies = CreateParameterStrategyVector( - ins, ins->shape(), instruction_id, leaf_strategies, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible) - .value(); + strategy_group = + CreateParameterStrategyGroup( + ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible) + .value(); break; } case HloOpcode::kConstant: { - strategies = CreateLeafStrategyVectorWithoutInNodes(instruction_id, - leaf_strategies); + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + leaf_strategies); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, 0); + strategy_group, 0); break; } case HloOpcode::kScatter: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); // We follow the first operand (the array we're scattering into) - auto src_strategies = strategy_map.at(ins->operand(0)).get(); - CHECK(!src_strategies->is_tuple); - for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); + CHECK(!src_strategy_group->is_tuple); + for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + ++sid) { HloSharding output_spec = - src_strategies->leaf_vector[sid].output_sharding; + src_strategy_group->leaf_vector[sid].output_sharding; std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); @@ -1801,15 +1808,15 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(sharding_optional.has_value()); } - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); } break; } case HloOpcode::kGather: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); const HloInstruction* indices = ins->operand(1); const Shape& shape = ins->shape(); for (int32_t index_dim = 0; index_dim < indices->shape().rank(); @@ -1848,16 +1855,17 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); } } - auto src_strategies = strategy_map.at(ins->operand(0)).get(); - for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); + for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + ++sid) { HloSharding output_spec = - src_strategies->leaf_vector[sid].output_sharding; + src_strategy_group->leaf_vector[sid].output_sharding; auto gather_parallel_dims = hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); absl::Span operand_parallel_dims; @@ -1882,68 +1890,71 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, *maybe_from_data, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategies->leaf_vector.push_back(ShardingStrategy( + strategy_group->leaf_vector.push_back(ShardingStrategy( {name, *maybe_from_data, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); } AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategies, 0, + ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0, /* operands_to_consider_all_strategies_for */ {0}); break; } case HloOpcode::kBroadcast: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); const HloInstruction* operand = ins->operand(0); - const StrategyVector* operand_strategies = + const StrategyGroup* operand_strategies = strategy_map.at(operand).get(); CHECK(!operand_strategies->is_tuple); if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategies, + cluster_env, strategy_map, strategy_group, only_allow_divisible, "", call_graph); } else { EnumerateAllPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategies, + cluster_env, strategy_map, strategy_group, batch_dim_map, only_allow_divisible, call_graph, /*partitions*/ 2); if (option.allow_mixed_mesh_shape) { EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_1d_, cluster_env, - strategy_map, strategies, + strategy_map, strategy_group, only_allow_divisible, "1d", call_graph); } } AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty); + strategy_group, replicated_penalty); break; } case HloOpcode::kReshape: { - strategies = CreateReshapeStrategies(instruction_id, ins, strategy_map, - cluster_env, only_allow_divisible, - replicated_penalty, batch_dim_map, - option, leaf_strategies); + strategy_group = CreateReshapeStrategies( + instruction_id, ins, strategy_map, cluster_env, + only_allow_divisible, replicated_penalty, batch_dim_map, option, + leaf_strategies); break; } case HloOpcode::kTranspose: case HloOpcode::kReverse: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); const HloInstruction* operand = ins->operand(0); // Create follow strategies - const StrategyVector* src_strategies = strategy_map.at(operand).get(); - CHECK(!src_strategies->is_tuple); - strategies->following = src_strategies; + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + ++sid) { HloSharding output_spec = Undefined(); - auto input_spec = src_strategies->leaf_vector[sid].output_sharding; + auto input_spec = + src_strategy_group->leaf_vector[sid].output_sharding; if (opcode == HloOpcode::kTranspose) { output_spec = hlo_sharding_util::TransposeSharding( input_spec, ins->dimensions()); @@ -1956,8 +1967,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); auto resharding_costs = ReshardingCostVector( - src_strategies, operand->shape(), input_spec, cluster_env); - strategies->leaf_vector.push_back( + src_strategy_group, operand->shape(), input_spec, cluster_env); + strategy_group->leaf_vector.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -1975,8 +1986,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); int64_t follow_idx; switch (opcode) { // TODO(yuemmawang) Re-evaluate the follow_idx choices for the @@ -2003,14 +2014,15 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } // Create follow strategies const HloInstruction* operand = ins->operand(follow_idx); - StrategyVector* src_strategies = strategy_map.at(operand).get(); - CHECK(!src_strategies->is_tuple); - strategies->following = src_strategies; + StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); + CHECK(!src_strategy_group->is_tuple); + strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategies->leaf_vector.size(); ++sid) { + for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + ++sid) { std::optional output_spec; HloSharding input_spec = - src_strategies->leaf_vector[sid].output_sharding; + src_strategy_group->leaf_vector[sid].output_sharding; // Find output shardings. switch (opcode) { @@ -2054,7 +2066,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, ins, *output_spec, strategy_map, cluster_env, call_graph, input_shardings); - strategies->leaf_vector.push_back( + strategy_group->leaf_vector.push_back( ShardingStrategy({name, *output_spec, compute_cost, @@ -2064,17 +2076,17 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, {input_spec}})); } - if (strategies->leaf_vector.empty()) { - strategies->following = nullptr; + if (strategy_group->leaf_vector.empty()) { + strategy_group->following = nullptr; AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, 0); + strategy_group, 0); } break; } case HloOpcode::kOptimizationBarrier: { auto operand_strategies = strategy_map.at(ins->operand(0)).get(); - strategies = MaybeFollowInsStrategyVector( + strategy_group = MaybeFollowInsStrategyGroup( operand_strategies, ins->shape(), instruction_id, /* have_memory_cost */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); @@ -2082,12 +2094,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kBitcast: { if (ins->shape() == ins->operand(0)->shape()) { - strategies = CreateElementwiseOperatorStrategies( + strategy_group = CreateElementwiseOperatorStrategies( instruction_id, ins, strategy_map, cluster_env, depth_map, alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies, associative_dot_pairs); } else { - strategies = CreateReshapeStrategies( + strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, only_allow_divisible, replicated_penalty, batch_dim_map, option, leaf_strategies); @@ -2146,7 +2158,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Ternary elementwise operations. case HloOpcode::kSelect: case HloOpcode::kClamp: { - strategies = CreateElementwiseOperatorStrategies( + strategy_group = CreateElementwiseOperatorStrategies( instruction_id, ins, strategy_map, cluster_env, depth_map, alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies, associative_dot_pairs); @@ -2158,61 +2170,61 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_map, leaf_strategies, cluster_env, option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); if (strategies_status.ok()) { - strategies = std::move(strategies_status.value()); + strategy_group = std::move(strategies_status.value()); } else { return strategies_status.status(); } break; } case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot(strategies, leaf_strategies, strategy_map, - ins, instruction_id, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleDot( + strategy_group, leaf_strategies, strategy_map, ins, instruction_id, + cluster_env, batch_dim_map, option, call_graph)); if (option.allow_replicated_strategy_for_dot_and_conv) { AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategies, + ins, ins->shape(), cluster_env, strategy_map, strategy_group, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, sequence, hlo_cost_analysis)); } break; } case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv(strategies, leaf_strategies, strategy_map, - ins, instruction_id, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleConv( + strategy_group, leaf_strategies, strategy_map, ins, instruction_id, + cluster_env, batch_dim_map, option, call_graph)); if (option.allow_replicated_strategy_for_dot_and_conv) { AddReplicatedStrategy( - ins, ins->shape(), cluster_env, strategy_map, strategies, + ins, ins->shape(), cluster_env, strategy_map, strategy_group, GetDotConvReplicationPenalty(ins, instruction_id, /* window */ 10, sequence, hlo_cost_analysis)); } break; } case HloOpcode::kRngGetAndUpdateState: { - strategies = CreateLeafStrategyVectorWithoutInNodes(instruction_id, - leaf_strategies); + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + leaf_strategies); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, 0); + strategy_group, 0); break; } case HloOpcode::kIota: { - strategies = CreateLeafStrategyVectorWithoutInNodes(instruction_id, - leaf_strategies); + strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, + leaf_strategies); if (cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategies, + strategy_map, strategy_group, only_allow_divisible, "", call_graph); } if (cluster_env.IsDeviceMesh2D()) { // Split 2 dims EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategies, batch_dim_map, + strategy_map, strategy_group, batch_dim_map, only_allow_divisible, call_graph, /*parts*/ 2); } if (cluster_env.IsDeviceMesh3D()) { // Split 3 dims EnumerateAllPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategies, batch_dim_map, + strategy_map, strategy_group, batch_dim_map, only_allow_divisible, call_graph, /*parts*/ 3); } if (cluster_env.IsDeviceMesh2D() && option.allow_mixed_mesh_shape) { @@ -2220,37 +2232,39 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // For example, when the mesh shape is (2, 4), we add strategies for // mesh shape (1, 8) here in addition. EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, - cluster_env, strategy_map, strategies, + cluster_env, strategy_map, strategy_group, only_allow_divisible, " 1d", call_graph); } // Replicate AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty * 5); + strategy_group, replicated_penalty * 5); break; } case HloOpcode::kTuple: { - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(ins->operand_count()); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->operand_count()); for (size_t i = 0; i < ins->operand_count(); ++i) { const HloInstruction* operand = ins->operand(i); - const StrategyVector* src_strategies = strategy_map.at(operand).get(); - auto child_strategies = MaybeFollowInsStrategyVector( - src_strategies, operand->shape(), instruction_id, + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group, operand->shape(), instruction_id, /* have_memory_cost= */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategies->childs.push_back(std::move(child_strategies)); + strategy_group->childs.push_back(std::move(child_strategies)); } break; } case HloOpcode::kGetTupleElement: { const HloInstruction* operand = ins->operand(0); - const StrategyVector* src_strategies = strategy_map.at(operand).get(); - CHECK(src_strategies->is_tuple); - strategies = MaybeFollowInsStrategyVector( - src_strategies->childs[ins->tuple_index()].get(), ins->shape(), + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), instruction_id, /* have_memory_cost= */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); @@ -2263,20 +2277,22 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, operands_to_consider_all_strategies_for = {}) { if (ins->shape().IsTuple()) { if (only_replicated) { - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(ins->shape().tuple_shapes_size()); + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve( + ins->shape().tuple_shapes_size()); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { - std::unique_ptr child_strategies = - CreateLeafStrategyVector(instruction_id, ins, - strategy_map, leaf_strategies); + std::unique_ptr child_strategies = + CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, strategy_map, child_strategies, replicated_penalty); - strategies->childs.push_back(std::move(child_strategies)); + strategy_group->childs.push_back( + std::move(child_strategies)); } } else { - strategies = + strategy_group = CreateAllStrategiesVector( ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, strategy_map, option, replicated_penalty, @@ -2285,13 +2301,13 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } } else { if (only_replicated) { - strategies = CreateLeafStrategyVector( + strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, leaf_strategies); AddReplicatedStrategy(ins, ins->shape(), cluster_env, - strategy_map, strategies, + strategy_map, strategy_group, replicated_penalty); } else { - strategies = + strategy_group = CreateAllStrategiesVector( ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, strategy_map, option, replicated_penalty, @@ -2303,10 +2319,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, if (IsCustomCallMarker(ins)) { const HloInstruction* operand = ins->operand(0); - const StrategyVector* src_strategies = strategy_map.at(operand).get(); - CHECK(src_strategies->is_tuple); - strategies = MaybeFollowInsStrategyVector( - src_strategies, ins->shape(), instruction_id, + const StrategyGroup* src_strategy_group = + strategy_map.at(operand).get(); + CHECK(src_strategy_group->is_tuple); + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, /* have_memory_cost= */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); } else if (ins->has_sharding()) { @@ -2318,10 +2335,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Follows operand 0's strategies if this custom-call op is // shardable and has the same input and output sizes. const HloInstruction* operand = ins->operand(0); - const StrategyVector* src_strategies = + const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); - strategies = MaybeFollowInsStrategyVector( - src_strategies, ins->shape(), instruction_id, + strategy_group = MaybeFollowInsStrategyGroup( + src_strategy_group, ins->shape(), instruction_id, /* have_memory_cost= */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); } @@ -2334,18 +2351,18 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, break; } case HloOpcode::kWhile: { - strategies = CreateTupleStrategyVector(instruction_id); - strategies->childs.reserve(ins->shape().tuple_shapes_size()); - const StrategyVector* src_strategies = + strategy_group = CreateTupleStrategyGroup(instruction_id); + strategy_group->childs.reserve(ins->shape().tuple_shapes_size()); + const StrategyGroup* src_strategy_group = strategy_map.at(ins->operand(0)).get(); for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { - auto child_strategies = MaybeFollowInsStrategyVector( - src_strategies->childs[i].get(), + auto child_strategies = MaybeFollowInsStrategyGroup( + src_strategy_group->childs[i].get(), ins->shape().tuple_shapes().at(i), instruction_id, /* have_memory_cost= */ true, leaf_strategies, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; - strategies->childs.push_back(std::move(child_strategies)); + strategy_group->childs.push_back(std::move(child_strategies)); } break; @@ -2353,90 +2370,93 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kSort: { - strategies = CreateAllStrategiesVector( - ins, ins->shape(), instruction_id, leaf_strategies, - cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, - /*create_replicated_strategies*/ true) - .value(); + strategy_group = + CreateAllStrategiesVector( + ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, + strategy_map, option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /*create_replicated_strategies*/ true) + .value(); break; } case HloOpcode::kOutfeed: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty); + strategy_group, replicated_penalty); break; } case HloOpcode::kAfterAll: { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, + strategy_map, leaf_strategies); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty); + strategy_group, replicated_penalty); break; } default: LOG(FATAL) << "Unhandled instruction: " + ins->ToString(); } - RemoveDuplicatedStrategy(strategies); + RemoveDuplicatedStrategy(strategy_group); if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { // Finds the sharding strategy that aligns with the given sharding spec // Do not merge nodes if this one instruction has annotations. TrimOrGenerateStrategiesBasedOnExistingSharding( - ins->shape(), strategies.get(), strategy_map, instructions, + ins->shape(), strategy_group.get(), strategy_map, instructions, ins->sharding(), cluster_env, pretrimmed_strategy_map, call_graph, option.nd_sharding_iteratively_strict_search_space); } - if (!strategies->is_tuple && strategies->following) { + if (!strategy_group->is_tuple && strategy_group->following) { if (!LeafVectorsAreConsistent( - strategies->leaf_vector, strategies->following->leaf_vector, + strategy_group->leaf_vector, + strategy_group->following->leaf_vector, /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { // It confuses the solver if two instructions have different number of // sharding strategies but share the same ILP variable. The solver // would run much longer and/or return infeasible solutions. // So if two strategies' leaf_vectors are inconsistent, we unfollow // them. - strategies->following = nullptr; + strategy_group->following = nullptr; } - } else if (strategies->is_tuple) { - for (size_t i = 0; i < strategies->childs.size(); i++) { - if (strategies->childs.at(i)->following && + } else if (strategy_group->is_tuple) { + for (size_t i = 0; i < strategy_group->childs.size(); i++) { + if (strategy_group->childs.at(i)->following && !LeafVectorsAreConsistent( - strategies->childs.at(i)->leaf_vector, - strategies->childs.at(i)->following->leaf_vector, + strategy_group->childs.at(i)->leaf_vector, + strategy_group->childs.at(i)->following->leaf_vector, /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { - strategies->childs.at(i)->following = nullptr; + strategy_group->childs.at(i)->following = nullptr; } } } RemoveInvalidShardingsWithShapes( - ins->shape(), strategies.get(), + ins->shape(), strategy_group.get(), /* instruction_has_user_sharding */ ins->has_sharding()); if (instruction_execution_counts.contains(ins)) { - ScaleCostsWithExecutionCounts(strategies.get(), + ScaleCostsWithExecutionCounts(strategy_group.get(), instruction_execution_counts.at(ins)); } else { VLOG(5) << "No execution count available for " << ins->name(); } - XLA_VLOG_LINES(2, absl::StrCat("strategies:\n", strategies->ToString())); + XLA_VLOG_LINES(2, + absl::StrCat("strategies:\n", strategy_group->ToString())); // Debug options: forcibly set the strategy of some instructions. if (option.force_strategy) { std::vector inst_indices = option.force_strategy_inst_indices; std::vector stra_names = option.force_strategy_stra_names; CHECK_EQ(inst_indices.size(), stra_names.size()); - auto it = absl::c_find(inst_indices, strategies->node_idx); + auto it = absl::c_find(inst_indices, strategy_group->node_idx); if (it != inst_indices.end()) { - CHECK(!strategies->is_tuple); + CHECK(!strategy_group->is_tuple); std::vector new_leaf_vector; int64_t idx = it - inst_indices.begin(); - for (const auto& stra : strategies->leaf_vector) { + for (const auto& stra : strategy_group->leaf_vector) { if (stra.name == stra_names[idx]) { new_leaf_vector.push_back(stra); } } - strategies->leaf_vector = std::move(new_leaf_vector); + strategy_group->leaf_vector = std::move(new_leaf_vector); } } @@ -2447,9 +2467,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // the mesh shape we're trying does not match with the mesh shape used in // user specified shardings. So we disable the check in that situation. if (!trying_multiple_mesh_shapes) { - CHECK(strategies->is_tuple || !strategies->leaf_vector.empty()) + CHECK(strategy_group->is_tuple || !strategy_group->leaf_vector.empty()) << ins->ToString() << " does not have any valid strategies."; - } else if (!(strategies->is_tuple || !strategies->leaf_vector.empty())) { + } else if (!(strategy_group->is_tuple || + !strategy_group->leaf_vector.empty())) { return Status(absl::StatusCode::kFailedPrecondition, "Could not generate any shardings for an instruction due " "to mismatched mesh shapes."); @@ -2457,8 +2478,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Checks the shape of resharding_costs is valid. It will check fail if the // shape is not as expected. // CheckReshardingCostsShape(strategies.get()); - CheckMemoryCosts(strategies.get(), ins->shape()); - strategy_map[ins] = std::move(strategies); + CheckMemoryCosts(strategy_group.get(), ins->shape()); + strategy_map[ins] = std::move(strategy_group); } // end of for loop // If gradient accumulation is used, adjust the cost of all-reduce for @@ -2468,7 +2489,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, std::vector grad_insts = GetGradientComputationInstructions(instructions); for (const HloInstruction* inst : grad_insts) { - StrategyVector* stra_vector = strategy_map[inst].get(); + StrategyGroup* stra_vector = strategy_map[inst].get(); CHECK(!stra_vector->is_tuple); for (auto& stra : stra_vector->leaf_vector) { @@ -2523,8 +2544,9 @@ AutoShardingSolverResult CallSolver( // Serialize node costs int num_nodes_without_default = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - const StrategyVector* strategies = leaf_strategies[node_idx]; - auto instruction_name = instructions.at(strategies->instruction_id)->name(); + const StrategyGroup* strategy_group = leaf_strategies[node_idx]; + auto instruction_name = + instructions.at(strategy_group->instruction_id)->name(); request.instruction_names.push_back( absl::StrCat(instruction_name, " (id: ", node_idx, ")")); std::vector ci, di, mi, pi; @@ -2533,14 +2555,15 @@ AutoShardingSolverResult CallSolver( if (iter != sharding_propagation_solution.end()) { CHECK(iter->second->has_sharding()) << iter->second->ToString(); default_strategy = iter->second->sharding(); - if (strategies->tuple_element_idx) { + if (strategy_group->tuple_element_idx) { const auto& tuple_elements = iter->second->sharding().tuple_elements(); - CHECK_LT(*strategies->tuple_element_idx, tuple_elements.size()); - default_strategy = tuple_elements.at(*strategies->tuple_element_idx); + CHECK_LT(*strategy_group->tuple_element_idx, tuple_elements.size()); + default_strategy = + tuple_elements.at(*strategy_group->tuple_element_idx); } } - for (NodeStrategyIdx j = 0; j < strategies->leaf_vector.size(); ++j) { - const ShardingStrategy& strategy = strategies->leaf_vector[j]; + for (NodeStrategyIdx j = 0; j < strategy_group->leaf_vector.size(); ++j) { + const ShardingStrategy& strategy = strategy_group->leaf_vector[j]; const HloSharding& sharding = strategy.output_sharding; ci.push_back(strategy.compute_cost); di.push_back(strategy.communication_cost + @@ -2551,7 +2574,7 @@ AutoShardingSolverResult CallSolver( if (option.use_sharding_propagation_for_default_shardings && *std::min_element(pi.begin(), pi.end()) > 0) { LOG(WARNING) << "No default strategy for {node_idx " << node_idx - << ", instruction ID " << strategies->instruction_id + << ", instruction ID " << strategy_group->instruction_id << ", instruction name " << instruction_name << "}"; ++num_nodes_without_default; } @@ -2566,14 +2589,16 @@ AutoShardingSolverResult CallSolver( // spec std::vector> new_followers; for (const auto& pair : alias_set) { - const StrategyVector* src_strategies = leaf_strategies[pair.first]; - const StrategyVector* dst_strategies = leaf_strategies[pair.second]; - Matrix raw_cost(src_strategies->leaf_vector.size(), - dst_strategies->leaf_vector.size()); - for (NodeStrategyIdx i = 0; i < src_strategies->leaf_vector.size(); ++i) { - for (NodeStrategyIdx j = 0; j < dst_strategies->leaf_vector.size(); ++j) { - if (src_strategies->leaf_vector[i].output_sharding == - dst_strategies->leaf_vector[j].output_sharding) { + const StrategyGroup* src_strategy_group = leaf_strategies[pair.first]; + const StrategyGroup* dst_strategy_group = leaf_strategies[pair.second]; + Matrix raw_cost(src_strategy_group->leaf_vector.size(), + dst_strategy_group->leaf_vector.size()); + for (NodeStrategyIdx i = 0; i < src_strategy_group->leaf_vector.size(); + ++i) { + for (NodeStrategyIdx j = 0; j < dst_strategy_group->leaf_vector.size(); + ++j) { + if (src_strategy_group->leaf_vector[i].output_sharding == + dst_strategy_group->leaf_vector[j].output_sharding) { raw_cost(i, j) = 0.0; } else { raw_cost(i, j) = 1.0; @@ -2615,9 +2640,9 @@ AutoShardingSolverResult CallSolver( if (vij[i * col_indices.size() + i] != 0.0) convertable = false; } if (convertable && option.allow_alias_to_follower_conversion) { - new_followers.push_back(std::make_pair(idx_a, idx_b)); + new_followers.push_back({idx_a, idx_b}); } else { - request.a.push_back(std::make_pair(idx_a, idx_b)); + request.a.push_back({idx_a, idx_b}); request.v.push_back(vij); } } @@ -2726,7 +2751,7 @@ void CheckHloSharding(const HloInstructionSequence& sequence, std::string str = absl::StrCat("Shardings not consistent (op size ", op_size, " GB):", ins->ToString(), "\n Operand: ", op->ToString()); - size_string.push_back(std::make_pair(op_size, std::move(str))); + size_string.push_back({op_size, std::move(str)}); } } else { LOG(INFO) << "Instruction " << op->name() @@ -2767,34 +2792,34 @@ void SetHloSharding(const HloInstructionSequence& sequence, continue; } - const StrategyVector* strategies = iter->second.get(); - if (strategies->is_tuple) { + const StrategyGroup* strategy_group = iter->second.get(); + if (strategy_group->is_tuple) { const Shape& out_shape = inst->shape(); ShapeTree output_tuple_sharding(out_shape, Undefined()); std::vector output_flattened_shardings; - std::function extract_tuple_shardings; + std::function extract_tuple_shardings; bool set_tuple_sharding = true; - extract_tuple_shardings = [&](const StrategyVector* strategies) { - if (strategies->is_tuple) { - for (const auto& child_strategies : strategies->childs) { + extract_tuple_shardings = [&](const StrategyGroup* strategy_group) { + if (strategy_group->is_tuple) { + for (const auto& child_strategies : strategy_group->childs) { extract_tuple_shardings(child_strategies.get()); } } else { - NodeIdx node_idx = strategies->node_idx; + NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = s_val[node_idx]; // Do not set completed sharding before the last iteration - if (strategies->leaf_vector[stra_idx] + if (strategy_group->leaf_vector[stra_idx] .output_sharding.IsReplicated() && !last_iteration) { set_tuple_sharding = false; } output_flattened_shardings.push_back( - strategies->leaf_vector[stra_idx].output_sharding); + strategy_group->leaf_vector[stra_idx].output_sharding); } }; - extract_tuple_shardings(strategies); + extract_tuple_shardings(strategy_group); // Create Tuple HloSharding. int i = 0; @@ -3131,20 +3156,20 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, // Print the memory usage std::string str("=== Memory ===\n"); std::vector> time_memory_usage; - // Function that gets the memory usage of a StrategyVector belongs to one + // Function that gets the memory usage of a StrategyGroup belongs to one // tensor. - std::function calculate_memory_usage; - calculate_memory_usage = [&](const StrategyVector* strategies) { - if (strategies->is_tuple) { + std::function calculate_memory_usage; + calculate_memory_usage = [&](const StrategyGroup* strategy_group) { + if (strategy_group->is_tuple) { double m = 0.0; - for (const auto& child : strategies->childs) { + for (const auto& child : strategy_group->childs) { m += calculate_memory_usage(child.get()); } return m; } - NodeIdx ins_idx = strategies->node_idx; + NodeIdx ins_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(ins_idx, s_val[ins_idx]); - const ShardingStrategy& strategy = strategies->leaf_vector[stra_idx]; + const ShardingStrategy& strategy = strategy_group->leaf_vector[stra_idx]; return strategy.memory_cost; }; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { @@ -3161,7 +3186,7 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, " MB; mem=", mem / (1024 * 1024), " MB\n"); } } - time_memory_usage.push_back(std::make_pair(time_idx, mem)); + time_memory_usage.push_back({time_idx, mem}); if (VLOG_IS_ON(6)) { absl::StrAppend(&str, "Time ", time_idx, ": ", mem / (1024 * 1024), " MB\n"); @@ -3191,8 +3216,8 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, const HloInstruction* ins = val->instruction(); auto mem = calculate_memory_usage(strategy_map.at(ins).get()); if (mem > 100 * 1024 * 1024) { - instruction_mem.push_back(std::make_pair( - absl::StrCat(ins->name(), val->index().ToString()), mem)); + instruction_mem.push_back( + {absl::StrCat(ins->name(), val->index().ToString()), mem}); } } } @@ -3887,7 +3912,7 @@ void AnnotateShardingWithSimpleHeuristic( // Filter strategies according to the option.force_batch_dim_to_mesh_dim. // This can be used to forcibly generate data-parallel strategies. Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option) { @@ -3902,7 +3927,7 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, } std::vector new_leaf_vector; - for (auto& stra : strategies->leaf_vector) { + for (auto& stra : strategy_group->leaf_vector) { std::vector tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper(shape, stra.output_sharding); @@ -3922,7 +3947,7 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, } CHECK(!new_leaf_vector.empty()) << ins->ToString() << " does not have any valid strategies"; - strategies->leaf_vector = std::move(new_leaf_vector); + strategy_group->leaf_vector = std::move(new_leaf_vector); return OkStatus(); } @@ -4389,10 +4414,10 @@ StatusOr AutoShardingImplementation::RunAutoSharding( const HloInstruction* instruction = value->instruction(); const ShapeIndex& index = value->index(); if (instruction->shape().IsTuple() && index.empty()) continue; - const spmd::StrategyVector* strategies = + const spmd::StrategyGroup* strategy_group = strategy_map.at(instruction).get(); const spmd::NodeIdx node_idx = - strategies->GetSubStrategyVector(index)->node_idx; + strategy_group->GetSubStrategyGroup(index)->node_idx; if (node_idx >= 0) liveness_node_set[t].push_back(node_idx); } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 9fb51c270da07a..b7f2ea79aac87f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -135,37 +135,37 @@ HloSharding Tile(const Shape& shape, absl::Span tensor_dims, absl::Span mesh_dims, const Array& device_mesh); -std::vector ReshardingCostVector(const StrategyVector* strategies, +std::vector ReshardingCostVector(const StrategyGroup* strategy_group, const Shape& shape, const HloSharding& required_sharding, const ClusterEnvironment& cluster_env); std::vector FollowInsCostVector(int64_t source_len, int64_t index); -std::unique_ptr CreateLeafStrategyVector( +std::unique_ptr CreateLeafStrategyGroup( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, LeafStrategies& leaf_strategies); -void SetInNodesWithInstruction(std::unique_ptr& strategies, +void SetInNodesWithInstruction(std::unique_ptr& strategy_group, const HloInstruction* ins, const StrategyMap& strategy_map); -void RemoveDuplicatedStrategy(std::unique_ptr& strategies); +void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group); Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategies, + std::unique_ptr& strategy_group, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option); -Status HandleDot(std::unique_ptr& strategies, +Status HandleDot(std::unique_ptr& strategy_group, LeafStrategies& leaf_strategies, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); -Status HandleConv(std::unique_ptr& strategies, +Status HandleConv(std::unique_ptr& strategy_group, LeafStrategies& leaf_strategies, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 6102683fe4fdf4..cf24b493f4b1d1 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -24,9 +24,12 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/matrix.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/shape_util.h" namespace xla { namespace spmd { @@ -113,12 +116,12 @@ class CostGraph { } Matrix CreateEdgeCost(NodeIdx src_idx, NodeIdx dst_idx, size_t in_node_idx, - StrategyVector* strategies, bool zero_cost = false) { + StrategyGroup* strategy_group, bool zero_cost = false) { CHECK_LT(src_idx, node_lens_.size()); CHECK_LT(dst_idx, node_lens_.size()); Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx k = 0; k < strategies->leaf_vector.size(); ++k) { - const ShardingStrategy& strategy = strategies->leaf_vector[k]; + for (NodeStrategyIdx k = 0; k < strategy_group->leaf_vector.size(); ++k) { + const ShardingStrategy& strategy = strategy_group->leaf_vector[k]; size_t start_idx = 0; if (strategy.resharding_costs[in_node_idx].size() > node_lens_[src_idx]) { start_idx = @@ -359,11 +362,11 @@ class CostGraph { inline const ShardingStrategy& GetShardingStrategy( const HloInstruction* inst, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val) { - const StrategyVector* strategies = strategy_map.at(inst).get(); - CHECK(!strategies->is_tuple); - NodeIdx node_idx = strategies->node_idx; + const StrategyGroup* strategy_group = strategy_map.at(inst).get(); + CHECK(!strategy_group->is_tuple); + NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategies->leaf_vector[stra_idx]; + return strategy_group->leaf_vector[stra_idx]; } // Get the final sharding strategy according to the ilp solution. @@ -371,16 +374,16 @@ inline const ShardingStrategy& GetShardingStrategyForTuple( const HloInstruction* inst, ShapeIndex index, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val) { - const StrategyVector* tuple_strategies = strategy_map.at(inst).get(); - CHECK(tuple_strategies->is_tuple); + const StrategyGroup* strategy_group = strategy_map.at(inst).get(); + CHECK(strategy_group->is_tuple); for (auto index_element : index) { - CHECK_LT(index_element, tuple_strategies->childs.size()); - const auto& strategies = tuple_strategies->childs[index_element]; - tuple_strategies = strategies.get(); + CHECK_LT(index_element, strategy_group->childs.size()); + const auto& strategies = strategy_group->childs[index_element]; + strategy_group = strategies.get(); } - NodeIdx node_idx = tuple_strategies->node_idx; + NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return tuple_strategies->leaf_vector[stra_idx]; + return strategy_group->leaf_vector[stra_idx]; } } // namespace spmd diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index a006072cc0a9f8..99390479414960 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -57,12 +57,12 @@ struct Enumeration { // Contains base functionality common to both DotHandler and ConvHandler. class HandlerBase { protected: - HandlerBase(std::unique_ptr& strategies, + HandlerBase(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : strategies_(strategies), + : strategy_group_(strategy_group), strategy_map_(strategy_map), ins_(ins), cluster_env_(cluster_env), @@ -87,7 +87,7 @@ class HandlerBase { operand->shape(), input_specs[i], cluster_env_)); } - strategies_->leaf_vector.push_back(ShardingStrategy({ + strategy_group_->leaf_vector.push_back(ShardingStrategy({ name, output_spec, compute_cost, @@ -221,7 +221,7 @@ class HandlerBase { Enumerate(split_func, num_outer_dims, num_inner_dims, true); } - std::unique_ptr& strategies_; + std::unique_ptr& strategy_group_; StrategyMap& strategy_map_; const HloInstruction* ins_; const ClusterEnvironment& cluster_env_; @@ -237,12 +237,12 @@ class HandlerBase { class DotHandler : public HandlerBase { public: - DotHandler(std::unique_ptr& strategies, + DotHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategies, strategy_map, ins, cluster_env, batch_map, + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph), space_base_dim_( ins->dot_dimension_numbers().lhs_batch_dimensions_size()), @@ -601,7 +601,7 @@ class DotHandler : public HandlerBase { cluster_env_.non_zero_mesh_dims_.size() > 1) { // If there is a batch dim and the device mesh is multi-dimensional, // always split on batch dim. Clear all old strategies. - strategies_->leaf_vector.clear(); + strategy_group_->leaf_vector.clear(); } // Sb = Sb x Sb @@ -626,7 +626,7 @@ class DotHandler : public HandlerBase { [](int64_t size) { return size > 1; }) > 1) { // If there are two batch dims, always split on these two dims. // Clear all old strategies. - strategies_->leaf_vector.clear(); + strategy_group_->leaf_vector.clear(); } // Sb = Sb x Sb @@ -641,7 +641,7 @@ class DotHandler : public HandlerBase { // and only keep the data parallel strategies. if (option_.force_batch_dim_to_mesh_dim >= 0 && batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategies_, + TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, cluster_env_, batch_map_, option_)); } @@ -658,17 +658,17 @@ class DotHandler : public HandlerBase { }; // Register strategies for dot instructions. -Status HandleDot(std::unique_ptr& strategies, +Status HandleDot(std::unique_ptr& strategy_group, LeafStrategies& leaf_strategies, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + leaf_strategies); - DotHandler handler(strategies, strategy_map, ins, cluster_env, batch_map, + DotHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return OkStatus(); @@ -676,12 +676,12 @@ Status HandleDot(std::unique_ptr& strategies, class ConvHandler : public HandlerBase { public: - ConvHandler(std::unique_ptr& strategies, + ConvHandler(std::unique_ptr& strategy_group, StrategyMap& strategy_map, const HloInstruction* ins, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) - : HandlerBase(strategies, strategy_map, ins, cluster_env, batch_map, + : HandlerBase(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph), conv_dnums_(ins->convolution_dimension_numbers()) { lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); @@ -837,7 +837,7 @@ class ConvHandler : public HandlerBase { // and only keep the data parallel strategies. if (option_.force_batch_dim_to_mesh_dim >= 0 && batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategies_, + TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, cluster_env_, batch_map_, option_)); } @@ -852,17 +852,17 @@ class ConvHandler : public HandlerBase { }; // Register strategies for dot instructions. -Status HandleConv(std::unique_ptr& strategies, +Status HandleConv(std::unique_ptr& strategy_group, LeafStrategies& leaf_strategies, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { - strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, + leaf_strategies); - ConvHandler handler(strategies, strategy_map, ins, cluster_env, batch_map, + ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return OkStatus(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index cd034821d8f822..2ca280330dca7f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -130,8 +130,9 @@ using EdgeStrategyIdx = int64_t; // An index into an edge's strategy vector. using LivenessIdx = int64_t; // An index into the liveness vector. using AliasIdx = int64_t; // An index into the alias vector. -// The strategy choices for each instruction. -struct StrategyVector { +// A group of strategy choices (along with details like index values) +// for each instruction. +struct StrategyGroup { bool is_tuple; // The index used in the solver. For non-leaf nodes, this is set to -1. NodeIdx node_idx; @@ -141,15 +142,15 @@ struct StrategyVector { // The size must be the same as the size of resharding cost // each element in leaf_vector's resharding_costs.size() needs to be the same // as strategies->in_nodes.size() - std::vector in_nodes; + std::vector in_nodes; // The followed strategy. Used for merging nodes. - const StrategyVector* following = nullptr; + const StrategyGroup* following = nullptr; // Used when is_tuple == False. Leaf strategy vector. // A vector of strategy choices for the non-tuple output. std::vector leaf_vector; // Used when is_tuple == True. A vector of pointers, each pointer is one - // StrategyVector for one value in the output Tuple - std::vector> childs; + // StrategyGroup for one value in the output Tuple + std::vector> childs; // The index of this instruction in the HLO operand (or tuple shape) list. std::optional tuple_element_idx; @@ -183,8 +184,8 @@ struct StrategyVector { return str; } - const StrategyVector* GetSubStrategyVector(const ShapeIndex& index) const { - const StrategyVector* result = this; + const StrategyGroup* GetSubStrategyGroup(const ShapeIndex& index) const { + const StrategyGroup* result = this; for (auto index_element : index) { CHECK_LE(index_element, result->childs.size()); result = result->childs.at(index_element).get(); @@ -199,13 +200,13 @@ using LivenessSet = std::vector>; using LivenessNodeSet = std::vector>; // Map an instruction to its strategy vector. using StrategyMap = - StableHashMap>; + StableHashMap>; // The list of all leaf strategies. -using LeafStrategies = std::vector; +using LeafStrategies = std::vector; // The list of all dot instruction pairs that can be optimized by // AllReduceReassociate pass. using AssociativeDotPairs = - std::vector>; + std::vector>; // The set of all alias pairs using AliasSet = StableHashSet>; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index ad4f3bc35cdae3..fe94eaf695a909 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status.h" @@ -874,40 +875,42 @@ bool AllInfinityCosts( // that were not intended to be replicated when being generating, but ending up // being replicated, which could happen when, for example, generating 2D // sharding for a 1D mesh shape. -void RemoveDuplicatedStrategy(std::unique_ptr& strategies) { - if (strategies->is_tuple) { - for (auto& child : strategies->childs) { +void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { + if (strategy_group->is_tuple) { + for (auto& child : strategy_group->childs) { RemoveDuplicatedStrategy(child); } - } else if (!strategies->following) { - if (strategies->leaf_vector.empty()) return; + } else if (!strategy_group->following) { + if (strategy_group->leaf_vector.empty()) return; std::vector new_vector; std::vector deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (size_t i = 0; i < strategies->leaf_vector.size(); ++i) { - if (AllInfinityCosts(strategies->leaf_vector[i].resharding_costs)) { + for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { + if (AllInfinityCosts(strategy_group->leaf_vector[i].resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } - std::string key = strategies->leaf_vector[i].output_sharding.ToString(); - if (!strategies->leaf_vector[i].input_shardings.empty()) { + std::string key = + strategy_group->leaf_vector[i].output_sharding.ToString(); + if (!strategy_group->leaf_vector[i].input_shardings.empty()) { for (const auto& sharding : - strategies->leaf_vector[i].input_shardings) { + strategy_group->leaf_vector[i].input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } } if (!added.contains(key)) { added.insert(key); - if (!strategies->leaf_vector[i].output_sharding.IsReplicated()) { - new_vector.push_back(std::move(strategies->leaf_vector[i])); + if (!strategy_group->leaf_vector[i].output_sharding.IsReplicated()) { + new_vector.push_back(std::move(strategy_group->leaf_vector[i])); } else { deduped_replicated_strategies.push_back( - std::move(strategies->leaf_vector[i])); + std::move(strategy_group->leaf_vector[i])); } } } - CHECK_LT(num_skipped_due_to_infinity_costs, strategies->leaf_vector.size()) + CHECK_LT(num_skipped_due_to_infinity_costs, + strategy_group->leaf_vector.size()) << "All strategies removed due to infinite resharding costs"; // Keeps replicated strategies as the last ones. if (!deduped_replicated_strategies.empty()) { @@ -915,7 +918,7 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategies) { new_vector.push_back(std::move(deduped_replicated_strategies[i])); } } - strategies->leaf_vector = std::move(new_vector); + strategy_group->leaf_vector = std::move(new_vector); } } @@ -1742,20 +1745,21 @@ AliasSet BuildAliasSet(const HloModule* module, const HloInstruction* output_tuple = entry->root_instruction(); AliasSet alias_set; - std::function + std::function traverse_tuple_alias; - traverse_tuple_alias = [&](const StrategyVector* src_strategies, - const StrategyVector* dst_strategies) { - if (src_strategies->is_tuple) { - CHECK(dst_strategies->is_tuple); - CHECK_EQ(src_strategies->childs.size(), dst_strategies->childs.size()); - for (size_t i = 0; i < src_strategies->childs.size(); ++i) { - traverse_tuple_alias(src_strategies->childs[i].get(), - dst_strategies->childs[i].get()); + traverse_tuple_alias = [&](const StrategyGroup* src_strategy_group, + const StrategyGroup* dst_strategy_group) { + if (src_strategy_group->is_tuple) { + CHECK(dst_strategy_group->is_tuple); + CHECK_EQ(src_strategy_group->childs.size(), + dst_strategy_group->childs.size()); + for (size_t i = 0; i < src_strategy_group->childs.size(); ++i) { + traverse_tuple_alias(src_strategy_group->childs[i].get(), + dst_strategy_group->childs[i].get()); } } else { alias_set.insert( - std::make_pair(src_strategies->node_idx, dst_strategies->node_idx)); + {src_strategy_group->node_idx, dst_strategy_group->node_idx}); } }; alias_config.ForEachAlias([&](const ShapeIndex& output_index, @@ -1825,17 +1829,18 @@ void CheckAliasSetCompatibility(const AliasSet& alias_set, const std::vector& instructions = sequence.instructions(); // Checks the compatibility for (const auto& pair : alias_set) { - const StrategyVector* src_strategies = leaf_strategies[pair.first]; - const StrategyVector* dst_strategies = leaf_strategies[pair.second]; + const StrategyGroup* src_strategy_group = leaf_strategies[pair.first]; + const StrategyGroup* dst_strategy_group = leaf_strategies[pair.second]; size_t compatible_cnt = 0; bool replicated = false; - for (size_t i = 0; i < src_strategies->leaf_vector.size(); ++i) { - for (size_t j = 0; j < dst_strategies->leaf_vector.size(); ++j) { - if (src_strategies->leaf_vector[i].output_sharding == - dst_strategies->leaf_vector[j].output_sharding) { + for (size_t i = 0; i < src_strategy_group->leaf_vector.size(); ++i) { + for (size_t j = 0; j < dst_strategy_group->leaf_vector.size(); ++j) { + if (src_strategy_group->leaf_vector[i].output_sharding == + dst_strategy_group->leaf_vector[j].output_sharding) { compatible_cnt += 1; - if (src_strategies->leaf_vector[i].output_sharding.IsReplicated()) { + if (src_strategy_group->leaf_vector[i] + .output_sharding.IsReplicated()) { replicated = true; } } @@ -1843,32 +1848,31 @@ void CheckAliasSetCompatibility(const AliasSet& alias_set, } if (compatible_cnt == 1 && - (replicated && (src_strategies->leaf_vector.size() > 1 || - dst_strategies->leaf_vector.size() > 1))) { - LOG(WARNING) << "Alias pair has only replicated strategy in common. This " - "will result in choosing replicated strategy for these " - "tensors and may result in large memory consumption: " - << "(" - << instructions.at(src_strategies->instruction_id)->name() - << ", " - << instructions.at(dst_strategies->instruction_id)->name() - << ")" - << "\n" - << "(" << src_strategies->node_idx << ", " - << dst_strategies->node_idx << ")\n" - << src_strategies->ToString() << "\n" - << dst_strategies->ToString(); + (replicated && (src_strategy_group->leaf_vector.size() > 1 || + dst_strategy_group->leaf_vector.size() > 1))) { + LOG(WARNING) + << "Alias pair has only replicated strategy in common. This " + "will result in choosing replicated strategy for these " + "tensors and may result in large memory consumption: " + << "(" << instructions.at(src_strategy_group->instruction_id)->name() + << ", " << instructions.at(dst_strategy_group->instruction_id)->name() + << ")" + << "\n" + << "(" << src_strategy_group->node_idx << ", " + << dst_strategy_group->node_idx << ")\n" + << src_strategy_group->ToString() << "\n" + << dst_strategy_group->ToString(); } CHECK(compatible_cnt > 0) << "Alias pair does not have any sharding strategy in common: " - << "(" << instructions.at(src_strategies->instruction_id)->name() - << ", " << instructions.at(dst_strategies->instruction_id)->name() + << "(" << instructions.at(src_strategy_group->instruction_id)->name() + << ", " << instructions.at(dst_strategy_group->instruction_id)->name() << ")" << "\n" - << "(" << src_strategies->node_idx << ", " << dst_strategies->node_idx - << ")\n" - << src_strategies->ToString() << "\n" - << dst_strategies->ToString(); + << "(" << src_strategy_group->node_idx << ", " + << dst_strategy_group->node_idx << ")\n" + << src_strategy_group->ToString() << "\n" + << dst_strategy_group->ToString(); } } @@ -2015,7 +2019,7 @@ AdjustShardingWithPartialMeshShapePerElement( LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; - return std::make_pair(absl::InternalError(err_msg), std::nullopt); + return {absl::InternalError(err_msg), std::nullopt}; } } } @@ -2028,7 +2032,7 @@ AdjustShardingWithPartialMeshShapePerElement( if (valid_shards.find(sharding.tile_assignment().dim( sharding.tile_assignment().num_dimensions() - 1)) != valid_shards.end()) { - return std::make_pair(OkStatus(), HloSharding::Replicate()); + return {OkStatus(), HloSharding::Replicate()}; } // If replicate on other dimensions, remove the // replicate_on_last_tile @@ -2075,9 +2079,9 @@ AdjustShardingWithPartialMeshShapePerElement( std::iota(device_ids.begin(), device_ids.end(), 0); tile_assignment.SetValues(device_ids); HloSharding new_sharding = HloSharding::Tile(std::move(tile_assignment)); - return std::make_pair(OkStatus(), new_sharding); + return {OkStatus(), new_sharding}; } - return std::make_pair(OkStatus(), std::nullopt); + return {OkStatus(), std::nullopt}; } StatusOr AdjustShardingsWithPartialMeshShape( @@ -2144,7 +2148,7 @@ std::vector> DecomposeMeshShapes( std::vector> partial_mesh_shapes; std::vector> pairs(mesh_shape.size()); for (size_t i = 0; i < mesh_shape.size(); i++) { - pairs[i] = std::make_pair(mesh_shape[i], i); + pairs[i] = {mesh_shape[i], i}; } // For vector of size 3, the sorted indices happen to be the same as their // rankings. mesh_shapes over 3 elements are not supported by AutoSharding. From 6443660d1a24bea25e53aebc6accb1b2c4986581 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 22:00:27 -0800 Subject: [PATCH 158/391] Renames the list of strategy choices to simply 'strategies' (instead of the original 'leaf_vector') PiperOrigin-RevId: 582908221 --- .../auto_sharding/auto_sharding.cc | 246 +++++++++--------- .../auto_sharding/auto_sharding_cost_graph.h | 20 +- .../auto_sharding_dot_handler.cc | 6 +- .../auto_sharding/auto_sharding_strategy.h | 10 +- .../auto_sharding/auto_sharding_util.cc | 36 +-- 5 files changed, 157 insertions(+), 161 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 07d97f4d92a180..30fb3f2a9ba6e6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -95,11 +95,11 @@ std::vector ReshardingCostVector( const ClusterEnvironment& cluster_env) { CHECK(!strategy_group->is_tuple) << "Only works with strategy vector."; std::vector ret; - ret.reserve(strategy_group->leaf_vector.size()); + ret.reserve(strategy_group->strategies.size()); auto required_sharding_for_resharding = required_sharding.IsTileMaximal() ? HloSharding::Replicate() : required_sharding; - for (const auto& x : strategy_group->leaf_vector) { + for (const auto& x : strategy_group->strategies) { ret.push_back(cluster_env.ReshardingCost(operand_shape, x.output_sharding, required_sharding_for_resharding)); } @@ -161,7 +161,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( auto operand = ins->operand(k); if (operand->shape().IsToken() || operand->shape().rank() == 0) { resharding_costs.push_back(std::vector( - strategy_map.at(operand)->leaf_vector.size(), 0.0)); + strategy_map.at(operand)->strategies.size(), 0.0)); if (!input_shardings[k].has_value()) { input_shardings[k] = HloSharding::Replicate(); } @@ -201,7 +201,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( VLOG(2) << "Zeroing out operand 0 resharding costs for gather sharding " << output_sharding.ToString(); resharding_costs.push_back( - std::vector(operand_strategies->leaf_vector.size(), 0)); + std::vector(operand_strategies->strategies.size(), 0)); input_shardings[k] = std::nullopt; } else { resharding_costs.push_back( @@ -261,21 +261,21 @@ std::unique_ptr MaybeFollowInsStrategyGroup( if (!pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { strategy_group->following = src_strategy_group; } - strategy_group->leaf_vector.reserve(src_strategy_group->leaf_vector.size()); + strategy_group->strategies.reserve(src_strategy_group->strategies.size()); // Creates the sharding strategies and restores the trimmed strategies if // there is any. for (int64_t sid = 0; - sid < src_strategy_group->leaf_vector.size() + + sid < src_strategy_group->strategies.size() + pretrimmed_strategy_map[src_strategy_group->node_idx].size(); ++sid) { const HloSharding* output_spec; - if (sid < src_strategy_group->leaf_vector.size()) { - output_spec = &src_strategy_group->leaf_vector[sid].output_sharding; + if (sid < src_strategy_group->strategies.size()) { + output_spec = &src_strategy_group->strategies[sid].output_sharding; } else { output_spec = &pretrimmed_strategy_map[src_strategy_group->node_idx] [sid - - src_strategy_group->leaf_vector.size()] + src_strategy_group->strategies.size()] .output_sharding; VLOG(1) << "Adding outspec from the trimmed strategy map: " << output_spec->ToString(); @@ -286,7 +286,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( have_memory_cost ? GetBytes(shape) / output_spec->NumTiles() : 0; auto resharding_costs = ReshardingCostVector(src_strategy_group, shape, *output_spec, cluster_env); - strategy_group->leaf_vector.push_back( + strategy_group->strategies.push_back( ShardingStrategy({name, *output_spec, compute_cost, @@ -328,7 +328,7 @@ StatusOr> FollowReduceStrategy( const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); // Follows the strategy of the operand. strategy_group->following = src_strategy_group; - strategy_group->leaf_vector.reserve(src_strategy_group->leaf_vector.size()); + strategy_group->strategies.reserve(src_strategy_group->strategies.size()); // Map operand dims to inst dim // Example: f32[1,16]{1,0} reduce(f32[1,16,4096]{2,1,0} %param0, f32[] // %param1), dimensions={2} @@ -340,9 +340,9 @@ StatusOr> FollowReduceStrategy( operand->shape().rank()) << "Invalid kReduce: output size + reduced dimensions size != op count"; - for (size_t sid = 0; sid < src_strategy_group->leaf_vector.size(); ++sid) { + for (size_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { HloSharding input_sharding = - src_strategy_group->leaf_vector[sid].output_sharding; + src_strategy_group->strategies[sid].output_sharding; const auto& tensor_dim_to_mesh = cluster_env.GetTensorDimToMeshDimWrapper( operand->shape(), input_sharding, /* consider_reverse_device_meshes */ true, @@ -371,7 +371,7 @@ StatusOr> FollowReduceStrategy( output_shape, operand_clone.get(), unit_clone.get(), ins->dimensions(), ins->to_apply()); operand_clone->set_sharding( - src_strategy_group->leaf_vector[sid].output_sharding); + src_strategy_group->strategies[sid].output_sharding); auto s = new_reduce->ReplaceOperandWith(0, operand_clone.get()); if (!s.ok()) { continue; @@ -402,7 +402,7 @@ StatusOr> FollowReduceStrategy( operand_strategies, output_shape, input_sharding, cluster_env)); } else { resharding_costs.push_back(std::vector( - strategy_map.at(cur_operand)->leaf_vector.size(), 0.0)); + strategy_map.at(cur_operand)->strategies.size(), 0.0)); } } const ShardingStrategy strategy = ShardingStrategy({name, @@ -412,7 +412,7 @@ StatusOr> FollowReduceStrategy( memory_cost, resharding_costs, {input_sharding}}); - strategy_group->leaf_vector.push_back(strategy); + strategy_group->strategies.push_back(strategy); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -446,13 +446,13 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, auto tuple_element_strategies = operand_strategy_vector->childs.at(tuple_element_idx).get(); std::vector indices = - FindReplicateStrategyIndices(tuple_element_strategies->leaf_vector); + FindReplicateStrategyIndices(tuple_element_strategies->strategies); CHECK_GT(indices.size(), 0) << "There is no replicated strategy in instruction " << operand->ToString() << ".\nStrategies:\n" << tuple_element_strategies->ToString(); resharding_costs.push_back(std::vector( - tuple_element_strategies->leaf_vector.size(), kInfinityCost)); + tuple_element_strategies->strategies.size(), kInfinityCost)); tuple_element_shardings.push_back(HloSharding::Replicate()); for (const size_t i : indices) { resharding_costs.back().at(i) = 0.0; @@ -483,12 +483,12 @@ std::vector> CreateZeroReshardingCostsForAllOperands( auto tuple_element_strategies = operand_strategies->childs.at(tuple_element_idx).get(); resharding_costs.push_back(std::vector( - tuple_element_strategies->leaf_vector.size(), 0)); + tuple_element_strategies->strategies.size(), 0)); } } } else { resharding_costs.push_back( - std::vector(operand_strategies->leaf_vector.size(), 0)); + std::vector(operand_strategies->strategies.size(), 0)); } } return resharding_costs; @@ -535,13 +535,13 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, } else { for (size_t i = 0; i < tuple_size; ++i) { resharding_costs.push_back(std::vector( - strategy_map.at(ins->operand(0))->childs[i].get()->leaf_vector.size(), + strategy_map.at(ins->operand(0))->childs[i].get()->strategies.size(), 0)); } } resharding_costs.push_back({}); double memory_cost = GetBytes(shape) / output_spec.NumTiles(); - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -602,23 +602,23 @@ void AddReplicatedStrategy( auto operand_strategies_to_consider = strategy_map.at(operand).get(); std::vector>> possible_input_shardings( - operand_strategies_to_consider->leaf_vector.size(), + operand_strategies_to_consider->strategies.size(), std::vector>(ins->operand_count())); std::vector>> possible_resharding_costs( - operand_strategies_to_consider->leaf_vector.size(), + operand_strategies_to_consider->strategies.size(), std::vector>(ins->operand_count())); for (int64_t k = 0; k < ins->operand_count(); ++k) { CHECK(!ins->operand(k)->shape().IsTuple()); if (k == operand_to_consider_all_strategies_for) { CHECK_EQ(possible_input_shardings.size(), - operand_strategies_to_consider->leaf_vector.size()); + operand_strategies_to_consider->strategies.size()); for (size_t j = 0; j < possible_input_shardings.size(); ++j) { possible_input_shardings[j][k] = - operand_strategies_to_consider->leaf_vector[j].output_sharding; + operand_strategies_to_consider->strategies[j].output_sharding; possible_resharding_costs[j][k] = ReshardingCostVector( strategy_map.at(ins->operand(k)).get(), ins->operand(k)->shape(), - operand_strategies_to_consider->leaf_vector[j].output_sharding, + operand_strategies_to_consider->strategies[j].output_sharding, cluster_env); } } else { @@ -634,7 +634,7 @@ void AddReplicatedStrategy( for (size_t j = 0; j < possible_input_shardings.size(); ++j) { double communication_cost = ComputeCommunicationCost( ins, possible_input_shardings[j], cluster_env); - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {"R", replicated_strategy, replicated_penalty, communication_cost, memory_cost, std::move(possible_resharding_costs[j]), std::move(possible_input_shardings[j])})); @@ -656,7 +656,7 @@ void AddReplicatedStrategy( auto operand = ins->operand(k); if (ins->opcode() == HloOpcode::kConditional) { resharding_costs.push_back(std::vector( - strategy_map.at(operand)->leaf_vector.size(), 0)); + strategy_map.at(operand)->strategies.size(), 0)); } else { resharding_costs.push_back(ReshardingCostVector( strategy_map.at(operand).get(), ins->operand(k)->shape(), @@ -665,7 +665,7 @@ void AddReplicatedStrategy( } } } - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {"R", HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -744,7 +744,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, communication_cost = ComputeSortCommunicationCost( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -861,7 +861,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } } } - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -903,7 +903,7 @@ void EnumerateAll1DPartitionReshape( std::vector> resharding_costs{ ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), *input_spec, cluster_env)}; - strategy_group->leaf_vector.push_back( + strategy_group->strategies.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -991,7 +991,7 @@ void BuildStrategyAndCostForReshape( std::vector> resharding_costs{ ReshardingCostVector(strategy_map.at(operand).get(), operand->shape(), *input_spec, cluster_env)}; - strategy_group->leaf_vector.push_back( + strategy_group->strategies.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -1010,10 +1010,10 @@ int64_t MaxNumTiles(const StrategyMap& strategy_map, strategy_group = strategy_group->following; } int64_t max_num_tiles = -1; - for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { + for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { max_num_tiles = std::max(max_num_tiles, - strategy_group->leaf_vector[i].output_sharding.NumTiles()); + strategy_group->strategies[i].output_sharding.NumTiles()); } return max_num_tiles; @@ -1163,8 +1163,8 @@ StatusOr> CreateAllStrategiesVector( if (option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { // Set penalty for 1d partial tiled layout - for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { - strategy_group->leaf_vector[i].compute_cost += replicated_penalty * 0.8; + for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { + strategy_group->strategies[i].compute_cost += replicated_penalty * 0.8; } // Split 1 dim, but for 1d mesh @@ -1172,7 +1172,7 @@ StatusOr> CreateAllStrategiesVector( cluster_env, strategy_map, strategy_group, only_allow_divisible, " 1d", call_graph); } - if (create_replicated_strategies || strategy_group->leaf_vector.empty()) { + if (create_replicated_strategies || strategy_group->strategies.empty()) { AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, replicated_penalty); } @@ -1269,8 +1269,8 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // Sharding provided by XLA users, we need to keep them. strategy_group->following = nullptr; std::vector strategy_indices; - for (size_t i = 0; i < strategy_group->leaf_vector.size(); i++) { - if (strategy_group->leaf_vector[i].output_sharding == + for (size_t i = 0; i < strategy_group->strategies.size(); i++) { + if (strategy_group->strategies[i].output_sharding == existing_sharding) { strategy_indices.push_back(i); } @@ -1281,15 +1281,15 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // Stores other strategies in the map, removes them in the vector and // only keeps the one we found. pretrimmed_strategy_map[strategy_group->node_idx] = - strategy_group->leaf_vector; - std::vector new_leaf_vector; + strategy_group->strategies; + std::vector new_strategies; for (int32_t found_strategy_index : strategy_indices) { ShardingStrategy found_strategy = - strategy_group->leaf_vector[found_strategy_index]; - new_leaf_vector.push_back(found_strategy); + strategy_group->strategies[found_strategy_index]; + new_strategies.push_back(found_strategy); } - strategy_group->leaf_vector.clear(); - strategy_group->leaf_vector = new_leaf_vector; + strategy_group->strategies.clear(); + strategy_group->strategies = new_strategies; } else { VLOG(1) << "Generate a new strategy based on user sharding."; std::string name = ToStringSimple(existing_sharding); @@ -1326,12 +1326,12 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( } double memory_cost = GetBytes(output_shape) / existing_sharding.NumTiles(); - if (!strategy_group->leaf_vector.empty()) { + if (!strategy_group->strategies.empty()) { pretrimmed_strategy_map[strategy_group->node_idx] = - strategy_group->leaf_vector; + strategy_group->strategies; } - strategy_group->leaf_vector.clear(); - strategy_group->leaf_vector.push_back( + strategy_group->strategies.clear(); + strategy_group->strategies.push_back( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, resharding_costs, input_shardings})); } @@ -1339,9 +1339,9 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. - if (strategy_group->leaf_vector.size() == 1) { + if (strategy_group->strategies.size() == 1) { for (auto& operand_resharding_costs : - strategy_group->leaf_vector[0].resharding_costs) { + strategy_group->strategies[0].resharding_costs) { if (operand_resharding_costs.size() == 1 && operand_resharding_costs[0] >= kInfinityCost) { operand_resharding_costs[0] = 0; @@ -1355,7 +1355,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // It is IMPORTANT that we do this only for instructions that do no follow // others, to keep the number of ILP variable small. std::vector new_vector; - for (const auto& strategy : strategy_group->leaf_vector) { + for (const auto& strategy : strategy_group->strategies) { if (strategy.output_sharding.IsReplicated() || ShardingIsConsistent(existing_sharding, strategy.output_sharding, strict) || @@ -1372,9 +1372,9 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // not have to strictly keep those shardings and the only purpose is to // reduce problem size for the last iteration. if (!new_vector.empty() && - new_vector.size() != strategy_group->leaf_vector.size()) { + new_vector.size() != strategy_group->strategies.size()) { strategy_group->following = nullptr; - strategy_group->leaf_vector = std::move(new_vector); + strategy_group->strategies = std::move(new_vector); } } } @@ -1388,14 +1388,14 @@ void CheckMemoryCosts(StrategyGroup* strategy_group, const Shape& shape) { } } else { double full_mem = 0.0; - for (const auto& strategy : strategy_group->leaf_vector) { + for (const auto& strategy : strategy_group->strategies) { if (strategy.output_sharding.IsReplicated()) { full_mem = strategy.memory_cost; size_t size = GetInstructionSize(shape); CHECK_EQ(strategy.memory_cost, size); } } - for (const auto& strategy : strategy_group->leaf_vector) { + for (const auto& strategy : strategy_group->strategies) { if (!strategy.output_sharding.IsReplicated() && full_mem > 0.0) { CHECK_EQ(strategy.memory_cost * strategy.output_sharding.NumTiles(), full_mem); @@ -1415,7 +1415,7 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, } } else { if (instruction_has_user_sharding && - strategy_group->leaf_vector.size() == 1) { + strategy_group->strategies.size() == 1) { // If an instruction has a specified user sharding, and there is only a // single strategy, removing that strategy would mean we won't have any // strategy for that instruction. Further, given that the user has @@ -1424,7 +1424,7 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, return; } std::vector new_vector; - for (const auto& strategy : strategy_group->leaf_vector) { + for (const auto& strategy : strategy_group->strategies) { if (strategy.output_sharding.IsReplicated()) { new_vector.push_back(strategy); continue; @@ -1444,7 +1444,7 @@ void RemoveInvalidShardingsWithShapes(const Shape& shape, new_vector.push_back(strategy); } } - strategy_group->leaf_vector = std::move(new_vector); + strategy_group->strategies = std::move(new_vector); } } @@ -1454,7 +1454,7 @@ void CheckReshardingCostsShape(StrategyGroup* strategy_group) { CheckReshardingCostsShape(strategy_group->childs[i].get()); } } else { - for (const auto& strategy : strategy_group->leaf_vector) { + for (const auto& strategy : strategy_group->strategies) { if (strategy_group->in_nodes.size() == 1 && strategy_group->in_nodes.at(0)->is_tuple) { // This is when current instruction's only operand is tuple, and the @@ -1477,11 +1477,11 @@ void CheckReshardingCostsShape(StrategyGroup* strategy_group) { if (strategy_group->in_nodes.size() == 1 && strategy_group->in_nodes.at(0)->is_tuple) { to_compare = - strategy_group->in_nodes.at(0)->childs.at(i)->leaf_vector.size(); + strategy_group->in_nodes.at(0)->childs.at(i)->strategies.size(); } else if (strategy_group->is_tuple) { to_compare = strategy_group->in_nodes.at(i)->childs.size(); } else { - to_compare = strategy_group->in_nodes.at(i)->leaf_vector.size(); + to_compare = strategy_group->in_nodes.at(i)->strategies.size(); } CHECK_EQ(strategy.resharding_costs[i].size(), to_compare) << "\nIndex of resharding_costs: " << i @@ -1510,7 +1510,7 @@ void ScaleCostsWithExecutionCounts(StrategyGroup* strategy_group, execution_count); } } else { - for (auto& strategy : strategy_group->leaf_vector) { + for (auto& strategy : strategy_group->strategies) { strategy.compute_cost *= execution_count; strategy.communication_cost *= execution_count; for (auto i = 0; i < strategy.resharding_costs.size(); ++i) { @@ -1556,11 +1556,9 @@ std::unique_ptr CreateElementwiseOperatorStrategies( } auto process_src_strategy_group = - [&](const std::vector& src_strategies_leaf_vector) { - for (int64_t sid = 0; sid < src_strategies_leaf_vector.size(); - ++sid) { - HloSharding output_spec = - src_strategies_leaf_vector[sid].output_sharding; + [&](const std::vector& src_strategies) { + for (int64_t sid = 0; sid < src_strategies.size(); ++sid) { + HloSharding output_spec = src_strategies[sid].output_sharding; std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = @@ -1574,7 +1572,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( input_shardings.push_back(output_spec); } - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_costs), input_shardings})); } @@ -1582,7 +1580,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( StrategyGroup* src_strategy_group = strategy_map.at(ins->operand(i)).get(); CHECK(!src_strategy_group->is_tuple); - process_src_strategy_group(src_strategy_group->leaf_vector); + process_src_strategy_group(src_strategy_group->strategies); if (pretrimmed_strategy_map.contains(src_strategy_group->node_idx)) { process_src_strategy_group( pretrimmed_strategy_map.at(src_strategy_group->node_idx)); @@ -1628,11 +1626,11 @@ std::unique_ptr CreateReshapeStrategies( CHECK(!src_strategy_group->is_tuple); strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); ++sid) { + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { std::optional output_spec = hlo_sharding_util::ReshapeSharding( operand->shape(), ins->shape(), - src_strategy_group->leaf_vector[sid].output_sharding); + src_strategy_group->strategies[sid].output_sharding); if (!output_spec.has_value()) { continue; @@ -1650,21 +1648,21 @@ std::unique_ptr CreateReshapeStrategies( double memory_cost = GetBytes(ins->shape()) / output_spec->NumTiles(); std::vector resharding_costs = ReshardingCostVector( src_strategy_group, operand->shape(), - src_strategy_group->leaf_vector[sid].output_sharding, cluster_env); - strategy_group->leaf_vector.push_back(ShardingStrategy( + src_strategy_group->strategies[sid].output_sharding, cluster_env); + strategy_group->strategies.push_back(ShardingStrategy( {name, *output_spec, compute_cost, communication_cost, memory_cost, {resharding_costs}, - {src_strategy_group->leaf_vector[sid].output_sharding}})); + {src_strategy_group->strategies[sid].output_sharding}})); } } // Fail to create follow strategies, enumerate all possible cases - if (strategy_group->leaf_vector.empty()) { - strategy_group->leaf_vector.clear(); + if (strategy_group->strategies.empty()) { + strategy_group->strategies.clear(); strategy_group->following = nullptr; // Split 1 dim @@ -1789,10 +1787,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // We follow the first operand (the array we're scattering into) auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); CHECK(!src_strategy_group->is_tuple); - for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { HloSharding output_spec = - src_strategy_group->leaf_vector[sid].output_sharding; + src_strategy_group->strategies[sid].output_sharding; std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); @@ -1808,7 +1806,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(sharding_optional.has_value()); } - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); } @@ -1855,17 +1853,17 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, output_spec, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); } } auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); - for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { HloSharding output_spec = - src_strategy_group->leaf_vector[sid].output_sharding; + src_strategy_group->strategies[sid].output_sharding; auto gather_parallel_dims = hlo_sharding_util::GetGatherParallelBatchDims(*ins, call_graph); absl::Span operand_parallel_dims; @@ -1890,7 +1888,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, *maybe_from_data, strategy_map, cluster_env, call_graph, input_shardings_optional); - strategy_group->leaf_vector.push_back(ShardingStrategy( + strategy_group->strategies.push_back(ShardingStrategy( {name, *maybe_from_data, compute_cost, communication_cost, memory_cost, std::move(resharding_cost), input_shardings_optional})); @@ -1950,11 +1948,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(!src_strategy_group->is_tuple); strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { HloSharding output_spec = Undefined(); - auto input_spec = - src_strategy_group->leaf_vector[sid].output_sharding; + auto input_spec = src_strategy_group->strategies[sid].output_sharding; if (opcode == HloOpcode::kTranspose) { output_spec = hlo_sharding_util::TransposeSharding( input_spec, ins->dimensions()); @@ -1968,7 +1965,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); auto resharding_costs = ReshardingCostVector( src_strategy_group, operand->shape(), input_spec, cluster_env); - strategy_group->leaf_vector.push_back( + strategy_group->strategies.push_back( ShardingStrategy({name, output_spec, compute_cost, @@ -2018,11 +2015,11 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(!src_strategy_group->is_tuple); strategy_group->following = src_strategy_group; - for (int64_t sid = 0; sid < src_strategy_group->leaf_vector.size(); + for (int64_t sid = 0; sid < src_strategy_group->strategies.size(); ++sid) { std::optional output_spec; HloSharding input_spec = - src_strategy_group->leaf_vector[sid].output_sharding; + src_strategy_group->strategies[sid].output_sharding; // Find output shardings. switch (opcode) { @@ -2066,7 +2063,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, ins, *output_spec, strategy_map, cluster_env, call_graph, input_shardings); - strategy_group->leaf_vector.push_back( + strategy_group->strategies.push_back( ShardingStrategy({name, *output_spec, compute_cost, @@ -2076,7 +2073,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, {input_spec}})); } - if (strategy_group->leaf_vector.empty()) { + if (strategy_group->strategies.empty()) { strategy_group->following = nullptr; AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0); @@ -2407,13 +2404,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } if (!strategy_group->is_tuple && strategy_group->following) { if (!LeafVectorsAreConsistent( - strategy_group->leaf_vector, - strategy_group->following->leaf_vector, + strategy_group->strategies, strategy_group->following->strategies, /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { // It confuses the solver if two instructions have different number of // sharding strategies but share the same ILP variable. The solver // would run much longer and/or return infeasible solutions. - // So if two strategies' leaf_vectors are inconsistent, we unfollow + // So if two strategies' strategiess are inconsistent, we unfollow // them. strategy_group->following = nullptr; } @@ -2421,8 +2417,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, for (size_t i = 0; i < strategy_group->childs.size(); i++) { if (strategy_group->childs.at(i)->following && !LeafVectorsAreConsistent( - strategy_group->childs.at(i)->leaf_vector, - strategy_group->childs.at(i)->following->leaf_vector, + strategy_group->childs.at(i)->strategies, + strategy_group->childs.at(i)->following->strategies, /*is_reshape*/ ins->opcode() == HloOpcode::kReshape)) { strategy_group->childs.at(i)->following = nullptr; } @@ -2449,14 +2445,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, auto it = absl::c_find(inst_indices, strategy_group->node_idx); if (it != inst_indices.end()) { CHECK(!strategy_group->is_tuple); - std::vector new_leaf_vector; + std::vector new_strategies; int64_t idx = it - inst_indices.begin(); - for (const auto& stra : strategy_group->leaf_vector) { + for (const auto& stra : strategy_group->strategies) { if (stra.name == stra_names[idx]) { - new_leaf_vector.push_back(stra); + new_strategies.push_back(stra); } } - strategy_group->leaf_vector = std::move(new_leaf_vector); + strategy_group->strategies = std::move(new_strategies); } } @@ -2467,10 +2463,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // the mesh shape we're trying does not match with the mesh shape used in // user specified shardings. So we disable the check in that situation. if (!trying_multiple_mesh_shapes) { - CHECK(strategy_group->is_tuple || !strategy_group->leaf_vector.empty()) + CHECK(strategy_group->is_tuple || !strategy_group->strategies.empty()) << ins->ToString() << " does not have any valid strategies."; } else if (!(strategy_group->is_tuple || - !strategy_group->leaf_vector.empty())) { + !strategy_group->strategies.empty())) { return Status(absl::StatusCode::kFailedPrecondition, "Could not generate any shardings for an instruction due " "to mismatched mesh shapes."); @@ -2492,7 +2488,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, StrategyGroup* stra_vector = strategy_map[inst].get(); CHECK(!stra_vector->is_tuple); - for (auto& stra : stra_vector->leaf_vector) { + for (auto& stra : stra_vector->strategies) { if (absl::StrContains(stra.name, "allreduce")) { stra.communication_cost /= option.grad_acc_num_micro_batches; } @@ -2562,8 +2558,8 @@ AutoShardingSolverResult CallSolver( tuple_elements.at(*strategy_group->tuple_element_idx); } } - for (NodeStrategyIdx j = 0; j < strategy_group->leaf_vector.size(); ++j) { - const ShardingStrategy& strategy = strategy_group->leaf_vector[j]; + for (NodeStrategyIdx j = 0; j < strategy_group->strategies.size(); ++j) { + const ShardingStrategy& strategy = strategy_group->strategies[j]; const HloSharding& sharding = strategy.output_sharding; ci.push_back(strategy.compute_cost); di.push_back(strategy.communication_cost + @@ -2591,14 +2587,14 @@ AutoShardingSolverResult CallSolver( for (const auto& pair : alias_set) { const StrategyGroup* src_strategy_group = leaf_strategies[pair.first]; const StrategyGroup* dst_strategy_group = leaf_strategies[pair.second]; - Matrix raw_cost(src_strategy_group->leaf_vector.size(), - dst_strategy_group->leaf_vector.size()); - for (NodeStrategyIdx i = 0; i < src_strategy_group->leaf_vector.size(); + Matrix raw_cost(src_strategy_group->strategies.size(), + dst_strategy_group->strategies.size()); + for (NodeStrategyIdx i = 0; i < src_strategy_group->strategies.size(); ++i) { - for (NodeStrategyIdx j = 0; j < dst_strategy_group->leaf_vector.size(); + for (NodeStrategyIdx j = 0; j < dst_strategy_group->strategies.size(); ++j) { - if (src_strategy_group->leaf_vector[i].output_sharding == - dst_strategy_group->leaf_vector[j].output_sharding) { + if (src_strategy_group->strategies[i].output_sharding == + dst_strategy_group->strategies[j].output_sharding) { raw_cost(i, j) = 0.0; } else { raw_cost(i, j) = 1.0; @@ -2810,13 +2806,13 @@ void SetHloSharding(const HloInstructionSequence& sequence, NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = s_val[node_idx]; // Do not set completed sharding before the last iteration - if (strategy_group->leaf_vector[stra_idx] + if (strategy_group->strategies[stra_idx] .output_sharding.IsReplicated() && !last_iteration) { set_tuple_sharding = false; } output_flattened_shardings.push_back( - strategy_group->leaf_vector[stra_idx].output_sharding); + strategy_group->strategies[stra_idx].output_sharding); } }; extract_tuple_shardings(strategy_group); @@ -3137,11 +3133,11 @@ std::string PrintAutoShardingSolution(const HloInstructionSequence& sequence, NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); if (cost_graph.follow_idx_[node_idx] < 0) { absl::StrAppend( - &str, leaf_strategies[node_idx]->leaf_vector[stra_idx].ToString(), + &str, leaf_strategies[node_idx]->strategies[stra_idx].ToString(), "\n"); } else { absl::StrAppend( - &str, leaf_strategies[node_idx]->leaf_vector[stra_idx].ToString(), + &str, leaf_strategies[node_idx]->strategies[stra_idx].ToString(), " follow ", cost_graph.follow_idx_[node_idx], "\n"); } } @@ -3169,7 +3165,7 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, } NodeIdx ins_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(ins_idx, s_val[ins_idx]); - const ShardingStrategy& strategy = strategy_group->leaf_vector[stra_idx]; + const ShardingStrategy& strategy = strategy_group->strategies[stra_idx]; return strategy.memory_cost; }; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { @@ -3926,8 +3922,8 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, "not divisible by the number of devices"); } - std::vector new_leaf_vector; - for (auto& stra : strategy_group->leaf_vector) { + std::vector new_strategies; + for (auto& stra : strategy_group->strategies) { std::vector tensor_dim_to_mesh_dim = cluster_env.GetTensorDimToMeshDimWrapper(shape, stra.output_sharding); @@ -3935,19 +3931,19 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, // If the mesh dim is not one, the output tensor must be // tiled along the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_leaf_vector.push_back(std::move(stra)); + new_strategies.push_back(std::move(stra)); } } else { // If the mesh dim is one, the output tensor must be replicated // on the mesh dim. if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_leaf_vector.push_back(std::move(stra)); + new_strategies.push_back(std::move(stra)); } } } - CHECK(!new_leaf_vector.empty()) + CHECK(!new_strategies.empty()) << ins->ToString() << " does not have any valid strategies"; - strategy_group->leaf_vector = std::move(new_leaf_vector); + strategy_group->strategies = std::move(new_strategies); return OkStatus(); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index cf24b493f4b1d1..c8b6611dd0d108 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -45,9 +45,9 @@ class CostGraph { // Build the cost graph for (const auto& strategies : leaf_strategies) { - node_lens_.push_back(strategies->leaf_vector.size()); + node_lens_.push_back(strategies->strategies.size()); extra_node_costs_.push_back( - std::vector(strategies->leaf_vector.size(), 0.0)); + std::vector(strategies->strategies.size(), 0.0)); for (size_t i = 0; i < strategies->in_nodes.size(); ++i) { if (!strategies->in_nodes[i]->is_tuple) { @@ -101,14 +101,14 @@ class CostGraph { Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { - if (leaf_strategies[src_idx]->leaf_vector[i].communication_cost > 0) { + if (leaf_strategies[src_idx]->strategies[i].communication_cost > 0) { CHECK_LE( std::abs( - leaf_strategies[src_idx]->leaf_vector[i].communication_cost - - leaf_strategies[dst_idx]->leaf_vector[i].communication_cost), + leaf_strategies[src_idx]->strategies[i].communication_cost - + leaf_strategies[dst_idx]->strategies[i].communication_cost), 1e-6); edge_cost(i, i) = - -leaf_strategies[src_idx]->leaf_vector[i].communication_cost; + -leaf_strategies[src_idx]->strategies[i].communication_cost; } } AddEdgeCost(src_idx, dst_idx, edge_cost); @@ -120,8 +120,8 @@ class CostGraph { CHECK_LT(src_idx, node_lens_.size()); CHECK_LT(dst_idx, node_lens_.size()); Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx k = 0; k < strategy_group->leaf_vector.size(); ++k) { - const ShardingStrategy& strategy = strategy_group->leaf_vector[k]; + for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { + const ShardingStrategy& strategy = strategy_group->strategies[k]; size_t start_idx = 0; if (strategy.resharding_costs[in_node_idx].size() > node_lens_[src_idx]) { start_idx = @@ -366,7 +366,7 @@ inline const ShardingStrategy& GetShardingStrategy( CHECK(!strategy_group->is_tuple); NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->leaf_vector[stra_idx]; + return strategy_group->strategies[stra_idx]; } // Get the final sharding strategy according to the ilp solution. @@ -383,7 +383,7 @@ inline const ShardingStrategy& GetShardingStrategyForTuple( } NodeIdx node_idx = strategy_group->node_idx; NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); - return strategy_group->leaf_vector[stra_idx]; + return strategy_group->strategies[stra_idx]; } } // namespace spmd diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 99390479414960..ffc16f0868571a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -87,7 +87,7 @@ class HandlerBase { operand->shape(), input_specs[i], cluster_env_)); } - strategy_group_->leaf_vector.push_back(ShardingStrategy({ + strategy_group_->strategies.push_back(ShardingStrategy({ name, output_spec, compute_cost, @@ -601,7 +601,7 @@ class DotHandler : public HandlerBase { cluster_env_.non_zero_mesh_dims_.size() > 1) { // If there is a batch dim and the device mesh is multi-dimensional, // always split on batch dim. Clear all old strategies. - strategy_group_->leaf_vector.clear(); + strategy_group_->strategies.clear(); } // Sb = Sb x Sb @@ -626,7 +626,7 @@ class DotHandler : public HandlerBase { [](int64_t size) { return size > 1; }) > 1) { // If there are two batch dims, always split on these two dims. // Clear all old strategies. - strategy_group_->leaf_vector.clear(); + strategy_group_->strategies.clear(); } // Sb = Sb x Sb diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 2ca280330dca7f..4a8e479ef6ed9e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -136,18 +136,18 @@ struct StrategyGroup { bool is_tuple; // The index used in the solver. For non-leaf nodes, this is set to -1. NodeIdx node_idx; - // The index of the HLO instruction that this strategy vector belongs to. + // The index of the HLO instruction that this strategy group belongs to. size_t instruction_id; // The connected nodes used for resharding costs; // The size must be the same as the size of resharding cost - // each element in leaf_vector's resharding_costs.size() needs to be the same + // each element in strategies's resharding_costs.size() needs to be the same // as strategies->in_nodes.size() std::vector in_nodes; // The followed strategy. Used for merging nodes. const StrategyGroup* following = nullptr; // Used when is_tuple == False. Leaf strategy vector. // A vector of strategy choices for the non-tuple output. - std::vector leaf_vector; + std::vector strategies; // Used when is_tuple == True. A vector of pointers, each pointer is one // StrategyGroup for one value in the output Tuple std::vector> childs; @@ -177,7 +177,7 @@ struct StrategyGroup { absl::StrAppend(&str, childs[i]->ToString(indention + 2)); } } else { - for (const auto& strategy : leaf_vector) { + for (const auto& strategy : strategies) { absl::StrAppend(&str, indent, "Strategy ", strategy.ToStringLong()); } } @@ -198,7 +198,7 @@ struct StrategyGroup { using LivenessSet = std::vector>; // A liveness set using node indices instead of HLO values. using LivenessNodeSet = std::vector>; -// Map an instruction to its strategy vector. +// Map an instruction to its strategy group. using StrategyMap = StableHashMap>; // The list of all leaf strategies. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index fe94eaf695a909..00b987b1e4575e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -881,36 +881,36 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { RemoveDuplicatedStrategy(child); } } else if (!strategy_group->following) { - if (strategy_group->leaf_vector.empty()) return; + if (strategy_group->strategies.empty()) return; std::vector new_vector; std::vector deduped_replicated_strategies; absl::flat_hash_set added; size_t num_skipped_due_to_infinity_costs = 0; - for (size_t i = 0; i < strategy_group->leaf_vector.size(); ++i) { - if (AllInfinityCosts(strategy_group->leaf_vector[i].resharding_costs)) { + for (size_t i = 0; i < strategy_group->strategies.size(); ++i) { + if (AllInfinityCosts(strategy_group->strategies[i].resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; } std::string key = - strategy_group->leaf_vector[i].output_sharding.ToString(); - if (!strategy_group->leaf_vector[i].input_shardings.empty()) { + strategy_group->strategies[i].output_sharding.ToString(); + if (!strategy_group->strategies[i].input_shardings.empty()) { for (const auto& sharding : - strategy_group->leaf_vector[i].input_shardings) { + strategy_group->strategies[i].input_shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } } if (!added.contains(key)) { added.insert(key); - if (!strategy_group->leaf_vector[i].output_sharding.IsReplicated()) { - new_vector.push_back(std::move(strategy_group->leaf_vector[i])); + if (!strategy_group->strategies[i].output_sharding.IsReplicated()) { + new_vector.push_back(std::move(strategy_group->strategies[i])); } else { deduped_replicated_strategies.push_back( - std::move(strategy_group->leaf_vector[i])); + std::move(strategy_group->strategies[i])); } } } CHECK_LT(num_skipped_due_to_infinity_costs, - strategy_group->leaf_vector.size()) + strategy_group->strategies.size()) << "All strategies removed due to infinite resharding costs"; // Keeps replicated strategies as the last ones. if (!deduped_replicated_strategies.empty()) { @@ -918,7 +918,7 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { new_vector.push_back(std::move(deduped_replicated_strategies[i])); } } - strategy_group->leaf_vector = std::move(new_vector); + strategy_group->strategies = std::move(new_vector); } } @@ -1834,12 +1834,12 @@ void CheckAliasSetCompatibility(const AliasSet& alias_set, size_t compatible_cnt = 0; bool replicated = false; - for (size_t i = 0; i < src_strategy_group->leaf_vector.size(); ++i) { - for (size_t j = 0; j < dst_strategy_group->leaf_vector.size(); ++j) { - if (src_strategy_group->leaf_vector[i].output_sharding == - dst_strategy_group->leaf_vector[j].output_sharding) { + for (size_t i = 0; i < src_strategy_group->strategies.size(); ++i) { + for (size_t j = 0; j < dst_strategy_group->strategies.size(); ++j) { + if (src_strategy_group->strategies[i].output_sharding == + dst_strategy_group->strategies[j].output_sharding) { compatible_cnt += 1; - if (src_strategy_group->leaf_vector[i] + if (src_strategy_group->strategies[i] .output_sharding.IsReplicated()) { replicated = true; } @@ -1848,8 +1848,8 @@ void CheckAliasSetCompatibility(const AliasSet& alias_set, } if (compatible_cnt == 1 && - (replicated && (src_strategy_group->leaf_vector.size() > 1 || - dst_strategy_group->leaf_vector.size() > 1))) { + (replicated && (src_strategy_group->strategies.size() > 1 || + dst_strategy_group->strategies.size() > 1))) { LOG(WARNING) << "Alias pair has only replicated strategy in common. This " "will result in choosing replicated strategy for these " From e63d37e07b95e5c6e8a78794547fe317454edae7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 15 Nov 2023 22:49:15 -0800 Subject: [PATCH 159/391] Remove a couple dead functions PiperOrigin-RevId: 582917032 --- .../auto_sharding/auto_sharding_util.cc | 75 ------------------- .../auto_sharding/auto_sharding_util.h | 11 --- 2 files changed, 86 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 00b987b1e4575e..262fa0663c9723 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -108,40 +108,6 @@ bool IsActivationFromAnotherStage(const HloInstruction* ins, return true; } -// Propagate sharding for broadcast. -// The output will be tiled along the broadcasted dimension the same way -// as the input for the broadcast while the other dimensions are kept -// non-tiled. -HloSharding BroadcastSharding(const HloSharding& input_spec, - const Shape& new_shape, - absl::Span dimensions) { - if (input_spec.IsReplicated()) { - return input_spec; - } - CHECK(new_shape.IsArray()); - std::vector target_tile_assignment_dimensions; - for (int64_t i = 0; i < new_shape.rank(); ++i) { - auto it = absl::c_find(dimensions, i); - if (it == dimensions.end()) { - target_tile_assignment_dimensions.push_back(1); - } else { - const int64_t source_dim = std::distance(dimensions.begin(), it); - target_tile_assignment_dimensions.push_back( - input_spec.tile_assignment().dim(source_dim)); - } - } - if (input_spec.ReplicateOnLastTileDim()) { - target_tile_assignment_dimensions.push_back( - input_spec.tile_assignment().dimensions().back()); - } - auto new_tile_assignment = - input_spec.tile_assignment().Reshape(target_tile_assignment_dimensions); - - return input_spec.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(new_tile_assignment) - : HloSharding::Tile(new_tile_assignment); -} - // Propagate sharding for dim-wise operations (e.g., slice, pad) which works // independently on each dimension. // The sharding can successfully propagate if the operation only happens @@ -1056,47 +1022,6 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, } } -void RemoveCustomCallMarker(HloModule* module) { - HloComputation* entry_computation = module->entry_computation(); - - std::vector get_tuple_ins; - std::vector marker_ins; - - for (HloInstruction* ins : entry_computation->instructions()) { - if (ins->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(ins->operand(0))) { - get_tuple_ins.push_back(ins); - marker_ins.push_back(ins->mutable_operand(0)); - } - } - - for (HloInstruction* raw_ins : get_tuple_ins) { - HloInstruction* ins = raw_ins; - while (ins->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* custom_call = ins->mutable_operand(0); - CHECK(IsCustomCallMarker(custom_call)); - HloInstruction* tuple = custom_call->mutable_operand(0); - ins = tuple->mutable_operand(ins->tuple_index()); - } - - TF_CHECK_OK(raw_ins->ReplaceAllUsesWith(ins)); - } - - for (HloInstruction* ins : get_tuple_ins) { - TF_CHECK_OK(entry_computation->RemoveInstruction(ins)); - } - - StableHashSet removed; - for (HloInstruction* ins : marker_ins) { - if (!removed.contains(ins)) { - HloInstruction* tmp = ins->mutable_operand(0); - TF_CHECK_OK(entry_computation->RemoveInstruction(ins)); - TF_CHECK_OK(entry_computation->RemoveInstruction(tmp)); - removed.insert(ins); - } - } -} - // Gets values in |array| along |dim| while keeping indices at other // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], // array[1, 1], array [2, 1], .... diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index bbf63c84f9bdc5..290d92f45a86a0 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -380,9 +380,6 @@ std::string GetBatchDimMapKey(const HloInstruction* ins, int64_t idx = -1); InstructionBatchDimMap BuildInstructionBatchDimMap( const HloInstructionSequence& sequence); -// Remove all custom call makers in an HloModule. -void RemoveCustomCallMarker(HloModule* module); - /* * HloSharding Utility */ @@ -422,14 +419,6 @@ inline bool IsFullyTiled(const HloSharding& sharding) { return sharding.NumTiles() == sharding.tile_assignment().num_elements(); } -// Propagate sharding for broadcast. -// The output will be tiled along the broadcasted dimension the same way -// as the input for the broadcast while the other dimensions are kept -// non-tiled. -HloSharding BroadcastSharding(const HloSharding& input_spec, - const Shape& new_shape, - absl::Span dimensions); - // Propagate sharding for dim-wise operations (e.g., slice, pad) which works // independently on each dimension. // The sharding can successfully propagate if the operation only happens on From eaa08910e3e45f1ccc1aef27787356bd4b4d7aab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 00:08:18 -0800 Subject: [PATCH 160/391] Integrate LLVM at llvm/llvm-project@8ea8dd9a0171 Updates LLVM usage to match [8ea8dd9a0171](https://github.com/llvm/llvm-project/commit/8ea8dd9a0171) PiperOrigin-RevId: 582932648 --- .../python/mlir_wrapper/filecheck_wrapper.cc | 3 +- third_party/llvm/generated.patch | 59 +++++++++++++++---- third_party/llvm/workspace.bzl | 4 +- .../tests/Dialect/gml_st/lower_vectors.mlir | 3 +- third_party/xla/xla/service/cpu/BUILD | 5 +- .../service/cpu/hlo_xla_runtime_pipeline.cc | 18 +++--- 6 files changed, 65 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc index 6042a896709d9e..8c82fc9bc12b42 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -29,8 +29,7 @@ PYBIND11_MODULE(filecheck_wrapper, m) { llvm::SMLoc()); SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), llvm::SMLoc()); - llvm::Regex regex = fc.buildCheckPrefixRegex(); - fc.readCheckFile(SM, llvm::StringRef(check), regex); + fc.readCheckFile(SM, llvm::StringRef(check)); return fc.checkInput(SM, llvm::StringRef(input)); }); } diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index af1f3cebfc9100..bc734d4fe6ced1 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,12 +1,51 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/test/Analysis/builtin_signbit.cpp b/clang/test/Analysis/builtin_signbit.cpp ---- a/clang/test/Analysis/builtin_signbit.cpp -+++ b/clang/test/Analysis/builtin_signbit.cpp -@@ -5,6 +5,7 @@ - // RUN: -O0 %s -o - | FileCheck %s --check-prefixes=CHECK-BE64 - // RUN: %clang -target powerpc64le-linux-gnu -emit-llvm -S -mabi=ibmlongdouble \ - // RUN: -O0 %s -o - | FileCheck %s --check-prefixes=CHECK-LE -+// REQUIRES: asserts +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +@@ -5172,6 +5172,18 @@ + ], + ) + ++gentbl( ++ name = "ReadTAPIOptsTableGen", ++ strip_include_prefix = "tools/llvm-readtapi", ++ tbl_outs = [( ++ "-gen-opt-parser-defs", ++ "tools/llvm-readtapi/TapiOpts.inc", ++ )], ++ tblgen = ":llvm-tblgen", ++ td_file = "tools/llvm-readtapi/TapiOpts.td", ++ td_srcs = ["include/llvm/Option/OptParser.td"], ++) ++ + cc_binary( + name = "llvm-readtapi", + testonly = True, +@@ -5183,6 +5195,8 @@ + stamp = 0, + deps = [ + ":Object", ++ ":Option", ++ ":ReadTAPIOptsTableGen", + ":Support", + ":TextAPI", + ], +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -1022,6 +1022,7 @@ + ":CAPIIR", + ":CAPIQuant", + ":MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", + ], + ) + +@@ -1040,6 +1041,7 @@ + ":CAPIIR", + ":CAPISparseTensor", + ":MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", + ], + ) - bool b; - double d = -1.0; diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index d3a0e41d030015..869bcb78ea2e15 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "c5dd1bbcc37e8811e7c6050159014d084eac6438" - LLVM_SHA256 = "f374bf677707588fc07235215e2bf03e27dfb299c4b478306dc918099c60b583" + LLVM_COMMIT = "8ea8dd9a017182d167f39f521ef397afba5a0fd5" + LLVM_SHA256 = "6963e268a2e03ff956f2457a629a8dafab8682eb4bbb664ff8dc668dd3faef7b" tf_http_archive( name = name, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir index 6c10c564e522bb..78fd35133dc6c0 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir @@ -161,8 +161,7 @@ func.func @optimize_pack_with_transpose(%arg0: memref<1024x1024xf32>) -> // FLATTEN-NOT: vector.transpose // FLATTEN: %[[COLLAPSE:.*]] = memref.collapse_shape %[[ALLOC]] // FLATTEN-SAME: memref<128x1024x8x1xf32> into memref<128x1024x8xf32> -// FLATTEN: %[[SHAPE_CAST:.*]] = vector.shape_cast %{{.*}} -// FLATTEN: vector.transfer_write %[[SHAPE_CAST]], %[[COLLAPSE]] +// FLATTEN: vector.transfer_write %[[READ]], %[[COLLAPSE]] // ----- diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 99383d270214d7..9f30361fa77afe 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1682,5 +1682,8 @@ cc_library( name = "cpu_symbol_repository", hdrs = ["cpu_symbol_repository.h"], visibility = ["//visibility:public"], - deps = ["//xla/service:symbol_repository"], + deps = [ + "//xla:xla_proto_cc", + "//xla/service:symbol_repository", + ], ) diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index a004f77306da78..535730e5dc41e7 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -91,16 +91,12 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, // Setting 1 thread means cuSPARSE libgen. // Otherwise direct CUDA codegen. const bool gpu_codegen = xla_cpu_sparse_cuda_threads > 0; + const bool gpu_libgen = xla_cpu_sparse_cuda_threads == 1; mlir::SparsificationOptions sparsification_options; sparsification_options.enableRuntimeLibrary = false; - sparsification_options.enableIndexReduction = true; - if (gpu_codegen) { - if (xla_cpu_sparse_cuda_threads == 1) { - sparsification_options.enableGPULibgen = true; - } else { - sparsification_options.parallelizationStrategy = - mlir::SparseParallelizationStrategy::kDenseOuterLoop; - } + if (gpu_codegen && !gpu_libgen) { + sparsification_options.parallelizationStrategy = + mlir::SparseParallelizationStrategy::kDenseOuterLoop; } // Sparsification set up. pm.addNestedPass(mlir::createLinalgGeneralizationPass()); @@ -113,14 +109,16 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, /*enableBufferInitialization=*/false, /*vectorLength=*/0, /*enableVLAVectorization=*/false, - /*enableSIMDIndex32*/ false)); + /*enableSIMDIndex32=*/false, + /*enableGPULibgen=*/gpu_libgen)); pm.addPass(mlir::createStorageSpecifierToLLVMPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass( mlir::bufferization::createFinalizingBufferizePass()); // Sparse GPU acceleration lowers to GPU dialect. if (gpu_codegen) { - pm.addPass(mlir::createSparseGPUCodegenPass(xla_cpu_sparse_cuda_threads)); + pm.addPass( + mlir::createSparseGPUCodegenPass(xla_cpu_sparse_cuda_threads, false)); pm.addNestedPass(mlir::createStripDebugInfoPass()); pm.addNestedPass(mlir::createConvertSCFToCFPass()); pm.addNestedPass( From 85ebd9b466e7bb250a41090a145d3ca30aac1bb1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 00:42:42 -0800 Subject: [PATCH 161/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/06b4bc05a86e9d2464a82e96df6703252bf100a1. PiperOrigin-RevId: 582939748 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index e17daad0c7dbcf..66639a6944c897 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "0bb14f40fe30f95a190e6b932cb3cc1ed7376d8a" - TFRT_SHA256 = "0c0544c04e42a3967382c358d72d1f9dd957c939d003d7778671ed73e404f753" + TFRT_COMMIT = "06b4bc05a86e9d2464a82e96df6703252bf100a1" + TFRT_SHA256 = "84a8f80403c6a4a8281d0cd291ff1ea3ce6f0ed29ab35b2d3d155f4ffec66488" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index e17daad0c7dbcf..66639a6944c897 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "0bb14f40fe30f95a190e6b932cb3cc1ed7376d8a" - TFRT_SHA256 = "0c0544c04e42a3967382c358d72d1f9dd957c939d003d7778671ed73e404f753" + TFRT_COMMIT = "06b4bc05a86e9d2464a82e96df6703252bf100a1" + TFRT_SHA256 = "84a8f80403c6a4a8281d0cd291ff1ea3ce6f0ed29ab35b2d3d155f4ffec66488" tf_http_archive( name = "tf_runtime", From ecbef17e16b835a46bca87778fdaed1e02880c06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 01:02:20 -0800 Subject: [PATCH 162/391] compat: Update forward compatibility horizon to 2023-11-16 PiperOrigin-RevId: 582944391 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5bb692137843f1..2098432509a560 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 15) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 16) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 81ffa7e0b223f004dcf5d594a168fbaf12a7b173 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 01:02:25 -0800 Subject: [PATCH 163/391] Update GraphDef version to 1682. PiperOrigin-RevId: 582944405 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 1cbde634435938..b00916cf91259d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1681 // Updated: 2023/11/15 +#define TF_GRAPH_DEF_VERSION 1682 // Updated: 2023/11/16 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 32f24e5310f535c8839e41c2b6218019188066a3 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 16 Nov 2023 02:09:57 -0800 Subject: [PATCH 164/391] Remove usage of deprecated dyn_cast, cast and isa member functions of AffineExpr. PiperOrigin-RevId: 582960564 --- third_party/xla/xla/mlir_hlo/BUILD | 1 + .../mhlo/analysis/shape_component_analysis.cc | 19 ++++++++++--------- .../symbolic_shape_optimization.cc | 13 +++++++------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index f55e0dad52f3f6..abbd1dafbb860d 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -1541,6 +1541,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", ], ) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc index e1fa72e2a80ac0..044f011c81e7a5 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" using namespace mlir; @@ -482,7 +483,7 @@ struct ShapeVisitor { SymbolicExpr dim; for (auto &it : in) { // For constant expressions, we can accumulate a concrete product. - if (auto cexpr = it.expr.dyn_cast()) { + if (auto cexpr = dyn_cast(it.expr)) { assert(cexpr.getValue() > 0 && "shape value must be positive"); concreteProduct *= cexpr.getValue(); continue; @@ -768,8 +769,8 @@ void ShapeComponentAnalysis::reset() { } bool SymbolicExpr::isConstant(int64_t value) const { - return expr.isa() && - expr.cast().getValue() == value; + return isa(expr) && + cast(expr).getValue() == value; } bool SymbolicExpr::isKnownNotNegativeOne() const { @@ -787,9 +788,9 @@ bool SymbolicExpr::isKnownNotNegativeOne() const { // For constants we know if it's -1 or not. Checking the sign is sufficient // here and allows for reuse below. This is correct, not complete. auto isGoodSymbolOrGoodConstantExpr = [&](AffineExpr expr) { - if (auto symExpr = expr.dyn_cast()) + if (auto symExpr = dyn_cast(expr)) return isGoodSymbol(symbols[symExpr.getPosition()]); - if (auto constExpr = expr.dyn_cast()) + if (auto constExpr = dyn_cast(expr)) return constExpr.getValue() >= 0; return false; }; @@ -799,7 +800,7 @@ bool SymbolicExpr::isKnownNotNegativeOne() const { // Multiplying non-negative symbols and non-negative constants will always // give a positive result. This is correct, not complete. // TODO(kramerb): Could the analysis provide a generic interface for this? - if (auto bexpr = expr.dyn_cast()) { + if (auto bexpr = dyn_cast(expr)) { return bexpr.getKind() == AffineExprKind::Mul && isGoodSymbolOrGoodConstantExpr(bexpr.getLHS()) && isGoodSymbolOrGoodConstantExpr(bexpr.getRHS()); @@ -809,15 +810,15 @@ bool SymbolicExpr::isKnownNotNegativeOne() const { } bool SymbolicExpr::isKnownNotOne() const { - if (auto constExpr = expr.dyn_cast()) { + if (auto constExpr = dyn_cast(expr)) { return constExpr.getValue() != 1; } return false; } std::optional SymbolicExpr::singleton() const { - if (expr.isa() && - expr.cast().getPosition() == 0) { + if (isa(expr) && + cast(expr).getPosition() == 0) { assert(symbols.size() == 1); return symbols[0]; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index 458329ac9386f5..bf36a140ce7e8c 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -113,8 +114,8 @@ struct SimplifyBroadcasts : public mlir::OpRewritePattern { llvm::map_range(symResult, [&](const auto &symResultDim) { // If we know the dimension statically, use a constant. if (!symResultDim) return findOrCreateConstant(1); - if (auto cexpr = symResultDim->expr.expr - .template dyn_cast()) { + if (auto cexpr = + dyn_cast(symResultDim->expr.expr)) { return findOrCreateConstant(cexpr.getValue()); } @@ -243,16 +244,16 @@ struct RemoveComputeReshapeShape final bool isProduct(AffineExpr expr, llvm::function_ref cbkConstantFactor, llvm::function_ref cbkSymbolicFactor) { - auto binExpr = expr.dyn_cast(); + auto binExpr = dyn_cast(expr); if (binExpr && binExpr.getKind() == AffineExprKind::Mul) { return isProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) && isProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor); } - if (auto symExpr = expr.dyn_cast()) { + if (auto symExpr = dyn_cast(expr)) { cbkSymbolicFactor(symExpr); return true; } - if (auto constExpr = expr.dyn_cast()) { + if (auto constExpr = dyn_cast(expr)) { cbkConstantFactor(constExpr); return true; } @@ -630,7 +631,7 @@ SmallVector concretizeOperandShape( for (auto it : llvm::zip(operandShape, operandShapeInfo)) { auto dimSize = std::get<0>(it); auto sExpr = std::get<1>(it); - if (auto cexpr = sExpr.expr.dyn_cast()) { + if (auto cexpr = dyn_cast(sExpr.expr)) { int64_t alsoDimSize = cexpr.getValue(); assert((ShapedType::isDynamic(dimSize) || dimSize == alsoDimSize) && "expect shape analysis result to be compatible with type"); From bdb33f1e1161d856446ca552f7200b979c6b3cf7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 16 Nov 2023 05:29:05 -0800 Subject: [PATCH 165/391] [XLA:GPU][NFC] Add a TritonTest base class in tests. TritonFilecheckTests encompass non-GEMM tests, so we adjust the class hierarchy to reflect that. PiperOrigin-RevId: 583008449 --- .../xla/xla/service/gpu/ir_emitter_triton_test.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 682c430c82d414..1a6b11b77c0bd5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -61,7 +61,7 @@ namespace { namespace m = ::xla::match; -class TritonGemmTest : public GpuCodegenTest { +class TritonTest : public GpuCodegenTest { public: se::CudaComputeCapability GetCudaComputeCapability() { return backend() @@ -69,8 +69,12 @@ class TritonGemmTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } +}; + +class TritonGemmTest : public TritonTest { + public: DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; } @@ -85,7 +89,7 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; -class TritonFilecheckTest : public TritonGemmTest { +class TritonFilecheckTest : public TritonTest { public: StatusOr CreateTritonIrAndFileCheck( absl::string_view hlo_text, const TritonGemmConfig& config, From 9af3cf09d01f7396b3d72e4e3be6a424144ed844 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 16 Nov 2023 05:31:51 -0800 Subject: [PATCH 166/391] Priority fusion: cache HloFusionAnalyses. ~33% HLO passes speedup. PiperOrigin-RevId: 583009054 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/hlo_fusion_analysis.cc | 123 +++++++++--------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 34 ++--- .../xla/service/gpu/kernel_mapping_scheme.h | 24 ++-- third_party/xla/xla/service/gpu/model/BUILD | 29 +++++ .../gpu/model/fusion_analysis_cache.cc | 90 +++++++++++++ .../service/gpu/model/fusion_analysis_cache.h | 71 ++++++++++ .../gpu/model/fusion_analysis_cache_test.cc | 115 ++++++++++++++++ .../gpu/model/gpu_performance_model.cc | 36 ++++- .../service/gpu/model/gpu_performance_model.h | 12 +- .../gpu/model/gpu_performance_model_test.cc | 2 +- .../xla/xla/service/gpu/priority_fusion.cc | 20 ++- .../xla/xla/service/gpu/priority_fusion.h | 9 +- .../xla/service/gpu/priority_fusion_test.cc | 8 +- 14 files changed, 464 insertions(+), 110 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 58ae615215ee3b..8445dec5c78840 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2071,6 +2071,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", + "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index bb2fe734a63055..c064c8c9c1770c 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -58,6 +58,35 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; +std::optional ComputeTransposeTilingScheme( + const std::optional& tiled_transpose) { + if (!tiled_transpose) { + return std::nullopt; + } + + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the input shape. + Vector3 dims = tiled_transpose->dimensions; + Vector3 order = tiled_transpose->permutation; + + Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; + Vector3 tile_sizes{1, 1, 1}; + tile_sizes[order[2]] = WarpSize() / kNumRows; + Vector3 num_threads{1, 1, WarpSize()}; + num_threads[order[2]] = kNumRows; + + return TilingScheme( + /*permuted_dims*/ permuted_dims, + /*tile_sizes=*/tile_sizes, + /*num_threads=*/num_threads, + /*indexing_order=*/kLinearIndexingX, + /*vector_size=*/1, + /*scaling_factor=*/1, + /*tiling_dimensions=*/{order[2], 2}); +} + // Returns true if `instr` is a non-strided slice. bool IsSliceWithUnitStrides(const HloInstruction* instr) { auto slice = DynCast(instr); @@ -257,6 +286,28 @@ std::optional FindConsistentTransposeHero( } // namespace +HloFusionAnalysis::HloFusionAnalysis( + FusionBackendConfig fusion_backend_config, + std::vector fusion_roots, + FusionBoundaryFn fusion_boundary_fn, + std::vector fusion_arguments, + std::vector fusion_heroes, + const se::DeviceDescription* device_info, + std::optional tiled_transpose, bool has_4_bit_input, + bool has_4_bit_output) + : fusion_backend_config_(std::move(fusion_backend_config)), + fusion_roots_(std::move(fusion_roots)), + fusion_boundary_fn_(std::move(fusion_boundary_fn)), + fusion_arguments_(std::move(fusion_arguments)), + fusion_heroes_(std::move(fusion_heroes)), + device_info_(device_info), + tiled_transpose_(tiled_transpose), + has_4_bit_input_(has_4_bit_input), + has_4_bit_output_(has_4_bit_output), + reduction_codegen_info_(ComputeReductionCodegenInfo(FindHeroReduction())), + transpose_tiling_scheme_(ComputeTransposeTilingScheme(tiled_transpose_)), + loop_fusion_config_(ComputeLoopFusionConfig()) {} + // static StatusOr HloFusionAnalysis::Create( FusionBackendConfig backend_config, @@ -353,7 +404,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kLoop; } -StatusOr HloFusionAnalysis::GetLaunchDimensions() { +StatusOr HloFusionAnalysis::GetLaunchDimensions() const { auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { @@ -403,7 +454,9 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions() { } const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { - CHECK(GetEmitterFusionKind() == EmitterFusionKind::kReduction); + if (GetEmitterFusionKind() != EmitterFusionKind::kReduction) { + return nullptr; + } auto roots = fusion_roots(); CHECK(!roots.empty()); // We always use the first reduce root that triggers unnested reduction @@ -418,57 +471,8 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { - if (reduction_codegen_info_.has_value()) { - return &reduction_codegen_info_.value(); - } - - const HloInstruction* hero_reduction = FindHeroReduction(); - - auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); - reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); - return &reduction_codegen_info_.value(); -} - -const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { - if (transpose_tiling_scheme_.has_value()) { - return &transpose_tiling_scheme_.value(); - } - - if (!tiled_transpose_) { - return nullptr; - } - - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the input shape. - Vector3 dims = tiled_transpose_->dimensions; - Vector3 order = tiled_transpose_->permutation; - - Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; - Vector3 tile_sizes{1, 1, 1}; - tile_sizes[order[2]] = WarpSize() / kNumRows; - Vector3 num_threads{1, 1, WarpSize()}; - num_threads[order[2]] = kNumRows; - - TilingScheme tiling_scheme( - /*permuted_dims*/ permuted_dims, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1, - /*scaling_factor=*/1, - /*tiling_dimensions=*/{order[2], 2}); - transpose_tiling_scheme_.emplace(std::move(tiling_scheme)); - return &transpose_tiling_scheme_.value(); -} - -const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { - if (loop_fusion_config_.has_value()) { - return &loop_fusion_config_.value(); - } - +std::optional +HloFusionAnalysis::ComputeLoopFusionConfig() const { int unroll_factor = 1; // Unrolling is good to read large inputs with small elements // due to vector loads, but increases the register pressure when one @@ -501,8 +505,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) { // Only the unroll factor is used for scatter. - loop_fusion_config_.emplace(LaunchDimensionsConfig{unroll_factor}); - return &loop_fusion_config_.value(); + return LaunchDimensionsConfig{unroll_factor}; } bool row_vectorized; @@ -537,8 +540,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { launch_config.row_vectorized = false; launch_config.few_waves = false; } - loop_fusion_config_.emplace(std::move(launch_config)); - return &loop_fusion_config_.value(); + return launch_config; } const Shape& HloFusionAnalysis::GetElementShape() const { @@ -809,8 +811,13 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( +std::optional +HloFusionAnalysis::ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const { + if (!hero_reduction) { + return std::nullopt; + } + Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index c07819db2d3a15..1bec5ca650be47 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -68,19 +68,27 @@ class HloFusionAnalysis { // Determines the launch dimensions for the fusion. The fusion kind must not // be `kTriton`. - StatusOr GetLaunchDimensions(); + StatusOr GetLaunchDimensions() const; // Calculates the reduction information. Returns `nullptr` if the fusion is // not a reduction. - const ReductionCodegenInfo* GetReductionCodegenInfo(); + const ReductionCodegenInfo* GetReductionCodegenInfo() const { + return reduction_codegen_info_.has_value() ? &*reduction_codegen_info_ + : nullptr; + } // Calculates the transpose tiling information. Returns `nullptr` if the // fusion is not a transpose. - const TilingScheme* GetTransposeTilingScheme(); + const TilingScheme* GetTransposeTilingScheme() const { + return transpose_tiling_scheme_.has_value() ? &*transpose_tiling_scheme_ + : nullptr; + } // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a // loop. - const LaunchDimensionsConfig* GetLoopFusionConfig(); + const LaunchDimensionsConfig* GetLoopFusionConfig() const { + return loop_fusion_config_.has_value() ? &*loop_fusion_config_ : nullptr; + } // Returns the hero reduction of the computation. const HloInstruction* FindHeroReduction() const; @@ -93,16 +101,7 @@ class HloFusionAnalysis { std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, - bool has_4_bit_input, bool has_4_bit_output) - : fusion_backend_config_(std::move(fusion_backend_config)), - fusion_roots_(std::move(fusion_roots)), - fusion_boundary_fn_(std::move(fusion_boundary_fn)), - fusion_arguments_(std::move(fusion_arguments)), - fusion_heroes_(std::move(fusion_heroes)), - device_info_(device_info), - tiled_transpose_(tiled_transpose), - has_4_bit_input_(has_4_bit_input), - has_4_bit_output_(has_4_bit_output) {} + bool has_4_bit_input, bool has_4_bit_output); const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; @@ -118,8 +117,9 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - ReductionCodegenInfo ComputeReductionCodegenInfo( + std::optional ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const; + std::optional ComputeLoopFusionConfig() const; bool HasConsistentTransposeHeros() const; FusionBackendConfig fusion_backend_config_; @@ -131,8 +131,8 @@ class HloFusionAnalysis { std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; - const bool has_4_bit_input_ = false; - const bool has_4_bit_output_ = false; + bool has_4_bit_input_ = false; + bool has_4_bit_output_ = false; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index 4a6f0f7ae3c6fa..f7b51c42c6beaf 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -146,34 +146,34 @@ class TilingScheme { private: // The number of elements in each dimension. - const Vector3 dims_in_elems_; + Vector3 dims_in_elems_; // The number of elements for each dimension of a tile. - const Vector3 tile_sizes_; + Vector3 tile_sizes_; // The dimensions which are used for the shared memory tile. - const Vector2 tiling_dimensions_; + Vector2 tiling_dimensions_; // Number of threads implicitly assigned to each dimension. - const Vector3 num_threads_; + Vector3 num_threads_; - const IndexingOrder indexing_order_; + IndexingOrder indexing_order_; // Vector size for dimension X. - const int vector_size_; + int vector_size_; // Scaling apply to transform physical threadIdx into logical. - const int64_t thread_id_virtual_scaling_ = 1; + int64_t thread_id_virtual_scaling_ = 1; }; class ReductionCodegenInfo { public: using IndexGroups = std::vector>; - explicit ReductionCodegenInfo(TilingScheme mapping_scheme, - int num_partial_results, bool is_row_reduction, - bool is_race_free, IndexGroups index_groups, - const HloInstruction* first_reduce) + ReductionCodegenInfo(TilingScheme mapping_scheme, int num_partial_results, + bool is_row_reduction, bool is_race_free, + IndexGroups index_groups, + const HloInstruction* first_reduce) : tiling_scheme_(mapping_scheme), num_partial_results_(num_partial_results), is_row_reduction_(is_row_reduction), @@ -198,7 +198,7 @@ class ReductionCodegenInfo { private: friend class ReductionCodegenState; - const TilingScheme tiling_scheme_; + TilingScheme tiling_scheme_; int num_partial_results_; bool is_row_reduction_; bool is_race_free_; diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 1b542e7dca447a..034f2a2f20d2ff 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -66,6 +66,34 @@ xla_test( ], ) +cc_library( + name = "fusion_analysis_cache", + srcs = ["fusion_analysis_cache.cc"], + hdrs = ["fusion_analysis_cache.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/synchronization", + ], +) + +xla_cc_test( + name = "fusion_analysis_cache_test", + srcs = ["fusion_analysis_cache_test.cc"], + deps = [ + ":fusion_analysis_cache", + "//xla/service:hlo_parser", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "gpu_cost_model_stats_collection", srcs = ["gpu_cost_model_stats_collection.cc"], @@ -149,6 +177,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", "//xla:shape_util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc new file mode 100644 index 00000000000000..59e9499b3c6d62 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc @@ -0,0 +1,90 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/fusion_analysis_cache.h" + +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla::gpu { + +const std::optional& HloFusionAnalysisCache::Get( + const HloInstruction& instruction) { + { + absl::ReaderMutexLock lock(&mutex_); + auto it = analyses_.find(&instruction); + if (it != analyses_.end()) { + return it->second; + } + } + + std::optional analysis = + AnalyzeFusion(instruction, device_info_); + absl::MutexLock lock(&mutex_); + + // If some other thread created an entry for this key concurrently, return + // that instead (the other thread is likely using the instance). + auto it = analyses_.find(&instruction); + if (it != analyses_.end()) { + return it->second; + } + + return analyses_[&instruction] = std::move(analysis); +} + +const std::optional& HloFusionAnalysisCache::Get( + const HloInstruction& producer, const HloInstruction& consumer) { + std::pair key{&producer, + &consumer}; + { + absl::ReaderMutexLock lock(&mutex_); + auto it = producer_consumer_analyses_.find(key); + if (it != producer_consumer_analyses_.end()) { + return it->second; + } + } + + std::optional analysis = + AnalyzeProducerConsumerFusion(producer, consumer, device_info_); + absl::MutexLock lock(&mutex_); + + // If some other thread created an entry for this key concurrently, return + // that instead (the other thread is likely using the instance). + auto it = producer_consumer_analyses_.find(key); + if (it != producer_consumer_analyses_.end()) { + return it->second; + } + + producers_for_consumers_[&consumer].push_back(&producer); + consumers_for_producers_[&producer].push_back(&consumer); + return producer_consumer_analyses_[key] = std::move(analysis); +} + +void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + analyses_.erase(&instruction); + + if (auto consumers = consumers_for_producers_.extract(&instruction)) { + for (const auto* consumer : consumers.mapped()) { + producer_consumer_analyses_.erase({&instruction, consumer}); + } + } + if (auto producers = producers_for_consumers_.extract(&instruction)) { + for (const auto* producer : producers.mapped()) { + producer_consumer_analyses_.erase({producer, &instruction}); + } + } +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h new file mode 100644 index 00000000000000..f21a4fdee0fcad --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h @@ -0,0 +1,71 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ +#define XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ + +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` +// and `Invalidate` the same key. Analyses are cached based on pointer-identity, +// no checking of changes is done. +class HloFusionAnalysisCache { + public: + explicit HloFusionAnalysisCache( + const stream_executor::DeviceDescription& device_info) + : device_info_(device_info) {} + + // Returns the analysis for the given instruction, creating it if it doesn't + // exist yet. Do not call concurrently with `Invalidate` for the same key. + const std::optional& Get( + const HloInstruction& instruction); + + // Returns the analysis for the given producer/consumer pair. + const std::optional& Get(const HloInstruction& producer, + const HloInstruction& consumer); + + // Removes the cache entry for the given instruction, if it exists. Also + // removes all producer-consumer fusions that involve this instruction. + void Invalidate(const HloInstruction& instruction); + + private: + const stream_executor::DeviceDescription& device_info_; + + absl::Mutex mutex_; + absl::node_hash_map> + analyses_; + absl::node_hash_map, + std::optional> + producer_consumer_analyses_; + + // For each instruction `producer`, contains the `consumer`s for which we have + // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. + absl::flat_hash_map> + consumers_for_producers_; + // For each instruction `consumer`, contains the `producer`s for which we have + // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. + absl::flat_hash_map> + producers_for_consumers_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc new file mode 100644 index 00000000000000..edacd6a7c8666b --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/fusion_analysis_cache.h" + +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/hlo_parser.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla::gpu { +namespace { + +class FusionAnalysisCacheTest : public HloTestBase { + public: + stream_executor::DeviceDescription device_{ + TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache cache_{device_}; +}; + +TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { + absl::string_view hlo_string = R"( + HloModule m + + f { + c0 = f32[] constant(0) + b0 = f32[1000] broadcast(c0) + ROOT n0 = f32[1000] negate(b0) + } + + ENTRY e { + ROOT r.1 = f32[1000] fusion(), kind=kLoop, calls=f + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto* computation = module->GetComputationWithName("f"); + auto* broadcast = computation->GetInstructionWithName("b0"); + auto* negate = computation->GetInstructionWithName("n0"); + auto* fusion = module->entry_computation()->root_instruction(); + + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(negate)); + + computation->set_root_instruction(broadcast); + + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(negate)) + << "Analysis should be cached."; + + cache_.Invalidate(*fusion); + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(broadcast)) + << "Analysis should have been recomputed"; +} + +TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { + absl::string_view hlo_string = R"( + HloModule m + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + f { + c0 = f32[] constant(0) + b0 = f32[1000] broadcast(c0) + ROOT r0 = f32[] reduce(b0, c0), dimensions={0}, to_apply=add + } + + ENTRY e { + f0 = f32[] fusion(), kind=kInput, calls=f + ROOT n0 = f32[] negate(f0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto* fusion = module->entry_computation()->GetInstructionWithName("f0"); + auto* neg = module->entry_computation()->GetInstructionWithName("n0"); + + auto* computation = module->GetComputationWithName("f"); + auto* constant = computation->GetInstructionWithName("c0"); + + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kReduction); + + computation->set_root_instruction(constant); + + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kReduction) + << "Analysis should be cached."; + + cache_.Invalidate(*fusion); + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kLoop) + << "Analysis should have been recomputed"; +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 7d5045cef03337..0c320e7b40840b 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -224,7 +224,7 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, // that the IR emitter will use. LaunchDimensions EstimateFusionLaunchDimensions( int64_t estimated_num_threads, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const se::DeviceDescription& device_info) { if (fusion_analysis) { // TODO(jreiffers): This is the wrong place for this DUS analysis. @@ -269,7 +269,15 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; - auto fusion_analysis = AnalyzeFusion(*instr, *cost_analysis->device_info_); + // Use the analysis cache if present. + // TODO(jreiffers): Remove this once all callers use a cache. + std::optional local_analysis = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeFusion(*instr, *cost_analysis->device_info_); + const auto& fusion_analysis = config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*instr) + : local_analysis; LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(instr->shape()), fusion_analysis, *device_info); @@ -303,7 +311,7 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); @@ -413,7 +421,16 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); + // Use the analysis cache if present. + // TODO(jreiffers): Remove this once all callers use a cache. + std::optional local_analysis = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeFusion(*fused_consumer, *device_info); + const auto& analysis_unfused = + config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*fused_consumer) + : local_analysis; LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(fused_consumer->shape()), @@ -462,8 +479,15 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( // // TODO(shyshkov): Add calculations for consumer epilogue in the formula to // make it complete. - auto analysis_fused = - AnalyzeProducerConsumerFusion(*producer, *fused_consumer, *device_info); + std::optional local_analysis_fused = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeProducerConsumerFusion(*producer, *fused_consumer, + *device_info); + const auto& analysis_fused = + config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*producer, *fused_consumer) + : local_analysis_fused; LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_data.num_threads * utilization_by_this_consumer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index b7b28fff1eeda7..0fcc8cfcb2abf2 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" @@ -62,20 +63,25 @@ struct GpuPerformanceModelOptions { // re-reads can happen from cache. bool first_read_from_dram = false; + // If present, use this to retrieve fusion analyses. + HloFusionAnalysisCache* fusion_analysis_cache = nullptr; + static GpuPerformanceModelOptions Default() { return GpuPerformanceModelOptions(); } - static GpuPerformanceModelOptions PriorityFusion() { + static GpuPerformanceModelOptions PriorityFusion( + HloFusionAnalysisCache* fusion_analysis_cache) { GpuPerformanceModelOptions config; config.consider_coalescing = true; config.first_read_from_dram = true; + config.fusion_analysis_cache = fusion_analysis_cache; return config; } static GpuPerformanceModelOptions ForModule(const HloModule* module) { return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion() + ? PriorityFusion(nullptr) // Only cache within priority fusion. : Default(); } }; @@ -121,7 +127,7 @@ class GpuPerformanceModel { const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer = nullptr); }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 68bde4b9010382..d768bce08c55ef 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -360,7 +360,7 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(), + producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 8df808c1bfad82..10fe7924d10c2c 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/instruction_fusion.h" @@ -79,12 +80,14 @@ class GpuPriorityFusionQueue : public FusionQueue { const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool) + tsl::thread::ThreadPool* thread_pool, + HloFusionAnalysisCache& fusion_analysis_cache) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), can_fuse_(can_fuse), fusion_process_dump_(fusion_process_dump), - thread_pool_(thread_pool) { + thread_pool_(thread_pool), + fusion_analysis_cache_(fusion_analysis_cache) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -181,6 +184,9 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } + fusion_analysis_cache_.Invalidate(*original_producer); + fusion_analysis_cache_.Invalidate(*original_consumer); + // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. // Priority recomputation is called before DCE. Remove all references to @@ -289,7 +295,8 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, &cost_analysis_, - GpuPerformanceModelOptions::PriorityFusion(), producer->users()); + GpuPerformanceModelOptions::PriorityFusion(&fusion_analysis_cache_), + producer->users()); if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = @@ -365,6 +372,8 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::Mutex fusion_process_dump_mutex_; tsl::thread::ThreadPool* thread_pool_; + + HloFusionAnalysisCache& fusion_analysis_cache_; }; } // namespace @@ -502,8 +511,7 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( // matter but some passes downstream still query these instead of fusion // analysis. // TODO: Don't recompute this all the time. - auto analysis = - AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); + const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer); if (!analysis) return HloInstruction::FusionKind::kLoop; switch (analysis->GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kLoop: @@ -544,7 +552,7 @@ std::unique_ptr GpuPriorityFusion::GetFusionQueue( [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index); }, - fusion_process_dump_.get(), thread_pool_)); + fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index 1723766d7784c8..afc5e8f99003d4 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" @@ -42,12 +43,13 @@ namespace gpu { class GpuPriorityFusion : public InstructionFusion { public: GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& d, + const se::DeviceDescription& device, GpuHloCostAnalysis::Options cost_analysis_options) : InstructionFusion(GpuPriorityFusion::IsExpensive), thread_pool_(thread_pool), - device_info_(d), - cost_analysis_options_(std::move(cost_analysis_options)) {} + device_info_(device), + cost_analysis_options_(std::move(cost_analysis_options)), + fusion_analysis_cache_(device_info_) {} absl::string_view name() const override { return "priority-fusion"; } @@ -86,6 +88,7 @@ class GpuPriorityFusion : public InstructionFusion { absl::Mutex fusion_node_evaluations_mutex_; absl::flat_hash_map fusion_node_evaluations_; + HloFusionAnalysisCache fusion_analysis_cache_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 5574173a75ef11..310340af91d391 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -237,10 +237,10 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { } )"; - EXPECT_THAT( - RunAndGetFusionKinds(kHlo), - ::testing::ElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop, - HloFusionAnalysis::EmitterFusionKind::kReduction)); + EXPECT_THAT(RunAndGetFusionKinds(kHlo), + ::testing::UnorderedElementsAre( + HloFusionAnalysis::EmitterFusionKind::kLoop, + HloFusionAnalysis::EmitterFusionKind::kReduction)); RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY From 0396789f3989c6a68c5c2a090109b6e20301464e Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 16 Nov 2023 06:44:08 -0800 Subject: [PATCH 167/391] Refactor FloatSupport to allow configuring the high precision type. PiperOrigin-RevId: 583026800 --- .../xla/xla/service/cpu/cpu_compiler.cc | 10 ++++---- .../xla/service/float_normalization_test.cc | 23 +++++++++++-------- third_party/xla/xla/service/float_support.h | 18 +++++---------- .../xla/xla/service/gpu/gpu_compiler.cc | 10 ++++---- .../xla/xla/service/gpu/gpu_float_support.h | 5 ++-- 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 376da95f9e8837..7b51d877fc4b2f 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -713,15 +713,15 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( // BF16/F8 lowering for most ops. FloatSupport bf16_support(BF16); pipeline.AddPass(&bf16_support); - FloatSupport f8e5m2_support(F8E5M2); + FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); - FloatSupport f8e4m3fn_support(F8E4M3FN); + FloatSupport f8e4m3fn_support(F8E4M3FN, F16); pipeline.AddPass(&f8e4m3fn_support); - FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); + FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); pipeline.AddPass(&f8e4m3b11fnuz_support); - FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); + FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16); pipeline.AddPass(&f8e5m2fnuz_support); - FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); pipeline.AddPass(&f8e4m3fnuz_support); // After canonicalization, there may be more batch dots that can be // simplified. diff --git a/third_party/xla/xla/service/float_normalization_test.cc b/third_party/xla/xla/service/float_normalization_test.cc index 2d6a976ff59df4..fb6c133a997897 100644 --- a/third_party/xla/xla/service/float_normalization_test.cc +++ b/third_party/xla/xla/service/float_normalization_test.cc @@ -35,8 +35,9 @@ namespace xla { class TestFloatSupport : public FloatSupport { public: - explicit TestFloatSupport(PrimitiveType low_precision_type) - : FloatSupport(low_precision_type) {} + explicit TestFloatSupport(PrimitiveType low_precision_type, + PrimitiveType high_precision_type) + : FloatSupport(low_precision_type, high_precision_type) {} ~TestFloatSupport() override = default; bool SupportsLowPrecisionOperand(const HloInstruction& hlo, @@ -80,8 +81,9 @@ class TestFloatSupport : public FloatSupport { // but supports some collectives. class TestFloatNoComputeSupport : public FloatSupport { public: - explicit TestFloatNoComputeSupport(PrimitiveType low_precision_type) - : FloatSupport(low_precision_type) {} + explicit TestFloatNoComputeSupport(PrimitiveType low_precision_type, + PrimitiveType high_precision_type) + : FloatSupport(low_precision_type, high_precision_type) {} ~TestFloatNoComputeSupport() override = default; bool SupportsLowPrecisionOperand(const HloInstruction& hlo, @@ -114,8 +116,9 @@ class FloatNormalizationTest : public HloTestBase { : HloTestBase(/*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true) {} - bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16) { - TestFloatSupport float_support(low_precision_type); + bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16, + PrimitiveType high_precision_type = F32) { + TestFloatSupport float_support(low_precision_type, high_precision_type); FloatNormalization normalization(&float_support); StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); @@ -508,7 +511,7 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get(), F8E5M2)); + EXPECT_TRUE(Normalize(module.get(), F8E5M2, F16)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -519,8 +522,10 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest { protected: - bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16) { - TestFloatNoComputeSupport float_support(low_precision_type); + bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16, + PrimitiveType high_precision_type = F32) { + TestFloatNoComputeSupport float_support(low_precision_type, + high_precision_type); FloatNormalization normalization(&float_support); StatusOr result = normalization.Run(module); diff --git a/third_party/xla/xla/service/float_support.h b/third_party/xla/xla/service/float_support.h index 9a2691be1f27f6..4ee4e7157fa4c5 100644 --- a/third_party/xla/xla/service/float_support.h +++ b/third_party/xla/xla/service/float_support.h @@ -27,8 +27,10 @@ namespace xla { // backend. class FloatSupport { public: - explicit FloatSupport(PrimitiveType low_precision_type) - : low_precision_type_(low_precision_type) {} + explicit FloatSupport(PrimitiveType low_precision_type, + PrimitiveType high_precision_type = F32) + : low_precision_type_(low_precision_type), + high_precision_type_(high_precision_type) {} virtual ~FloatSupport() = default; // The low-precision type. Callers can use this class to query whether the @@ -38,16 +40,7 @@ class FloatSupport { // A high-precision type that should be used in place of the low-precision // type if the backend does not support the low-precision type for a certain // instruction. - PrimitiveType HighPrecisionType() const { - if (low_precision_type_ == F8E5M2 || low_precision_type_ == F8E4M3FN || - low_precision_type_ == F8E4M3B11FNUZ || - low_precision_type_ == F8E5M2FNUZ || - low_precision_type_ == F8E4M3FNUZ) { - return F16; - } - DCHECK_EQ(low_precision_type_, BF16); - return F32; - } + PrimitiveType HighPrecisionType() const { return high_precision_type_; } // Returns whether the backend supports a low-precision operand for the HLO // instruction at the given index. @@ -82,6 +75,7 @@ class FloatSupport { private: PrimitiveType low_precision_type_; + PrimitiveType high_precision_type_; }; } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 551a386c203de8..924b75ac0193b7 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1014,11 +1014,11 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( gpu_target_config)); // Lambdas and related constants: const GpuFloatSupport bf16_support(BF16); - const GpuFloatSupport f8e5m2_support(F8E5M2); - const GpuFloatSupport f8e4m3fn_support(F8E4M3FN); - const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ); - const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ); - const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ); + const GpuFloatSupport f8e5m2_support(F8E5M2, F16); + const GpuFloatSupport f8e4m3fn_support(F8E4M3FN, F16); + const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); + const FloatSupport f8e5m2fnuz_support(F8E5M2FNUZ, F16); + const FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.h b/third_party/xla/xla/service/gpu/gpu_float_support.h index c9e0e2ac0c48e7..2bdc64a739f85c 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.h +++ b/third_party/xla/xla/service/gpu/gpu_float_support.h @@ -27,8 +27,9 @@ namespace gpu { class GpuFloatSupport : public FloatSupport { public: - explicit GpuFloatSupport(PrimitiveType low_precision_type) - : FloatSupport(low_precision_type) {} + explicit GpuFloatSupport(PrimitiveType low_precision_type, + PrimitiveType high_precision_type = F32) + : FloatSupport(low_precision_type, high_precision_type) {} bool SupportsLowPrecisionOperand(const HloInstruction& hlo, int64_t operand_index) const override { From 284506bb05fb4c824695747801ce7b054448215c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 16 Nov 2023 06:47:03 -0800 Subject: [PATCH 168/391] Priority fusion: cache HloFusionAnalyses. ~33% HLO passes speedup. PiperOrigin-RevId: 583027442 --- third_party/xla/xla/service/gpu/BUILD | 1 - .../xla/service/gpu/hlo_fusion_analysis.cc | 123 +++++++++--------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 34 ++--- .../xla/service/gpu/kernel_mapping_scheme.h | 24 ++-- third_party/xla/xla/service/gpu/model/BUILD | 29 ----- .../gpu/model/fusion_analysis_cache.cc | 90 ------------- .../service/gpu/model/fusion_analysis_cache.h | 71 ---------- .../gpu/model/fusion_analysis_cache_test.cc | 115 ---------------- .../gpu/model/gpu_performance_model.cc | 36 +---- .../service/gpu/model/gpu_performance_model.h | 12 +- .../gpu/model/gpu_performance_model_test.cc | 2 +- .../xla/xla/service/gpu/priority_fusion.cc | 20 +-- .../xla/xla/service/gpu/priority_fusion.h | 9 +- .../xla/service/gpu/priority_fusion_test.cc | 8 +- 14 files changed, 110 insertions(+), 464 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8445dec5c78840..58ae615215ee3b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2071,7 +2071,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", - "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index c064c8c9c1770c..bb2fe734a63055 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -58,35 +58,6 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; -std::optional ComputeTransposeTilingScheme( - const std::optional& tiled_transpose) { - if (!tiled_transpose) { - return std::nullopt; - } - - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the input shape. - Vector3 dims = tiled_transpose->dimensions; - Vector3 order = tiled_transpose->permutation; - - Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; - Vector3 tile_sizes{1, 1, 1}; - tile_sizes[order[2]] = WarpSize() / kNumRows; - Vector3 num_threads{1, 1, WarpSize()}; - num_threads[order[2]] = kNumRows; - - return TilingScheme( - /*permuted_dims*/ permuted_dims, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1, - /*scaling_factor=*/1, - /*tiling_dimensions=*/{order[2], 2}); -} - // Returns true if `instr` is a non-strided slice. bool IsSliceWithUnitStrides(const HloInstruction* instr) { auto slice = DynCast(instr); @@ -286,28 +257,6 @@ std::optional FindConsistentTransposeHero( } // namespace -HloFusionAnalysis::HloFusionAnalysis( - FusionBackendConfig fusion_backend_config, - std::vector fusion_roots, - FusionBoundaryFn fusion_boundary_fn, - std::vector fusion_arguments, - std::vector fusion_heroes, - const se::DeviceDescription* device_info, - std::optional tiled_transpose, bool has_4_bit_input, - bool has_4_bit_output) - : fusion_backend_config_(std::move(fusion_backend_config)), - fusion_roots_(std::move(fusion_roots)), - fusion_boundary_fn_(std::move(fusion_boundary_fn)), - fusion_arguments_(std::move(fusion_arguments)), - fusion_heroes_(std::move(fusion_heroes)), - device_info_(device_info), - tiled_transpose_(tiled_transpose), - has_4_bit_input_(has_4_bit_input), - has_4_bit_output_(has_4_bit_output), - reduction_codegen_info_(ComputeReductionCodegenInfo(FindHeroReduction())), - transpose_tiling_scheme_(ComputeTransposeTilingScheme(tiled_transpose_)), - loop_fusion_config_(ComputeLoopFusionConfig()) {} - // static StatusOr HloFusionAnalysis::Create( FusionBackendConfig backend_config, @@ -404,7 +353,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kLoop; } -StatusOr HloFusionAnalysis::GetLaunchDimensions() const { +StatusOr HloFusionAnalysis::GetLaunchDimensions() { auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { @@ -454,9 +403,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions() const { } const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { - if (GetEmitterFusionKind() != EmitterFusionKind::kReduction) { - return nullptr; - } + CHECK(GetEmitterFusionKind() == EmitterFusionKind::kReduction); auto roots = fusion_roots(); CHECK(!roots.empty()); // We always use the first reduce root that triggers unnested reduction @@ -471,8 +418,57 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -std::optional -HloFusionAnalysis::ComputeLoopFusionConfig() const { +const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { + if (reduction_codegen_info_.has_value()) { + return &reduction_codegen_info_.value(); + } + + const HloInstruction* hero_reduction = FindHeroReduction(); + + auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); + reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); + return &reduction_codegen_info_.value(); +} + +const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { + if (transpose_tiling_scheme_.has_value()) { + return &transpose_tiling_scheme_.value(); + } + + if (!tiled_transpose_) { + return nullptr; + } + + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the input shape. + Vector3 dims = tiled_transpose_->dimensions; + Vector3 order = tiled_transpose_->permutation; + + Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; + Vector3 tile_sizes{1, 1, 1}; + tile_sizes[order[2]] = WarpSize() / kNumRows; + Vector3 num_threads{1, 1, WarpSize()}; + num_threads[order[2]] = kNumRows; + + TilingScheme tiling_scheme( + /*permuted_dims*/ permuted_dims, + /*tile_sizes=*/tile_sizes, + /*num_threads=*/num_threads, + /*indexing_order=*/kLinearIndexingX, + /*vector_size=*/1, + /*scaling_factor=*/1, + /*tiling_dimensions=*/{order[2], 2}); + transpose_tiling_scheme_.emplace(std::move(tiling_scheme)); + return &transpose_tiling_scheme_.value(); +} + +const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { + if (loop_fusion_config_.has_value()) { + return &loop_fusion_config_.value(); + } + int unroll_factor = 1; // Unrolling is good to read large inputs with small elements // due to vector loads, but increases the register pressure when one @@ -505,7 +501,8 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) { // Only the unroll factor is used for scatter. - return LaunchDimensionsConfig{unroll_factor}; + loop_fusion_config_.emplace(LaunchDimensionsConfig{unroll_factor}); + return &loop_fusion_config_.value(); } bool row_vectorized; @@ -540,7 +537,8 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { launch_config.row_vectorized = false; launch_config.few_waves = false; } - return launch_config; + loop_fusion_config_.emplace(std::move(launch_config)); + return &loop_fusion_config_.value(); } const Shape& HloFusionAnalysis::GetElementShape() const { @@ -811,13 +809,8 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -std::optional -HloFusionAnalysis::ComputeReductionCodegenInfo( +ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const { - if (!hero_reduction) { - return std::nullopt; - } - Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index 1bec5ca650be47..c07819db2d3a15 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -68,27 +68,19 @@ class HloFusionAnalysis { // Determines the launch dimensions for the fusion. The fusion kind must not // be `kTriton`. - StatusOr GetLaunchDimensions() const; + StatusOr GetLaunchDimensions(); // Calculates the reduction information. Returns `nullptr` if the fusion is // not a reduction. - const ReductionCodegenInfo* GetReductionCodegenInfo() const { - return reduction_codegen_info_.has_value() ? &*reduction_codegen_info_ - : nullptr; - } + const ReductionCodegenInfo* GetReductionCodegenInfo(); // Calculates the transpose tiling information. Returns `nullptr` if the // fusion is not a transpose. - const TilingScheme* GetTransposeTilingScheme() const { - return transpose_tiling_scheme_.has_value() ? &*transpose_tiling_scheme_ - : nullptr; - } + const TilingScheme* GetTransposeTilingScheme(); // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a // loop. - const LaunchDimensionsConfig* GetLoopFusionConfig() const { - return loop_fusion_config_.has_value() ? &*loop_fusion_config_ : nullptr; - } + const LaunchDimensionsConfig* GetLoopFusionConfig(); // Returns the hero reduction of the computation. const HloInstruction* FindHeroReduction() const; @@ -101,7 +93,16 @@ class HloFusionAnalysis { std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, - bool has_4_bit_input, bool has_4_bit_output); + bool has_4_bit_input, bool has_4_bit_output) + : fusion_backend_config_(std::move(fusion_backend_config)), + fusion_roots_(std::move(fusion_roots)), + fusion_boundary_fn_(std::move(fusion_boundary_fn)), + fusion_arguments_(std::move(fusion_arguments)), + fusion_heroes_(std::move(fusion_heroes)), + device_info_(device_info), + tiled_transpose_(tiled_transpose), + has_4_bit_input_(has_4_bit_input), + has_4_bit_output_(has_4_bit_output) {} const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; @@ -117,9 +118,8 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - std::optional ComputeReductionCodegenInfo( + ReductionCodegenInfo ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const; - std::optional ComputeLoopFusionConfig() const; bool HasConsistentTransposeHeros() const; FusionBackendConfig fusion_backend_config_; @@ -131,8 +131,8 @@ class HloFusionAnalysis { std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; - bool has_4_bit_input_ = false; - bool has_4_bit_output_ = false; + const bool has_4_bit_input_ = false; + const bool has_4_bit_output_ = false; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index f7b51c42c6beaf..4a6f0f7ae3c6fa 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -146,34 +146,34 @@ class TilingScheme { private: // The number of elements in each dimension. - Vector3 dims_in_elems_; + const Vector3 dims_in_elems_; // The number of elements for each dimension of a tile. - Vector3 tile_sizes_; + const Vector3 tile_sizes_; // The dimensions which are used for the shared memory tile. - Vector2 tiling_dimensions_; + const Vector2 tiling_dimensions_; // Number of threads implicitly assigned to each dimension. - Vector3 num_threads_; + const Vector3 num_threads_; - IndexingOrder indexing_order_; + const IndexingOrder indexing_order_; // Vector size for dimension X. - int vector_size_; + const int vector_size_; // Scaling apply to transform physical threadIdx into logical. - int64_t thread_id_virtual_scaling_ = 1; + const int64_t thread_id_virtual_scaling_ = 1; }; class ReductionCodegenInfo { public: using IndexGroups = std::vector>; - ReductionCodegenInfo(TilingScheme mapping_scheme, int num_partial_results, - bool is_row_reduction, bool is_race_free, - IndexGroups index_groups, - const HloInstruction* first_reduce) + explicit ReductionCodegenInfo(TilingScheme mapping_scheme, + int num_partial_results, bool is_row_reduction, + bool is_race_free, IndexGroups index_groups, + const HloInstruction* first_reduce) : tiling_scheme_(mapping_scheme), num_partial_results_(num_partial_results), is_row_reduction_(is_row_reduction), @@ -198,7 +198,7 @@ class ReductionCodegenInfo { private: friend class ReductionCodegenState; - TilingScheme tiling_scheme_; + const TilingScheme tiling_scheme_; int num_partial_results_; bool is_row_reduction_; bool is_race_free_; diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 034f2a2f20d2ff..1b542e7dca447a 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -66,34 +66,6 @@ xla_test( ], ) -cc_library( - name = "fusion_analysis_cache", - srcs = ["fusion_analysis_cache.cc"], - hdrs = ["fusion_analysis_cache.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/synchronization", - ], -) - -xla_cc_test( - name = "fusion_analysis_cache_test", - srcs = ["fusion_analysis_cache_test.cc"], - deps = [ - ":fusion_analysis_cache", - "//xla/service:hlo_parser", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "gpu_cost_model_stats_collection", srcs = ["gpu_cost_model_stats_collection.cc"], @@ -177,7 +149,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ - ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", "//xla:shape_util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc deleted file mode 100644 index 59e9499b3c6d62..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/fusion_analysis_cache.h" - -#include "xla/hlo/ir/hlo_instruction.h" - -namespace xla::gpu { - -const std::optional& HloFusionAnalysisCache::Get( - const HloInstruction& instruction) { - { - absl::ReaderMutexLock lock(&mutex_); - auto it = analyses_.find(&instruction); - if (it != analyses_.end()) { - return it->second; - } - } - - std::optional analysis = - AnalyzeFusion(instruction, device_info_); - absl::MutexLock lock(&mutex_); - - // If some other thread created an entry for this key concurrently, return - // that instead (the other thread is likely using the instance). - auto it = analyses_.find(&instruction); - if (it != analyses_.end()) { - return it->second; - } - - return analyses_[&instruction] = std::move(analysis); -} - -const std::optional& HloFusionAnalysisCache::Get( - const HloInstruction& producer, const HloInstruction& consumer) { - std::pair key{&producer, - &consumer}; - { - absl::ReaderMutexLock lock(&mutex_); - auto it = producer_consumer_analyses_.find(key); - if (it != producer_consumer_analyses_.end()) { - return it->second; - } - } - - std::optional analysis = - AnalyzeProducerConsumerFusion(producer, consumer, device_info_); - absl::MutexLock lock(&mutex_); - - // If some other thread created an entry for this key concurrently, return - // that instead (the other thread is likely using the instance). - auto it = producer_consumer_analyses_.find(key); - if (it != producer_consumer_analyses_.end()) { - return it->second; - } - - producers_for_consumers_[&consumer].push_back(&producer); - consumers_for_producers_[&producer].push_back(&consumer); - return producer_consumer_analyses_[key] = std::move(analysis); -} - -void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - analyses_.erase(&instruction); - - if (auto consumers = consumers_for_producers_.extract(&instruction)) { - for (const auto* consumer : consumers.mapped()) { - producer_consumer_analyses_.erase({&instruction, consumer}); - } - } - if (auto producers = producers_for_consumers_.extract(&instruction)) { - for (const auto* producer : producers.mapped()) { - producer_consumer_analyses_.erase({producer, &instruction}); - } - } -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h deleted file mode 100644 index f21a4fdee0fcad..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ -#define XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/stream_executor/device_description.h" - -namespace xla::gpu { - -// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` -// and `Invalidate` the same key. Analyses are cached based on pointer-identity, -// no checking of changes is done. -class HloFusionAnalysisCache { - public: - explicit HloFusionAnalysisCache( - const stream_executor::DeviceDescription& device_info) - : device_info_(device_info) {} - - // Returns the analysis for the given instruction, creating it if it doesn't - // exist yet. Do not call concurrently with `Invalidate` for the same key. - const std::optional& Get( - const HloInstruction& instruction); - - // Returns the analysis for the given producer/consumer pair. - const std::optional& Get(const HloInstruction& producer, - const HloInstruction& consumer); - - // Removes the cache entry for the given instruction, if it exists. Also - // removes all producer-consumer fusions that involve this instruction. - void Invalidate(const HloInstruction& instruction); - - private: - const stream_executor::DeviceDescription& device_info_; - - absl::Mutex mutex_; - absl::node_hash_map> - analyses_; - absl::node_hash_map, - std::optional> - producer_consumer_analyses_; - - // For each instruction `producer`, contains the `consumer`s for which we have - // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. - absl::flat_hash_map> - consumers_for_producers_; - // For each instruction `consumer`, contains the `producer`s for which we have - // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. - absl::flat_hash_map> - producers_for_consumers_; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc deleted file mode 100644 index edacd6a7c8666b..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/fusion_analysis_cache.h" - -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/hlo_parser.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla::gpu { -namespace { - -class FusionAnalysisCacheTest : public HloTestBase { - public: - stream_executor::DeviceDescription device_{ - TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - HloFusionAnalysisCache cache_{device_}; -}; - -TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { - absl::string_view hlo_string = R"( - HloModule m - - f { - c0 = f32[] constant(0) - b0 = f32[1000] broadcast(c0) - ROOT n0 = f32[1000] negate(b0) - } - - ENTRY e { - ROOT r.1 = f32[1000] fusion(), kind=kLoop, calls=f - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto* computation = module->GetComputationWithName("f"); - auto* broadcast = computation->GetInstructionWithName("b0"); - auto* negate = computation->GetInstructionWithName("n0"); - auto* fusion = module->entry_computation()->root_instruction(); - - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(negate)); - - computation->set_root_instruction(broadcast); - - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(negate)) - << "Analysis should be cached."; - - cache_.Invalidate(*fusion); - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(broadcast)) - << "Analysis should have been recomputed"; -} - -TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { - absl::string_view hlo_string = R"( - HloModule m - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - f { - c0 = f32[] constant(0) - b0 = f32[1000] broadcast(c0) - ROOT r0 = f32[] reduce(b0, c0), dimensions={0}, to_apply=add - } - - ENTRY e { - f0 = f32[] fusion(), kind=kInput, calls=f - ROOT n0 = f32[] negate(f0) - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto* fusion = module->entry_computation()->GetInstructionWithName("f0"); - auto* neg = module->entry_computation()->GetInstructionWithName("n0"); - - auto* computation = module->GetComputationWithName("f"); - auto* constant = computation->GetInstructionWithName("c0"); - - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kReduction); - - computation->set_root_instruction(constant); - - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kReduction) - << "Analysis should be cached."; - - cache_.Invalidate(*fusion); - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kLoop) - << "Analysis should have been recomputed"; -} - -} // namespace -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 0c320e7b40840b..7d5045cef03337 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -224,7 +224,7 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, // that the IR emitter will use. LaunchDimensions EstimateFusionLaunchDimensions( int64_t estimated_num_threads, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const se::DeviceDescription& device_info) { if (fusion_analysis) { // TODO(jreiffers): This is the wrong place for this DUS analysis. @@ -269,15 +269,7 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; - // Use the analysis cache if present. - // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*instr, *cost_analysis->device_info_); - const auto& fusion_analysis = config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*instr) - : local_analysis; + auto fusion_analysis = AnalyzeFusion(*instr, *cost_analysis->device_info_); LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(instr->shape()), fusion_analysis, *device_info); @@ -311,7 +303,7 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); @@ -421,16 +413,7 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - // Use the analysis cache if present. - // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*fused_consumer, *device_info); - const auto& analysis_unfused = - config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*fused_consumer) - : local_analysis; + auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(fused_consumer->shape()), @@ -479,15 +462,8 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( // // TODO(shyshkov): Add calculations for consumer epilogue in the formula to // make it complete. - std::optional local_analysis_fused = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeProducerConsumerFusion(*producer, *fused_consumer, - *device_info); - const auto& analysis_fused = - config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*producer, *fused_consumer) - : local_analysis_fused; + auto analysis_fused = + AnalyzeProducerConsumerFusion(*producer, *fused_consumer, *device_info); LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_data.num_threads * utilization_by_this_consumer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index 0fcc8cfcb2abf2..b7b28fff1eeda7 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" @@ -63,25 +62,20 @@ struct GpuPerformanceModelOptions { // re-reads can happen from cache. bool first_read_from_dram = false; - // If present, use this to retrieve fusion analyses. - HloFusionAnalysisCache* fusion_analysis_cache = nullptr; - static GpuPerformanceModelOptions Default() { return GpuPerformanceModelOptions(); } - static GpuPerformanceModelOptions PriorityFusion( - HloFusionAnalysisCache* fusion_analysis_cache) { + static GpuPerformanceModelOptions PriorityFusion() { GpuPerformanceModelOptions config; config.consider_coalescing = true; config.first_read_from_dram = true; - config.fusion_analysis_cache = fusion_analysis_cache; return config; } static GpuPerformanceModelOptions ForModule(const HloModule* module) { return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion(nullptr) // Only cache within priority fusion. + ? PriorityFusion() : Default(); } }; @@ -127,7 +121,7 @@ class GpuPerformanceModel { const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer = nullptr); }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index d768bce08c55ef..68bde4b9010382 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -360,7 +360,7 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), + producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(), consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 10fe7924d10c2c..8df808c1bfad82 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -42,7 +42,6 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/instruction_fusion.h" @@ -80,14 +79,12 @@ class GpuPriorityFusionQueue : public FusionQueue { const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool, - HloFusionAnalysisCache& fusion_analysis_cache) + tsl::thread::ThreadPool* thread_pool) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), can_fuse_(can_fuse), fusion_process_dump_(fusion_process_dump), - thread_pool_(thread_pool), - fusion_analysis_cache_(fusion_analysis_cache) { + thread_pool_(thread_pool) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -184,9 +181,6 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } - fusion_analysis_cache_.Invalidate(*original_producer); - fusion_analysis_cache_.Invalidate(*original_consumer); - // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. // Priority recomputation is called before DCE. Remove all references to @@ -295,8 +289,7 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, &cost_analysis_, - GpuPerformanceModelOptions::PriorityFusion(&fusion_analysis_cache_), - producer->users()); + GpuPerformanceModelOptions::PriorityFusion(), producer->users()); if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = @@ -372,8 +365,6 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::Mutex fusion_process_dump_mutex_; tsl::thread::ThreadPool* thread_pool_; - - HloFusionAnalysisCache& fusion_analysis_cache_; }; } // namespace @@ -511,7 +502,8 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( // matter but some passes downstream still query these instead of fusion // analysis. // TODO: Don't recompute this all the time. - const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer); + auto analysis = + AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); if (!analysis) return HloInstruction::FusionKind::kLoop; switch (analysis->GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kLoop: @@ -552,7 +544,7 @@ std::unique_ptr GpuPriorityFusion::GetFusionQueue( [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index); }, - fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); + fusion_process_dump_.get(), thread_pool_)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index afc5e8f99003d4..1723766d7784c8 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" @@ -43,13 +42,12 @@ namespace gpu { class GpuPriorityFusion : public InstructionFusion { public: GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& device, + const se::DeviceDescription& d, GpuHloCostAnalysis::Options cost_analysis_options) : InstructionFusion(GpuPriorityFusion::IsExpensive), thread_pool_(thread_pool), - device_info_(device), - cost_analysis_options_(std::move(cost_analysis_options)), - fusion_analysis_cache_(device_info_) {} + device_info_(d), + cost_analysis_options_(std::move(cost_analysis_options)) {} absl::string_view name() const override { return "priority-fusion"; } @@ -88,7 +86,6 @@ class GpuPriorityFusion : public InstructionFusion { absl::Mutex fusion_node_evaluations_mutex_; absl::flat_hash_map fusion_node_evaluations_; - HloFusionAnalysisCache fusion_analysis_cache_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 310340af91d391..5574173a75ef11 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -237,10 +237,10 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { } )"; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), - ::testing::UnorderedElementsAre( - HloFusionAnalysis::EmitterFusionKind::kLoop, - HloFusionAnalysis::EmitterFusionKind::kReduction)); + EXPECT_THAT( + RunAndGetFusionKinds(kHlo), + ::testing::ElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop, + HloFusionAnalysis::EmitterFusionKind::kReduction)); RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY From 2b65d5005c64aac5d7800cc88c9984b34d2c5e3b Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Thu, 16 Nov 2023 07:26:48 -0800 Subject: [PATCH 169/391] Update test_util decorators to be either decorators or decorator factories (not both). This change modifies most of the decorator functions in `test_util.py` to make them behave consistently as either decorators or as decorator factories. That is, the modified functions no longer conditionally return a decorator or a decorated function depending on if the passed-in function is `None`. If the decorator accepted optional parameters other than the function, it was updated to be a decorator factory (a function which returns a decorator). If it did not have any additional parameters, it was updated to only work as a decorator. For example: ``` def foo(config=True, func=None): def decorator(f): def decorated(*args, **kwargs): ... return f(*args, **kwargs) return decorated if func is not None: return decorator(func) return decorator ``` would now be: ``` def foo(config=True): def decorator(f): def decorated(*args, **kwargs): ... return f(*args, **kwargs) return decorated return decorator ``` while: ``` def bar(func=None): def decorator(f): def decorated(*args, **kwargs): ... return f(*args, **kwargs) return decorated if func is not None: return decorator(func) return decorator ``` became: ``` def bar(func): def decorated(*args, **kwargs): ... return func(*args, **kwargs) return decorated ``` This change also modifies internal TensorFlow code which used these decorators, if necessary, in order to comply with the changes. PiperOrigin-RevId: 583036777 --- .../data/kernel_tests/placement_test.py | 8 +- tensorflow/python/framework/test_util.py | 378 +++++++----------- .../image_ops/draw_bounding_box_op_test.py | 2 +- .../sparse_ops/sparse_xent_op_test_base.py | 2 +- .../registration/registration_saving_test.py | 2 +- 5 files changed, 154 insertions(+), 238 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/placement_test.py b/tensorflow/python/data/kernel_tests/placement_test.py index 6c9efc53f2486a..35f929737c6ff6 100644 --- a/tensorflow/python/data/kernel_tests/placement_test.py +++ b/tensorflow/python/data/kernel_tests/placement_test.py @@ -198,7 +198,7 @@ def create_iter(): create_iter() @combinations.generate(test_base.graph_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIteratorOnDeviceGraphModeOneShotIterator(self): self.skipTest("TODO(b/169429285): tf.data.Dataset.make_one_shot_iterator " "does not support GPU placement.") @@ -230,7 +230,7 @@ def testIteratorOnDeviceGraphModeOneShotIterator(self): self.assertIn(b"GPU:0", self.evaluate(has_value_device)) @combinations.generate(test_base.graph_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIteratorOnDeviceGraphModeInitializableIterator(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) @@ -259,7 +259,7 @@ def testIteratorOnDeviceGraphModeInitializableIterator(self): self.assertIn(b"GPU:0", self.evaluate(has_value_device)) @combinations.generate(test_base.eager_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testIterDatasetEagerModeWithExplicitDevice(self): @def_function.function @@ -274,7 +274,7 @@ def comp(): self.assertEqual(result.numpy(), 45) @combinations.generate(test_base.eager_only_combinations()) - @test_util.run_gpu_only() + @test_util.run_gpu_only def testFunctionInliningColocation(self): @def_function.function diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 25d3330bed779c..3ff5963a274853 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1273,7 +1273,7 @@ def wrapper(*args, **kwargs): return wrapper -def add_graph_building_optimization_tests(cls=None): +def add_graph_building_optimization_tests(cls): """Adds methods with graph_building_optimization enabled to the test suite. Example: @@ -1302,22 +1302,16 @@ def testBarWithGraphBuildingOptimization(self): cls with new test methods added. """ - def decorator(cls): - if flags.config().graph_building_optimization.value(): - return cls - - for name, value in cls.__dict__.copy().items(): - if (callable(value) and - (name.startswith(unittest.TestLoader.testMethodPrefix) or - name.startswith("benchmark"))): - setattr(cls, name + "WithGraphBuildingOptimization", - enable_graph_building_optimization(value)) + if flags.config().graph_building_optimization.value(): return cls - if cls is not None: - return decorator(cls) - - return decorator + for name, value in cls.__dict__.copy().items(): + if (callable(value) and + (name.startswith(unittest.TestLoader.testMethodPrefix) or + name.startswith("benchmark"))): + setattr(cls, name + "WithGraphBuildingOptimization", + enable_graph_building_optimization(value)) + return cls def disable_eager_op_as_function(unused_msg): @@ -1334,7 +1328,7 @@ def disable_eager_op_as_function(unused_msg): return _disable_test(execute_func=False) -def set_xla_env_flag(func=None, flag=""): +def set_xla_env_flag(flag=""): """Decorator for setting XLA_FLAGS prior to running a test. This function returns a decorator intended to be applied to test methods in @@ -1351,11 +1345,11 @@ def testFoo(self): ... Args: - func: The function to be wrapped. flag: The xla flag to be set in the XLA_FLAGS env variable. Returns: - The wrapped function. + A decorator which sets the configured flag in XLA_FLAGS for the decorated + function. """ def decorator(f): @@ -1377,13 +1371,10 @@ def decorated(*args, **kwargs): return decorated - if func is not None: - return decorator(func) - return decorator -def build_as_function_and_v1_graph(func=None): +def build_as_function_and_v1_graph(func): """Run a test case in v1 graph mode and inside tf.function in eager mode. WARNING: This decorator can only be used in test cases that statically checks @@ -1400,39 +1391,33 @@ def build_as_function_and_v1_graph(func=None): Decorated test case function. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError( - "`run_in_graph_mode_and_function` only supports test methods.") - - @parameterized.named_parameters(("_v1_graph", "v1_graph"), - ("_function", "function")) - @functools.wraps(f) - def decorated(self, run_mode, *args, **kwargs): - if run_mode == "v1_graph": - with ops.Graph().as_default(): - f(self, *args, **kwargs) - elif run_mode == "function": + if tf_inspect.isclass(func): + raise ValueError( + "`run_in_graph_mode_and_function` only supports test methods.") - @def_function.function - def function_in_eager(): - f(self, *args, **kwargs) + @parameterized.named_parameters(("_v1_graph", "v1_graph"), + ("_function", "function")) + @functools.wraps(func) + def decorated(self, run_mode, *args, **kwargs): + if run_mode == "v1_graph": + with ops.Graph().as_default(): + func(self, *args, **kwargs) + elif run_mode == "function": - # Create a new graph for the eagerly executed version of this test for - # better isolation. - graph_for_eager_test = ops.Graph() - with graph_for_eager_test.as_default(), context.eager_mode(): - function_in_eager() - ops.dismantle_graph(graph_for_eager_test) - else: - raise ValueError("Unknown run mode %s" % run_mode) + @def_function.function + def function_in_eager(): + func(self, *args, **kwargs) - return decorated - - if func is not None: - return decorator(func) + # Create a new graph for the eagerly executed version of this test for + # better isolation. + graph_for_eager_test = ops.Graph() + with graph_for_eager_test.as_default(), context.eager_mode(): + function_in_eager() + ops.dismantle_graph(graph_for_eager_test) + else: + raise ValueError("Unknown run mode %s" % run_mode) - return decorator + return decorated def run_in_async_and_sync_mode(f): @@ -1573,17 +1558,13 @@ def run_eagerly(self, **kwargs): return decorator -def run_in_v1_v2(func=None, - device_to_use: str = None, +def run_in_v1_v2(device_to_use: str = None, assert_no_eager_garbage: bool = False): """Execute the decorated test in v1 and v2 modes. The overall execution is similar to that of `run_in_graph_and_eager_mode`. Args: - func: A test function/method to be decorated. If `func` is None, this method - returns a decorator the can be applied to a function. Otherwise, an - already applied decorator is returned. device_to_use: A string in the following format: "/device:CPU:0". assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage collector and asserts that no extra garbage has been created when running @@ -1600,12 +1581,11 @@ def run_in_v1_v2(func=None, A decorator that runs a given test in v1 and v2 modes. """ - decorator_tag = "wrapped_with_v1_v2_decorator" - if hasattr(func, decorator_tag): - # Already decorated with this very same decorator - return func - def decorator(f): + decorator_tag = "wrapped_with_v1_v2_decorator" + if hasattr(f, decorator_tag): + # Already decorated with this very same decorator + return f def decorated(self, *args, **kwargs): logging.info("Running %s in V1 mode.", f.__name__) @@ -1644,9 +1624,6 @@ def run_v2(self, **kwargs): tf_decorated.__dict__[decorator_tag] = True return tf_decorated - if func is not None: - return decorator(func) - return decorator @@ -1709,53 +1686,48 @@ def bound_f(): return decorated -def deprecated_graph_mode_only(func=None): +def deprecated_graph_mode_only(func): """Execute the decorated test in graph mode. - This function returns a decorator intended to be applied to tests that are not - compatible with eager mode. When this decorator is applied, the test body will - be run in an environment where API calls construct graphs instead of executing - eagerly. + This is a decorator intended to be applied to tests that are not compatible + with eager mode. When this decorator is applied, the test body will be run in + an environment where API calls construct graphs instead of executing eagerly. `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and `run_in_graph_and_eager_modes` are available decorators for different v1/v2/eager/graph combinations. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function or class to be annotated. + If `func` is a function this returns the decorator applied to `func`. + If `func` is a unit test class this returns that class with the decorator + applied to all test functions within that class. Returns: - Returns a decorator that will run the decorated test method in graph mode. + Returns a function or class that will run the decorated test(s) + in graph mode. """ - def decorator(f): - if tf_inspect.isclass(f): - setup = f.__dict__.get("setUp") - if setup is not None: - setattr(f, "setUp", decorator(setup)) - - for name, value in f.__dict__.copy().items(): - if (callable(value) and - name.startswith(unittest.TestLoader.testMethodPrefix)): - setattr(f, name, decorator(value)) - - return f + if tf_inspect.isclass(func): + setup = func.__dict__.get("setUp") + if setup is not None: + setattr(func, "setUp", deprecated_graph_mode_only(setup)) - def decorated(*args, **kwargs): - if context.executing_eagerly(): - with context.graph_mode(): - return f(*args, **kwargs) - else: - return f(*args, **kwargs) + for name, value in func.__dict__.copy().items(): + if (callable(value) and + name.startswith(unittest.TestLoader.testMethodPrefix)): + setattr(func, name, deprecated_graph_mode_only(value)) - return tf_decorator.make_decorator(f, decorated) + return func - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if context.executing_eagerly(): + with context.graph_mode(): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) run_deprecated_v1 = deprecated_graph_mode_only @@ -1847,73 +1819,57 @@ def run_v2_only(func=None, reason=None): return _run_vn_only(func=func, v2=True, reason=reason) -def run_gpu_only(func=None): +def run_gpu_only(func): """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence of a GPU. If a GPU is absent, it will simply be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_gpu_only` only supports test methods.") - - def decorated(self, *args, **kwargs): - if not is_gpu_available(): - self.skipTest("Test requires GPU") + if tf_inspect.isclass(func): + raise ValueError("`run_gpu_only` only supports test methods.") - return f(self, *args, **kwargs) + def decorated(self, *args, **kwargs): + if not is_gpu_available(): + self.skipTest("Test requires GPU") - return decorated + return func(self, *args, **kwargs) - if func is not None: - return decorator(func) - - return decorator + return decorated -def run_cuda_only(func=None): +def run_cuda_only(func): """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_cuda_only` only supports test methods.") - - def decorated(self, *args, **kwargs): - if not is_gpu_available(cuda_only=True): - self.skipTest("Test requires CUDA GPU") + if tf_inspect.isclass(func): + raise ValueError("`run_cuda_only` only supports test methods.") - return f(self, *args, **kwargs) + def decorated(self, *args, **kwargs): + if not is_gpu_available(cuda_only=True): + self.skipTest("Test requires CUDA GPU") - return decorated + return func(self, *args, **kwargs) - if func is not None: - return decorator(func) - - return decorator + return decorated -def run_gpu_or_tpu(func=None): +def run_gpu_or_tpu(func): """Execute the decorated test only if a physical GPU or TPU is available. This function is intended to be applied to tests that require the presence @@ -1923,30 +1879,25 @@ def run_gpu_or_tpu(func=None): - If both GPU and TPU are absent, the test will be skipped. Args: - func: function to be annotated. If `func` is None, this method returns a - decorator the can be applied to a function. If `func` is not None this - returns the decorator applied to `func`. + func: function to be annotated. Returns: - Returns a decorator that will conditionally skip the decorated test method. + Returns a function that will conditionally skip the decorated test method. """ - def decorator(f): - if tf_inspect.isclass(f): - raise ValueError("`run_gpu_or_tpu` only supports test methods.") + if tf_inspect.isclass(func): + raise ValueError("`run_gpu_or_tpu` only supports test methods.") - def decorated(self, *args, **kwargs): - if config.list_physical_devices("GPU"): - return f(self, "GPU", *args, **kwargs) + def decorated(self, *args, **kwargs): + if config.list_physical_devices("GPU"): + return func(self, "GPU", *args, **kwargs) - if config.list_physical_devices("TPU"): - return f(self, "TPU", *args, **kwargs) + if config.list_physical_devices("TPU"): + return func(self, "TPU", *args, **kwargs) - self.skipTest("Test requires GPU or TPU") + self.skipTest("Test requires GPU or TPU") - return decorated - - return decorator if func is None else decorator(func) + return decorated def with_forward_compatibility_horizons(*horizons): @@ -2182,36 +2133,29 @@ def disable_cudnn_autotune(func): Decorated function. """ - def decorator(f): - - def decorated(*args, **kwargs): - original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") - os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" - original_xla_flags = os.environ.get("XLA_FLAGS") - new_xla_flags = "--xla_gpu_autotune_level=0" - if original_xla_flags: - new_xla_flags = original_xla_flags + " " + new_xla_flags - os.environ["XLA_FLAGS"] = new_xla_flags - - result = f(*args, **kwargs) + def decorated(*args, **kwargs): + original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") + os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" + original_xla_flags = os.environ.get("XLA_FLAGS") + new_xla_flags = "--xla_gpu_autotune_level=0" + if original_xla_flags: + new_xla_flags = original_xla_flags + " " + new_xla_flags + os.environ["XLA_FLAGS"] = new_xla_flags - if (original_tf_cudnn_use_autotune is None): - del os.environ["TF_CUDNN_USE_AUTOTUNE"] - else: - os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune - if (original_xla_flags is None): - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = original_xla_flags - - return result + result = func(*args, **kwargs) - return tf_decorator.make_decorator(func, decorated) + if (original_tf_cudnn_use_autotune is None): + del os.environ["TF_CUDNN_USE_AUTOTUNE"] + else: + os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune + if (original_xla_flags is None): + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = original_xla_flags - if func is not None: - return decorator(func) + return result - return decorator + return tf_decorator.make_decorator(func, decorated) # The description is just for documentation purposes. @@ -2233,21 +2177,14 @@ def enable_tf_xla_constant_folding_impl(func): Decorated function. """ - def decorator(f): - - def decorated(*args, **kwargs): - original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() - pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) - result = f(*args, **kwargs) - pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) - return result - - return tf_decorator.make_decorator(func, decorated) - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() + pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) + result = func(*args, **kwargs) + pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) + return result - return decorator + return tf_decorator.make_decorator(func, decorated) return enable_tf_xla_constant_folding_impl @@ -2257,18 +2194,11 @@ def _disable_test(execute_func): def disable_test_impl(func): - def decorator(func): - - def decorated(*args, **kwargs): - if execute_func: - return func(*args, **kwargs) - - return tf_decorator.make_decorator(func, decorated) - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if execute_func: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) return disable_test_impl @@ -2327,20 +2257,13 @@ def disable_tfrt_impl(cls_or_func): else: return cls_or_func else: - def decorator(func): - - def decorated(*args, **kwargs): - if tfrt_utils.enabled(): - return - else: - return func(*args, **kwargs) - - return tf_decorator.make_decorator(func, decorated) - - if cls_or_func is not None: - return decorator(cls_or_func) + def decorated(*args, **kwargs): + if tfrt_utils.enabled(): + return + else: + return cls_or_func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(cls_or_func, decorated) return disable_tfrt_impl @@ -2385,26 +2308,19 @@ def xla_allow_fallback(description): # pylint: disable=unused-argument def xla_allow_fallback_impl(func): """Allow fallback to TF even though testing xla.""" - def decorator(func): - - def decorated(*args, **kwargs): - if is_xla_enabled(): - # Update the global XLABuildOpsPassFlags to enable lazy compilation, - # which allows the compiler to fall back to TF classic. Remember the - # old value so that we can reset it. - old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) - result = func(*args, **kwargs) - pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) - return result - else: - return func(*args, **kwargs) - - return tf_decorator.make_decorator(func, decorated) - - if func is not None: - return decorator(func) + def decorated(*args, **kwargs): + if is_xla_enabled(): + # Update the global XLABuildOpsPassFlags to enable lazy compilation, + # which allows the compiler to fall back to TF classic. Remember the + # old value so that we can reset it. + old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) + result = func(*args, **kwargs) + pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) + return result + else: + return func(*args, **kwargs) - return decorator + return tf_decorator.make_decorator(func, decorated) return xla_allow_fallback_impl diff --git a/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py b/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py index a66d8d8a9a2a13..f7641c63e7f7e3 100644 --- a/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py +++ b/tensorflow/python/kernel_tests/image_ops/draw_bounding_box_op_test.py @@ -135,7 +135,7 @@ def testDrawBoundingBoxHalf(self): image, dtype=dtypes.half, colors=colors) # generate_bound_box_proposals is only available on GPU. - @test_util.run_gpu_only() + @test_util.run_gpu_only def testGenerateBoundingBoxProposals(self): # Op only exists on GPU. with self.cached_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py b/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py index a30d82591da5c9..381e5c093f007e 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py +++ b/tensorflow/python/kernel_tests/sparse_ops/sparse_xent_op_test_base.py @@ -71,7 +71,7 @@ def testSingleClass(self): self.assertAllClose([0.0, 0.0, 0.0], tf_loss) self.assertAllClose([[0.0], [0.0], [0.0]], tf_gradient) - @test_util.run_gpu_only() + @test_util.run_gpu_only def _testInvalidLabelGPU(self, invalid_label_gradient=np.nan): labels = [4, 3, 0, -1] logits = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.], diff --git a/tensorflow/python/saved_model/registration/registration_saving_test.py b/tensorflow/python/saved_model/registration/registration_saving_test.py index ec87f06c0a5e10..8e60cc2bf9e122 100644 --- a/tensorflow/python/saved_model/registration/registration_saving_test.py +++ b/tensorflow/python/saved_model/registration/registration_saving_test.py @@ -223,7 +223,7 @@ def test_registered_saver(self, cycles): class SingleCycleTest(test.TestCase): - @test_util.deprecated_graph_mode_only() + @test_util.deprecated_graph_mode_only def test_registered_saver_fails_in_saved_model_graph_mode(self): with context.eager_mode(): p1 = Part([1, 4]) From 756ad90f9994dbb0382f81a6f543b1dd9236b58b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Thu, 16 Nov 2023 08:14:18 -0800 Subject: [PATCH 170/391] [XLA:GPU] Fix fusion parameter limit Sometimes we fused more than allowed and it broke some internal tests, because there is a TF_RET_CHECK for this. PiperOrigin-RevId: 583049657 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 10 ++++--- .../service/gpu/gemm_rewriter_triton_test.cc | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index d4ef2ea54465a9..5477e0f0fd8bd0 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -1238,8 +1238,8 @@ void FusionContext::TryToFuseWithInputsRecursively( HloInstruction* hlo = to_visit.front(); to_visit.pop(); // Watch the total number of fusion parameters. - if (inputs.size() >= TritonFusionAnalysis::kMaxParameterPerScope && - NumAddedParameters(*hlo) > 0) { + if (inputs.size() + NumAddedParameters(*hlo) > + TritonFusionAnalysis::kMaxParameterPerScope) { // Re-queue: the number of parameters may go down when other instructions // are processed. to_visit.push(hlo); @@ -1322,8 +1322,10 @@ StatusOr FuseDot(HloInstruction& dot, context.TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), gpu_version, old_to_new_mapping, fusion_inputs, builder); - TF_RET_CHECK(fusion_inputs.size() - operand_count_before <= - TritonFusionAnalysis::kMaxParameterPerScope); + const int new_parameters = fusion_inputs.size() - operand_count_before; + TF_RET_CHECK(new_parameters <= TritonFusionAnalysis::kMaxParameterPerScope) + << "Too many new parameters: " << new_parameters << " > " + << TritonFusionAnalysis::kMaxParameterPerScope; return context; }; diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 9798fb4124554e..7a78d9523db1b2 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -1124,6 +1124,34 @@ ENTRY e { TritonFusionAnalysis::kMaxParameterPerScope * 2); } +TEST_F(GemmRewriterTritonLevel2Test, + DoNotFuseTooManyParametersWhenAnInstructionWouldAddMultipleParameters) { + // If we fuse the select, it adds 2 additional parameters at once (not 3, + // because the select instruction itself is removed from the parameters). + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +ENTRY e { + a = f32[3,49]{1,0} parameter(0) + b = f32[3,49]{1,0} parameter(1) + c = pred[3,49]{1,0} parameter(2) + d = f32[3,49]{1,0} parameter(3) + e = f32[3,49]{1,0} parameter(4) + add0 = f32[3,49]{1,0} add(a, b) + select = f32[3,49]{1,0} select(c, d, e) + add1 = f32[3,49]{1,0} add(add0, select) + f = f32[3,32]{1,0} parameter(5) + ROOT tmp_102 = f32[49,32]{1,0} dot(add1, f), lhs_contracting_dims={0}, rhs_contracting_dims={0} +})")); + + EXPECT_TRUE(GemmRewriterTriton(gpu_version_).Run(module.get()).value()); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kFusion); + EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), + HloInstruction::FusionKind::kCustom); + EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), + TritonFusionAnalysis::kMaxParameterPerScope + 1); +} + TEST_F(GemmRewriterTritonLevel2Test, OperationsAddingMoreParametersGetMultipleTries) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, From eee133576b34adc9422319d80a1fd3756a8bb6df Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 16 Nov 2023 08:18:53 -0800 Subject: [PATCH 171/391] [XLA:GPU] Add GetCommonUtilization helper. (NFC) PiperOrigin-RevId: 583050939 --- .../gpu/model/gpu_performance_model.cc | 77 +++++++++++-------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 7d5045cef03337..64e9132d5ce507 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -296,6 +296,44 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( return {flops, bytes_written, num_threads, write_time, exec_time}; } +// Returns utilization `overlap` between a common operand of producer and +// consumer on merge. `utilization > 0` means that the operand will be accessed +// more efficiently after fusion. +// +// Currently covers two cases: +// 1) Producer has to use the common operand elementwise from its root if it is +// a fusion or just be an elementwise instruction. +// 2) Consumer has to have common elementwise roots for the producer and the +// common operand if it is a fusion or just be an elementwise instruction. +float GetCommonUtilization( + const HloInstruction* producer, int64_t producer_idx_of_operand, + const HloInstruction* consumer, + const ConstHloInstructionMap& consumer_operands, + const GpuHloCostAnalysis* cost_analysis) { + auto consumer_idx_of_operand = + consumer_operands.find(producer->operand(producer_idx_of_operand)); + if (consumer_idx_of_operand == consumer_operands.end()) { + return 0.f; + } + + if (producer->IsElementwise() || + (producer->opcode() == HloOpcode::kFusion && + FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, + cost_analysis))) { + if (consumer->opcode() == HloOpcode::kFusion) { + int64_t consumer_idx_of_producer = consumer_operands.at(producer); + return cost_analysis->CommonElementwiseUtilization( + consumer->fused_parameter(consumer_idx_of_operand->second), + consumer->fused_parameter(consumer_idx_of_producer)); + } else { + if (consumer->IsElementwise()) { + return 1.f; + } + } + } + return 0.f; +} + // Tells input access time of the producer alone if fused_consumer // is not specified. Otherwise estimates the access time to producer's // inputs as if it is fused into the consumer. @@ -308,14 +346,14 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); float producer_output_utilization = 1.f; - ConstHloInstructionSet consumer_operands; + ConstHloInstructionMap consumer_operands; bool consumer_transposes = false; if (fused_consumer) { consumer_transposes = TransposesMinorDimension(fused_consumer); producer_output_utilization = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - for (const HloInstruction* op : fused_consumer->operands()) { - consumer_operands.insert(op); + for (int64_t i = 0; i < fused_consumer->operand_count(); ++i) { + consumer_operands[fused_consumer->operand(i)] = i; } } @@ -336,33 +374,12 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t n_bytes_net = std::llround(operand_bytes_accessed / std::max(operand_utilization, 1.0f)); - // Look for common operands of producer and consumer that are accessed - // more efficiently on merge: - // 1) Producer has to use the common operand elementwise from its root if - // it is a fusion or just be an elementwise instruction. - // 2) Consumer has to have common elementwise roots for the producer - // and the common operand if it is a fusion or just be an elementwise - // instruction. - float common_utilization = 0; - if (consumer_operands.count(producer->operand(i)) && - (producer->IsElementwise() || - (producer->opcode() == HloOpcode::kFusion && - FusionUsesParameterElementwiseFromRoot(producer, i, - cost_analysis)))) { - if (fused_consumer->opcode() == HloOpcode::kFusion) { - int64_t consumer_idx_of_common_operand = - fused_consumer->operand_index(producer->operand(i)); - int64_t consumer_idx_of_producer = - fused_consumer->operand_index(producer); - common_utilization = cost_analysis->CommonElementwiseUtilization( - fused_consumer->fused_parameter(consumer_idx_of_common_operand), - fused_consumer->fused_parameter(consumer_idx_of_producer)); - } else { - if (fused_consumer->IsElementwise()) { - common_utilization = 1.f; - } - } - } + // Look if common operand of producer and consumer will be accessed more + // efficiently on merge. + float common_utilization = + GetCommonUtilization(producer, + /*producer_idx_of_operand=*/i, fused_consumer, + consumer_operands, cost_analysis); // TODO(jreiffers): We should be checking each operand here. bool coalesced = (fusion_analysis && From dfc0ccc5075772f9292caf6f7fe269f65079f9b9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 16 Nov 2023 08:58:42 -0800 Subject: [PATCH 172/391] Add Bazel toolchain configs for cross-compiling TensorFlow for Linux Aarch64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds support for cross-compiling TensorFlow targets for Linux Aarch64 on a Linux x86 machine. We use Clang as the cross-compiler, `ld.lld` as the linker and build in a special cross-compile supported [Docker image](http://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:de26c1dbddcb42b48e665972f62f128f2c69e0f1aa6f0ba6c7411dd23d4de785) that contains all the necessary build tools and the sysroots for both Linux Aarch64 and Linux x86. We need a Linux x86 toolchain because Bazel needs it to be able to build the tools used during the build—such as Protoc, llvm-tablegen, flatc—correctly for our execution platform (Linux x86). In addition, this adds support for cross-compiling using RBE. We do this by invoking a Bazel remote build from a Linux Aarch64 host which would then send build requests to remote Linux x86 VMs. The Linux x86 VMs build inside the cross-compile Docker image using the cross-compile toolchain configs to build the targets for Aarch64. The targets, once built, are automatically transferred to the host seamlessly. RBE cross-compiling is necessary for us to be able to run `bazel test` commands. Note that lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" flags seem to be actually used to specify the execution platform details. It seems it is this way because these flags are old and predate the distinction between host and execution platform. The toolchain configs can be found in "tensorflow/tools/toolchains/cross_compile/cc/BUILD" and the RBE platform configs can be found in "tensorflow/tools/toolchains/cross_compile/config/BUILD". If trying to cross-compile without RBE, run your build from a Linux x86 host in the [Docker image](http://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:de26c1dbddcb42b48e665972f62f128f2c69e0f1aa6f0ba6c7411dd23d4de785) and use `--config=cross_compile_linux_arm64`. If you are trying to cross-compile with RBE, run your build from a Linux Aarch64 host and use `--conig=rbe_cross_compile_linux_arm64`. Since RBE uses GCP VM instances and requires authentication, it is only available to Googlers and TF CI builds. Tests can only be run with RBE. PiperOrigin-RevId: 583062010 --- .bazelrc | 36 +++- tensorflow/opensource_only.files | 2 + .../tools/toolchains/cross_compile/cc/BUILD | 188 +++++++++++++++++ .../toolchains/cross_compile/config/BUILD | 23 +++ third_party/xla/.bazelrc | 36 +++- third_party/xla/opensource_only.files | 2 + third_party/xla/third_party/tsl/.bazelrc | 36 +++- .../xla/third_party/tsl/opensource_only.files | 2 + .../tools/toolchains/cross_compile/cc/BUILD | 191 ++++++++++++++++++ .../toolchains/cross_compile/config/BUILD | 23 +++ .../tools/toolchains/cross_compile/cc/BUILD | 191 ++++++++++++++++++ .../toolchains/cross_compile/config/BUILD | 23 +++ 12 files changed, 747 insertions(+), 6 deletions(-) create mode 100644 tensorflow/tools/toolchains/cross_compile/cc/BUILD create mode 100644 tensorflow/tools/toolchains/cross_compile/config/BUILD create mode 100644 third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD create mode 100644 third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD create mode 100644 third_party/xla/tools/toolchains/cross_compile/cc/BUILD create mode 100644 third_party/xla/tools/toolchains/cross_compile/config/BUILD diff --git a/.bazelrc b/.bazelrc index b11faac6bc96c0..1cbd781f7a54f3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -765,7 +765,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -789,10 +789,42 @@ test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_ test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 606141b99f24a4..0a3015106ef946 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -175,6 +175,8 @@ tf_staging/tensorflow/tools/toolchains/BUILD: tf_staging/tensorflow/tools/toolchains/clang6/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py/BUILD: tf_staging/tensorflow/tools/toolchains/cpus/py3/BUILD: +tf_staging/tensorflow/tools/toolchains/cross_compile/cc/BUILD: +tf_staging/tensorflow/tools/toolchains/cross_compile/config/BUILD: tf_staging/tensorflow/tools/toolchains/embedded/arm-linux/BUILD: tf_staging/tensorflow/tools/toolchains/java/BUILD: tf_staging/tensorflow/tools/toolchains/python/BUILD: diff --git a/tensorflow/tools/toolchains/cross_compile/cc/BUILD b/tensorflow/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..7db2527259d026 --- /dev/null +++ b/tensorflow/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,188 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup(name = "empty") + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/tensorflow/tools/toolchains/cross_compile/config/BUILD b/tensorflow/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/tensorflow/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index b11faac6bc96c0..1cbd781f7a54f3 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -765,7 +765,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -789,10 +789,42 @@ test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_ test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 9abb2546fa24ed..9de7578a5801a9 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -26,6 +26,8 @@ tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: tools/toolchains/cpus/py/BUILD: tools/toolchains/cpus/py3/BUILD: +tools/toolchains/cross_compile/cc/BUILD: +tools/toolchains/cross_compile/config/BUILD: tools/toolchains/embedded/arm-linux/BUILD: tools/toolchains/java/BUILD: tools/toolchains/python/BUILD: diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index b11faac6bc96c0..1cbd781f7a54f3 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -765,7 +765,7 @@ test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflo test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 @@ -789,10 +789,42 @@ test:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_ test:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/compiler/mlir/tfr/examples/customization:test_ops_test -//tensorflow/compiler/mlir/tfr/examples/mnist:mnist_ops_test -//tensorflow/compiler/mlir/tfr/examples/pad:pad_ops_test -//tensorflow/python/tools:aot_compiled_test +# CROSS-COMPILE ARM64 PYCPP +test:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test +# Tests that fail only when cross-compiled +test:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # END TF TEST SUITE OPTIONS + +# START LINUX AARCH64 CROSS-COMPILE CONFIGS +# Set execution platform to Linux x86 +# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top" +# flags seem to be actually used to specify the execution platform details. It +# seems it is this way because these flags are old and predate the distinction +# between host and execution platform. +build:cross_compile_linux_arm64 --host_cpu=k8 +build:cross_compile_linux_arm64 --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 + +# Set the target CPU to Aarch64 +build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --cpu=aarch64 +build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite + +# RBE configs +build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 +build:rbe_cross_compile_linux_arm64 --config=rbe_base +build:rbe_cross_compile_linux_arm64 --remote_instance_name=projects/tensorflow-testing/instances/default_instance + +# Test-related settings below this point +# We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to +# force all tests to run locally on the Aarch64 host. +test:rbe_cross_compile_linux_arm64 --strategy=TestRunner=local +test:rbe_cross_compile_linux_arm64 --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors +test:rbe_cross_compile_linux_arm64 --flaky_test_attempts=3 --build_tests_only +# END LINUX AARCH64 CROSS-COMPILE CONFIGS diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index f2f2b14eba7be9..fa84f35768a5d2 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -131,6 +131,8 @@ tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: tools/toolchains/cpus/py/BUILD: tools/toolchains/cpus/py3/BUILD: +tools/toolchains/cross_compile/cc/BUILD: +tools/toolchains/cross_compile/config/BUILD: tools/toolchains/embedded/arm-linux/BUILD: tools/toolchains/java/BUILD: tools/toolchains/python/BUILD: diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..dc621893ac9675 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,191 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup( + name = "empty", + visibility = ["//visibility:public"], +) + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) diff --git a/third_party/xla/tools/toolchains/cross_compile/cc/BUILD b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD new file mode 100644 index 00000000000000..dc621893ac9675 --- /dev/null +++ b/third_party/xla/tools/toolchains/cross_compile/cc/BUILD @@ -0,0 +1,191 @@ +"""Toolchain configs for cross-compiling TensorFlow""" + +load("@bazel_tools//tools/cpp:unix_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +cc_toolchain_suite( + name = "cross_compile_toolchain_suite", + toolchains = { + "aarch64": ":linux_aarch64_toolchain", + "k8": ":linux_x86_toolchain", + }, +) + +filegroup( + name = "empty", + visibility = ["//visibility:public"], +) + +cc_toolchain( + name = "linux_x86_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_x86_toolchain_config", + toolchain_identifier = "linux_x86_toolchain", +) + +cc_toolchain_config( + name = "linux_x86_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt9", + compile_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mavx", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "k8", + cxx_builtin_include_directories = [ + "/dt9/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=x86_64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "x86_64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_x86_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) + +cc_toolchain( + name = "linux_aarch64_toolchain", + all_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":linux_aarch64_toolchain_config", + toolchain_identifier = "linux_aarch64_toolchain", +) + +cc_toolchain_config( + name = "linux_aarch64_toolchain_config", + abi_libc_version = "local", + abi_version = "local", + builtin_sysroot = "/dt10/", + compile_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fstack-protector", + "-Wall", + "-Wthread-safety", + "-Wself-assign", + "-Wunused-but-set-parameter", + "-Wno-free-nonheap-object", + "-fcolor-diagnostics", + "-fno-omit-frame-pointer", + "-mtune=generic", + "-march=armv8-a", + ], + compiler = "clang", + coverage_compile_flags = ["--coverage"], + coverage_link_flags = ["--coverage"], + cpu = "aarch64", + cxx_builtin_include_directories = [ + "/dt10/", + "/usr/lib/llvm-17/include/", + "/usr/lib/llvm-17/lib/clang/17/include", + ], + dbg_compile_flags = ["-g"], + host_system_name = "linux", + link_flags = [ + "--target=aarch64-unknown-linux-gnu", + "-fuse-ld=lld", + "--ld-path=/usr/lib/llvm-17/bin/ld.lld", + "-Wl,--undefined-version", + ], + link_libs = [ + "-lstdc++", + "-lm", + ], + opt_compile_flags = [ + "-g0", + "-O2", + "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ], + opt_link_flags = ["-Wl,--gc-sections"], + supports_start_end_lib = True, + target_libc = "", + target_system_name = "aarch64-unknown-linux-gnu", + tool_paths = { + "gcc": "/usr/lib/llvm-17/bin/clang", + "ld": "/usr/lib/llvm-17/bin/ld.lld", + "ar": "/usr/lib/llvm-17/bin/llvm-ar", + "cpp": "/usr/lib/llvm-17/bin/clang++", + "llvm-cov": "/usr/lib/llvm-17/bin/llvm-cov", + "nm": "/usr/lib/llvm-17/bin/llvm-nm", + "objdump": "/usr/lib/llvm-17/bin/llvm-objdump", + "strip": "/usr/lib/llvm-17/bin/llvm-strip", + }, + toolchain_identifier = "linux_aarch64_toolchain", + unfiltered_compile_flags = [ + "-no-canonical-prefixes", + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + "-Wno-unused-command-line-argument", + "-Wno-gnu-offsetof-extensions", + ], +) diff --git a/third_party/xla/tools/toolchains/cross_compile/config/BUILD b/third_party/xla/tools/toolchains/cross_compile/config/BUILD new file mode 100644 index 00000000000000..b6a504ba1449d6 --- /dev/null +++ b/third_party/xla/tools/toolchains/cross_compile/config/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["restricted"]) + +platform( + name = "linux_x86_64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + exec_properties = { + "container-image": "docker://gcr.io/tensorflow-testing/ml-devinfra-linux-aarch64-cross-compile@sha256:11c5ac3b9b4e01cfa82b39b90826a9bfc5b806ccc92cd3d272e6bf861de43be1", + "OSFamily": "Linux", + }, +) + +platform( + name = "linux_aarch64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:aarch64", + ], +) From 38b6b3ecd2b131554f6bcecc68d6cc197b4fc873 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 09:41:03 -0800 Subject: [PATCH 173/391] Support per-channel quantization for requantize The main change in implementation is in Requantize(). This allows any combinations of per-tensor or per-channel quantized tensors as the input and/or output. Also added numerical tests. PiperOrigin-RevId: 583074614 --- .../bridge/convert_mhlo_quant_to_int.cc | 339 +++++++++++------- .../convert_tf_quant_to_mhlo_int_test.cc | 122 +++++++ .../bridge/convert-mhlo-quant-to-int.mlir | 85 +++++ 3 files changed, 426 insertions(+), 120 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index a5a425d540dd30..d67336e456d864 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -53,78 +53,8 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { -// This helper function create ops to requantize `input` tensor and returns the -// output tensor. Clamping is done if output integer bit-width < 32. -// -// Requantization is essentially dequantize --> quantize. -// -// Dequantize: (input - zp) * scale -// Quantize: input / scale + zp -// -// Hence, -// output = (input - input_zp) * input_scale / output_scale + output_zp -// -// This is simplified as: -// output = input * merged_scale + merged_zp -// where: -// merged_zp = output_zp - input_zp * merged_scale. -// merged_scale = input_scale / output_scale. -Value Requantize(mlir::OpState op, Value input, - UniformQuantizedType input_quantized_type, - UniformQuantizedType output_quantized_type, - TensorType output_tensor_type, - ConversionPatternRewriter &rewriter) { - // Skip requantization when input and result have the same type. - if (input_quantized_type == output_quantized_type) { - return rewriter.create(op->getLoc(), output_tensor_type, - input); - } - - double merged_scale_fp = - input_quantized_type.getScale() / output_quantized_type.getScale(); - Value merged_scale = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_scale_fp))); - - auto float_tensor_type = - input.getType().cast().clone(rewriter.getF32Type()); - Value output_float = - rewriter.create(op->getLoc(), float_tensor_type, input); - - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float, merged_scale, nullptr); - - // Add merged_zp only when it is non-zero. - double merged_zp_fp = output_quantized_type.getZeroPoint() - - input_quantized_type.getZeroPoint() * merged_scale_fp; - if (merged_zp_fp != 0) { - Value merged_zp = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_zp_fp))); - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float, merged_zp, nullptr); - } - - // Clamp output if the output integer bit-width <32. - if (output_tensor_type.getElementType().cast().getWidth() < 32) { - Value quantization_min = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMin()))); - Value quantization_max = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMax()))); - // Clamp results by [quantization_min, quantization_max]. - output_float = rewriter.create( - op->getLoc(), float_tensor_type, quantization_min, output_float, - quantization_max); - } - - output_float = rewriter.create( - op->getLoc(), float_tensor_type, output_float); - return rewriter.create(op->getLoc(), output_tensor_type, - output_float); -} - +// TODO: b/311218165 - consider extract this to common utils and better ways to +// handle polymorphism. using QuantType = std::variant; FailureOr GetQuantType(Type type) { @@ -139,6 +69,22 @@ FailureOr GetQuantType(Type type) { } } +bool IsPerTensorType(QuantType quant_type) { + return std::holds_alternative(quant_type); +} + +bool IsPerChannelType(QuantType quant_type) { + return std::holds_alternative(quant_type); +} + +UniformQuantizedType GetPerTensorType(QuantType quant_type) { + return std::get(quant_type); +} + +UniformQuantizedPerAxisType GetPerChannelType(QuantType quant_type) { + return std::get(quant_type); +} + // Extract scale and zero point info from input quant type info. void GetQuantizationParams(OpBuilder &builder, Location loc, QuantType quant_type, Value &scales, @@ -161,7 +107,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, } else { auto &quant_per_channel_type = std::get(quant_type); - llvm::SmallVector scales_vec; + SmallVector scales_vec; for (auto scale : quant_per_channel_type.getScales()) scales_vec.push_back(scale); scales = builder.create( @@ -172,7 +118,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getF32Type()), scales_vec)); if (output_zero_point_in_fp) { - llvm::SmallVector zero_points_vec; + SmallVector zero_points_vec; for (auto zero_point : quant_per_channel_type.getZeroPoints()) zero_points_vec.push_back(zero_point); zero_points = builder.create( @@ -183,7 +129,7 @@ void GetQuantizationParams(OpBuilder &builder, Location loc, builder.getF32Type()), zero_points_vec)); } else { - llvm::SmallVector zero_points_vec; + SmallVector zero_points_vec; for (auto zero_point : quant_per_channel_type.getZeroPoints()) zero_points_vec.push_back(zero_point); zero_points = builder.create( @@ -241,6 +187,147 @@ Type GetQuantStorageType(Type type) { } } +Type GetQuantStorageType(QuantType type) { + if (IsPerTensorType(type)) { + return GetPerTensorType(type).getStorageType(); + } else { + return GetPerChannelType(type).getStorageType(); + } +} + +Value ApplyMergedScalesAndZps(OpBuilder &builder, Location loc, + QuantType input_quant_type, + QuantType output_quant_type, + Value input_float_tensor) { + // Use single merged scale and merged zp if both input and output are + // per-tensor quantized. Otherwise use a vector. + if (IsPerTensorType(input_quant_type) && IsPerTensorType(output_quant_type)) { + UniformQuantizedType input_per_tensor_tyep = + GetPerTensorType(input_quant_type); + UniformQuantizedType output_per_tensor_tyep = + GetPerTensorType(output_quant_type); + double merged_scale_fp = + input_per_tensor_tyep.getScale() / output_per_tensor_tyep.getScale(); + auto merged_scale = builder.create( + loc, builder.getF32FloatAttr(static_cast(merged_scale_fp))); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_scale, + /*broadcast_dimensions=*/nullptr); + // Add merged_zp only when it is non-zero. + double merged_zp_fp = + output_per_tensor_tyep.getZeroPoint() - + input_per_tensor_tyep.getZeroPoint() * merged_scale_fp; + if (merged_zp_fp != 0) { + Value merged_zp = builder.create( + loc, builder.getF32FloatAttr(static_cast(merged_zp_fp))); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_zp, /*broadcast_dimensions=*/nullptr); + } + } else { + int64_t channel_size = + IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getScales().size() + : GetPerChannelType(input_quant_type).getScales().size(); + int64_t quantized_dimension = + IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getQuantizedDimension() + : GetPerChannelType(input_quant_type).getQuantizedDimension(); + SmallVector merged_scale_double, merged_zp_double; + merged_scale_double.resize(channel_size); + merged_zp_double.resize(channel_size); + for (int i = 0; i < channel_size; ++i) { + merged_scale_double[i] = + (IsPerChannelType(input_quant_type) + ? GetPerChannelType(input_quant_type).getScales()[i] + : GetPerTensorType(input_quant_type).getScale()) / + (IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getScales()[i] + : GetPerTensorType(output_quant_type).getScale()); + merged_zp_double[i] = + (IsPerChannelType(output_quant_type) + ? GetPerChannelType(output_quant_type).getZeroPoints()[i] + : GetPerTensorType(output_quant_type).getZeroPoint()) - + (IsPerChannelType(input_quant_type) + ? GetPerChannelType(input_quant_type).getZeroPoints()[i] + : GetPerTensorType(input_quant_type).getZeroPoint()) * + merged_scale_double[i]; + } + SmallVector merged_scale_float(merged_scale_double.begin(), + merged_scale_double.end()), + merged_zp_float(merged_zp_double.begin(), merged_zp_double.end()); + + auto broadcast_dims = DenseIntElementsAttr::get( + RankedTensorType::get({1}, builder.getI64Type()), + {quantized_dimension}); + Value merged_scale = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channel_size}, builder.getF32Type()), + merged_scale_float)); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_scale, broadcast_dims); + if (llvm::any_of(merged_zp_float, [](double zp) { return zp != 0; })) { + Value merged_zp = builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get({channel_size}, builder.getF32Type()), + merged_zp_float)); + input_float_tensor = builder.create( + loc, input_float_tensor, merged_zp, broadcast_dims); + } + } + return input_float_tensor; +} + +// This helper function create ops to requantize `input` tensor and returns the +// output tensor. Clamping is done if output integer bit-width < i32. It assumes +// that if both input and output tensor are per-channel quantized, they have the +// same quantization axis. +// +// Requantization is essentially dequantize --> quantize. +// +// Dequantize: (input - zp) * scale +// Quantize: input / scale + zp +// +// Hence, +// output = (input - input_zp) * input_scale / output_scale + output_zp +// +// This is simplified as: +// output = input * merged_scale + merged_zp +// where: +// merged_zp = output_zp - input_zp * merged_scale. +// merged_scale = input_scale / output_scale. +Value Requantize(mlir::OpState op, Value input, QuantType input_quant_type, + QuantType output_quant_type, TensorType output_tensor_type, + ConversionPatternRewriter &rewriter) { + // Skip requantization when input and result have the same type. + if (input_quant_type == output_quant_type) { + return rewriter.create(op->getLoc(), output_tensor_type, + input); + } + + auto float_tensor_type = output_tensor_type.clone(rewriter.getF32Type()); + Value output_float = + rewriter.create(op->getLoc(), float_tensor_type, input); + + output_float = + ApplyMergedScalesAndZps(rewriter, op->getLoc(), input_quant_type, + output_quant_type, output_float); + + // Clamp output if the output integer bit-width <32. + if (output_tensor_type.getElementType().cast().getWidth() < 32) { + Value quantization_min, quantization_max; + GetQuantizationStorageInfo(rewriter, op->getLoc(), output_quant_type, + quantization_min, quantization_max); + // Clamp results by [quantization_min, quantization_max]. + output_float = rewriter.create( + op->getLoc(), quantization_min, output_float, quantization_max); + } + + output_float = rewriter.create( + op->getLoc(), float_tensor_type, output_float); + return rewriter.create(op->getLoc(), output_tensor_type, + output_float); +} + class ConvertUniformQuantizeOp : public OpConversionPattern { public: @@ -255,10 +342,24 @@ class ConvertUniformQuantizeOp if (succeeded(quant_type)) { return matchAndRewriteQuantize(op, adaptor, rewriter, *quant_type); } - } else if (input_element_type.isa()) { - return matchAndRewriteRequantize(op, adaptor, rewriter); + } else if (input_element_type.isa()) { + auto input_quant_type = GetQuantType(input_element_type); + auto output_quant_type = GetQuantType(op.getResult().getType()); + if (succeeded(input_quant_type) && succeeded(output_quant_type)) { + if (IsPerChannelType(*input_quant_type) && + IsPerChannelType(*output_quant_type) && + GetPerChannelType(*input_quant_type).getQuantizedDimension() != + GetPerChannelType(*output_quant_type).getQuantizedDimension()) { + op->emitError("Cannot requantize while changing quantization_axis"); + return failure(); + } + return matchAndRewriteRequantize(op, adaptor, rewriter, + *input_quant_type, *output_quant_type); + } } - return rewriter.notifyMatchFailure(op, "Unsupported input element type."); + op->emitError("Unsupported input element type."); + return failure(); } LogicalResult matchAndRewriteQuantize(mhlo::UniformQuantizeOp op, @@ -298,16 +399,14 @@ class ConvertUniformQuantizeOp LogicalResult matchAndRewriteRequantize( mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto input_quantized_type = getElementTypeOrSelf(op.getOperand().getType()) - .cast(); - auto output_quantized_type = getElementTypeOrSelf(op.getResult().getType()) - .cast(); + ConversionPatternRewriter &rewriter, QuantType input_quant_type, + QuantType output_quant_type) const { rewriter.replaceOp( - op, Requantize(op, adaptor.getOperand(), input_quantized_type, - output_quantized_type, + op, Requantize(op, adaptor.getOperand(), input_quant_type, + output_quant_type, + /*output_tensor_type=*/ op.getResult().getType().cast().clone( - output_quantized_type.getStorageType()), + GetQuantStorageType(output_quant_type)), rewriter)); return success(); } @@ -357,18 +456,18 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto lhs_element_type = - op.getLhs().getType().getElementType().dyn_cast(); - auto rhs_element_type = - op.getRhs().getType().getElementType().dyn_cast(); - auto result_element_type = op.getResult() - .getType() - .getElementType() - .dyn_cast(); + auto lhs_quant_type = + GetQuantType(getElementTypeOrSelf(op.getLhs().getType())); + auto rhs_quant_type = + GetQuantType(getElementTypeOrSelf(op.getRhs().getType())); + auto res_quant_type = + GetQuantType(getElementTypeOrSelf(op.getResult().getType())); // We only handle cases where lhs, rhs and results all have quantized // element type. - if (!lhs_element_type || !rhs_element_type || !result_element_type) { + if (failed(lhs_quant_type) || IsPerChannelType(*lhs_quant_type) || + failed(rhs_quant_type) || IsPerChannelType(*rhs_quant_type) || + failed(res_quant_type) || IsPerChannelType(*res_quant_type)) { op->emitError( "AddOp requires the same quantized element type for all operands and " "results"); @@ -384,17 +483,17 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // TODO: b/260280919 - Consider avoiding conversion to int32. Value lhs = adaptor.getLhs(); Value lhs_int32_tensor = - Requantize(op, lhs, lhs_element_type, result_element_type, + Requantize(op, lhs, *lhs_quant_type, *res_quant_type, res_int32_tensor_type, rewriter); Value rhs = adaptor.getRhs(); Value rhs_int32_tensor = - Requantize(op, rhs, rhs_element_type, result_element_type, + Requantize(op, rhs, *rhs_quant_type, *res_quant_type, res_int32_tensor_type, rewriter); Value zero_point = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getZeroPoint()))); + GetPerTensorType(*res_quant_type).getZeroPoint()))); // Now the lhs and rhs have been coverted to the same scale and zps. // Given: @@ -411,24 +510,26 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { Value res_int32 = rewriter.create( op->getLoc(), res_int32_tensor_type, add_result, zero_point, nullptr); - if (result_element_type.getStorageType().isInteger(32)) { + if (GetQuantStorageType(*res_quant_type).isInteger(32)) { // For i32, clamping is not needed. rewriter.replaceOp(op, res_int32); } else { // Clamp results by [quantization_min, quantization_max] when storage type // is not i32. Value result_quantization_min = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getStorageTypeMin()))); + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + GetPerTensorType(*res_quant_type).getStorageTypeMin()))); Value result_quantization_max = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_element_type.getStorageTypeMax()))); + op->getLoc(), + rewriter.getI32IntegerAttr(static_cast( + GetPerTensorType(*res_quant_type).getStorageTypeMax()))); res_int32 = rewriter.create( op->getLoc(), res_int32_tensor_type, result_quantization_min, res_int32, result_quantization_max); // Convert results back to result storage type. auto res_final_tensor_type = - res_int32_tensor_type.clone(result_element_type.getStorageType()); + res_int32_tensor_type.clone(GetQuantStorageType(*res_quant_type)); rewriter.replaceOpWithNewOp(op, res_final_tensor_type, res_int32); } @@ -512,7 +613,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, // Calculate the output tensor shape. This is input tensor dims minus // contracting dims. auto ranked_tensor = tensor.getType().cast(); - llvm::SmallVector output_dims; + SmallVector output_dims; for (int64_t i = 0; i < ranked_tensor.getRank(); ++i) { if (absl::c_count(reduction_dims, i) == 0) { output_dims.push_back(ranked_tensor.getDimSize(i)); @@ -581,7 +682,7 @@ Value CalculateDynamicOutputDims(OpBuilder &builder, Location loc, Value lhs, // Calculate each output dim and concatenate into a 1D tensor. // Output dims are batching dims, spatial dims, LHS result dims, RHS result // dims. - llvm::SmallVector output_dims; + SmallVector output_dims; for (int64_t i = 0; i < lhs_shape.getRank(); ++i) { if (absl::c_count(dims.lhs_batching_dims, i) != 0) { output_dims.push_back(GetDimValue(builder, loc, lhs, lhs_shape, i)); @@ -612,8 +713,8 @@ Value CalculateDynamicOutputDims(OpBuilder &builder, Location loc, Value lhs, Value BroadcastZpContribution(OpBuilder &builder, Location loc, Value zp_contribution, - llvm::ArrayRef reduction_dims, - llvm::ArrayRef batching_dims, + ArrayRef reduction_dims, + ArrayRef batching_dims, int64_t non_batching_starting_idx, TensorType output_tensor_type, Value &output_dims_value, Value lhs, Value rhs, @@ -623,7 +724,7 @@ Value BroadcastZpContribution(OpBuilder &builder, Location loc, // broadcast. auto zp_contribution_rank = zp_contribution.getType().cast().getRank(); - llvm::SmallVector broadcast_dims; + SmallVector broadcast_dims; broadcast_dims.resize(zp_contribution_rank, 0); // Result tensor will have batching dims first, then LHS result dims, then // RHS result dims. So non-batching result dims index doesn't start from 0. @@ -677,9 +778,8 @@ Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, Value output_dims_value = nullptr; // Calculate LHS contribution when RHS zp is non-zero. if (rhs_zp != 0) { - llvm::SmallVector reduction_dims = - llvm::to_vector(llvm::concat(dims.lhs_spatial_dims, - dims.lhs_contracting_dims)); + SmallVector reduction_dims = to_vector(llvm::concat( + dims.lhs_spatial_dims, dims.lhs_contracting_dims)); Value lhs_zp_contribution = CreateZeroPointPartialOffset(builder, loc, lhs, rhs_zp, reduction_dims); // Broadcast lhs ZP contribution to result tensor shape. @@ -691,9 +791,8 @@ Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value lhs, } // Calculate RHS contribution when LHS zp is non-zero. if (lhs_zp != 0) { - llvm::SmallVector reduction_dims = - llvm::to_vector(llvm::concat(dims.rhs_spatial_dims, - dims.rhs_contracting_dims)); + SmallVector reduction_dims = to_vector(llvm::concat( + dims.rhs_spatial_dims, dims.rhs_contracting_dims)); Value rhs_zp_contribution = CreateZeroPointPartialOffset(builder, loc, rhs, lhs_zp, reduction_dims); // Broadcast rhs ZP contribution to result tensor shape. @@ -765,7 +864,7 @@ Value CreateDotLikeKernel(OpBuilder &builder, Location loc, auto original_padding = op.getPaddingAttr().getValues(); // Explicitly pad LHS with zp and update LHS value. - llvm::SmallVector new_attrs(attrs); + SmallVector new_attrs(attrs); if (llvm::any_of(original_padding, [](int64_t x) { return x != 0; })) { Value zp = builder.create( loc, @@ -779,7 +878,7 @@ Value CreateDotLikeKernel(OpBuilder &builder, Location loc, // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence // we add 0 to the beginning and end of the padding vectors. int64_t rank = lhs.getType().cast().getRank(); - llvm::SmallVector padding_low(rank, 0), padding_high(rank, 0), + SmallVector padding_low(rank, 0), padding_high(rank, 0), padding_interior(rank, 0); for (int64_t i = 1; i < rank - 1; ++i) { padding_low[i] = original_padding[i * 2 - 2]; @@ -962,7 +1061,7 @@ class ConvertUniformQuantizedDotOp : public OpConversionPattern { rewriter.getContext(), /*lhsBatchingDimensions=*/{}, /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{1}, /*rhsContractingDimensions=*/{0}); - llvm::SmallVector attrs(op->getAttrs()); + SmallVector attrs(op->getAttrs()); attrs.push_back( {StringAttr::get(rewriter.getContext(), "dot_dimension_numbers"), dims}); @@ -1088,7 +1187,7 @@ FailureOr VerifyAndConstructDims( auto res_element_quant_per_channel_type = getElementTypeOrSelf(op.getResult()) .cast(); - llvm::SmallVector scale_ratios( + SmallVector scale_ratios( res_element_quant_per_channel_type.getScales().size()); for (int i = 0; i < scale_ratios.size(); ++i) { scale_ratios[i] = @@ -1177,7 +1276,7 @@ class ConvertGenericOp : public ConversionPattern { // Determine new result type: use storage type for uq types; use original // type otherwise. - llvm::SmallVector new_result_types; + SmallVector new_result_types; for (auto result_type : op->getResultTypes()) { new_result_types.push_back(GetQuantStorageType(result_type)); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index f20d1b3609361e..f00e3402e04902 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -579,6 +579,128 @@ func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { ExecuteAndCompareResultsWithTfKernel(kProgram, {&input}); } +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannel) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, + %input_zp: tensor<10xi32>, %output_scale: tensor<10xf32>, + %output_zp: tensor<10xi32> + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = 1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = 1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor<10xf32>, tensor<10xi32>, + tensor<10xf32>, tensor<10xi32> + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({10})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({10})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerTensorToPerChannel) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor, %input_zp: tensor, + %output_scale: tensor<10xf32>, %output_zp: tensor<10xi32> + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = -1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = 1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor, tensor, + tensor<10xf32>, tensor<10xi32> + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, CreateRandomF32Literal({}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({10})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantizePerChannelToPerTensor) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main( + %input: tensor<10x10xi8>, %input_scale: tensor<10xf32>, + %input_zp: tensor<10xi32>, %output_scale: tensor, %output_zp: tensor + ) -> tensor<10x10xi8> { + %0 = "tf.Cast"(%input) {} : (tensor<10x10xi8>) -> tensor<10x10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = 1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = -1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x10x!tf_type.qint8>, tensor<10xf32>, tensor<10xi32>, + tensor, tensor + ) -> tensor<10x10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x10x!tf_type.qint8>) -> tensor<10x10xi8> + return %2 : tensor<10x10xi8> +})mlir"; + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN( + auto input_scale, + CreateRandomF32Literal({10}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto input_zp, CreateRandomI32Literal({10})); + TF_ASSERT_OK_AND_ASSIGN( + auto output_scale, CreateRandomF32Literal({}, /*min=*/0.0001, /*max=*/2)); + TF_ASSERT_OK_AND_ASSIGN(auto output_zp, CreateRandomI32Literal({})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&input, &input_scale, &input_zp, &output_scale, &output_zp}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); +} + TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) { constexpr absl::string_view kProgram = R"mlir( func.func @main(%lhs: tensor<10x10xi32>, %rhs: tensor<10x10xi32>) -> tensor<10x10xi32> { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index e022bcd81c447a..d943e9c1b04fdb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -341,6 +341,91 @@ func.func @requantize_merged_zp_zero( // ----- +// CHECK-LABEL: func @requantize_per_channel +func.func @requantize_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_channel_to_per_tensor +func.func @requantize_per_channel_to_per_tensor( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-5.000000e+00, -1.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func @requantize_per_tensor_to_per_channel +func.func @requantize_per_tensor_to_per_channel( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[VAL1:.*]] = mhlo.convert %arg0 : (tensor<2x2xi8>) -> tensor<2x2xf32> + // CHECK-DAG: %[[MERGED_SCALE:.*]] = mhlo.constant dense<[1.000000e+00, 5.000000e-01]> : tensor<2xf32> + // CHECK: %[[VAL2:.*]] = chlo.broadcast_multiply %[[VAL1]], %[[MERGED_SCALE]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[MERGED_ZP:.*]] = mhlo.constant dense<[-1.000000e+00, -2.000000e+00]> : tensor<2xf32> + // CHECK: %[[VAL3:.*]] = chlo.broadcast_add %[[VAL2]], %[[MERGED_ZP]] + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK: %[[VAL4:.*]] = mhlo.clamp %[[QUANT_MIN]], %[[VAL3]], %[[QUANT_MAX]] + // CHECK: %[[VAL5:.*]] = mhlo.round_nearest_even %[[VAL4]] : tensor<2x2xf32> + // CHECK: %[[VAL6:.*]] = mhlo.convert %[[VAL5]] : (tensor<2x2xf32>) -> tensor<2x2xi8> + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @requantize_per_channel_change_axis( + %arg0: tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> { + // expected-error@+2 {{Cannot requantize while changing quantization_axis}} + // expected-error@+1 {{failed to legalize operation 'mhlo.uniform_quantize' that was explicitly marked illegal}} + %0 = mhlo.uniform_quantize %arg0 : ( + tensor<2x2x!quant.uniform> + ) -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + // CHECK-LABEL: func @dot func.func @dot(%arg0: tensor<2x2x!quant.uniform>, %arg1: tensor<2x2x!quant.uniform> From c24de56bdeb834cdcedb6ee137428d734b7fe5d9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 10:10:13 -0800 Subject: [PATCH 174/391] Integrate LLVM at llvm/llvm-project@865f54e50173 Updates LLVM usage to match [865f54e50173](https://github.com/llvm/llvm-project/commit/865f54e50173) PiperOrigin-RevId: 583084484 --- third_party/llvm/generated.patch | 50 -------------------------------- third_party/llvm/workspace.bzl | 4 +-- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index bc734d4fe6ced1..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,51 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -@@ -5172,6 +5172,18 @@ - ], - ) - -+gentbl( -+ name = "ReadTAPIOptsTableGen", -+ strip_include_prefix = "tools/llvm-readtapi", -+ tbl_outs = [( -+ "-gen-opt-parser-defs", -+ "tools/llvm-readtapi/TapiOpts.inc", -+ )], -+ tblgen = ":llvm-tblgen", -+ td_file = "tools/llvm-readtapi/TapiOpts.td", -+ td_srcs = ["include/llvm/Option/OptParser.td"], -+) -+ - cc_binary( - name = "llvm-readtapi", - testonly = True, -@@ -5183,6 +5195,8 @@ - stamp = 0, - deps = [ - ":Object", -+ ":Option", -+ ":ReadTAPIOptsTableGen", - ":Support", - ":TextAPI", - ], -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -1022,6 +1022,7 @@ - ":CAPIIR", - ":CAPIQuant", - ":MLIRBindingsPythonHeadersAndDeps", -+ "@pybind11", - ], - ) - -@@ -1040,6 +1041,7 @@ - ":CAPIIR", - ":CAPISparseTensor", - ":MLIRBindingsPythonHeadersAndDeps", -+ "@pybind11", - ], - ) - diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 869bcb78ea2e15..f7c7984832623d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "8ea8dd9a017182d167f39f521ef397afba5a0fd5" - LLVM_SHA256 = "6963e268a2e03ff956f2457a629a8dafab8682eb4bbb664ff8dc668dd3faef7b" + LLVM_COMMIT = "865f54e501739f382d33866baebfd0f9aaad01bb" + LLVM_SHA256 = "16dc3aa4f7688f11e20d1f506419e99217018aa8b9ae02453d63b95b76541a2a" tf_http_archive( name = name, From 4417253ed0550dbfe028429fa191af9f7b499e29 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Thu, 16 Nov 2023 10:30:22 -0800 Subject: [PATCH 175/391] Account for extra encoding bytes and extract all repeated nodes into separate chunks when splitting repeated fields. PiperOrigin-RevId: 583090688 --- tensorflow/tools/proto_splitter/cc/BUILD | 9 ++++-- .../cc/composable_splitter_base.cc | 2 ++ .../cc/graph_def_splitter_test.cc | 12 +++++-- .../cc/repeated_field_splitter.cc | 31 ++++++++++--------- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index 266d69479ff8a0..716b7fa317fafe 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -1,12 +1,13 @@ -# Description: -# Utilities for splitting and joining large protos > 2GB. -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test", ) +# Description: +# Utilities for splitting and joining large protos > 2GB. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ @@ -43,6 +44,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/tools/proto_splitter:chunk_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -197,6 +199,7 @@ cc_library( ":composable_splitter", ":max_size", ":size_splitter", + ":split", ":util", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/status", diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index 8a9ee3091a1366..4fdaa2777d9f8f 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "riegeli/bytes/fd_writer.h" // from @riegeli @@ -127,6 +128,7 @@ absl::Status ComposableSplitterBase::Write(std::string file_prefix) { chunk)) { auto msg_chunk = std::get>(chunk); + LOG(INFO) << "Writing chunk of size " << msg_chunk->ByteSizeLong(); writer.WriteRecord(*msg_chunk); chunk_metadata->set_size(msg_chunk->ByteSizeLong()); chunk_metadata->set_type(::proto_splitter::ChunkInfo::MESSAGE); diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc index 036c90fde04a94..bbb2587a2d3c39 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/graph_def_splitter.h" +#include #include #include #include @@ -179,7 +180,12 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { const std::string graph_def_path = io::JoinPath(testing::TensorFlowSrcRoot(), "tools/proto_splitter/testdata", "split-lots-nodes.pb"); - int64_t max_size = 500; + + // split-lots-nodes.pb has 15 nodes that are 95 or 96 bytes each. The max size + // is set to "exactly" the size of 5 nodes, but with the extra encoding bytes, + // only 4 nodes should fit in each chunk. Thus, there should be exactly 4 + // chunks created for all 15 nodes. + int64_t max_size = 96 * 5; DebugSetMaxSize(max_size); TF_EXPECT_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), @@ -196,7 +202,9 @@ TEST(GraphDefSplitterTest, TestLotsNodes) { *chunked_message, EqualsProto(R"pb(chunk_index: 0 chunked_fields { message { chunk_index: 1 } } - chunked_fields { message { chunk_index: 2 } })pb")); + chunked_fields { message { chunk_index: 2 } } + chunked_fields { message { chunk_index: 3 } } + chunked_fields { message { chunk_index: 4 } })pb")); auto chunks = x.first; EXPECT_CHUNK_SIZES(chunks, max_size); diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc index e836556d569974..01601c7e22a1fc 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" +#include #include +#include #include #include "absl/status/status.h" @@ -23,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" +#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" @@ -31,6 +34,10 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { +// Additional bytes added to each node to account for the extra info needed to +// encode the field key (realistically 3 but making it 5 for some wiggle room). +constexpr int kExtraBytes = 5; + template absl::StatusOr> RepeatedFieldSplitters::Create( @@ -65,13 +72,8 @@ absl::StatusOr RepeatedFieldSplitters< // List of indices at which to split the repeated field. For example, [3, 5] // means that the field list is split into: [:3], [3:5], [5:] - std::vector repeated_msg_split = {}; - // Should be the same length as the list above. Contains new protos to hold - // the elements that are split from the original proto. - // From the [3, 5] example above, the messages in this list contain nodes - // [3:5] and [5:] - std::vector> repeated_new_msg; - // Track the total size of the current node split. + std::vector repeated_msg_split = {0}; + // Track the total byte size of the current node split. uint64_t total_size = 0; // Linearly iterate through all nodes. It may be possible to optimize this @@ -99,17 +101,12 @@ absl::StatusOr RepeatedFieldSplitters< } if (total_size + node_size > max_size) { repeated_msg_split.push_back(i); - auto new_chunk = std::make_shared(); - repeated_new_msg.push_back(new_chunk); - std::vector empty_fields = {}; - auto x = std::make_unique(new_chunk); - TF_RETURN_IF_ERROR(AddChunk(std::move(x), &empty_fields)); total_size = 0; } - total_size += node_size; + total_size += node_size + kExtraBytes; } - if (!repeated_msg_split.empty()) { + if (repeated_msg_split.size() > 1) { auto repeated_nodes_ptrs = ret.parent->GetReflection() ->template MutableRepeatedPtrField(ret.parent, @@ -127,7 +124,11 @@ absl::StatusOr RepeatedFieldSplitters< for (int i = 1; i < repeated_msg_split.size(); ++i) { start = repeated_msg_split[i - 1]; int end = repeated_msg_split[i]; - std::shared_ptr new_msg = repeated_new_msg[i - 1]; + + auto new_msg = std::make_shared(); + std::vector empty_fields; + auto x = std::make_unique(new_msg); + TF_RETURN_IF_ERROR(AddChunk(std::move(x), &empty_fields)); // Move nodes into new_msg. TF_ASSIGN_OR_RETURN(auto new_ret, From be283d6b5a01a7878283990b65132ccf5bbb7264 Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Thu, 16 Nov 2023 10:37:58 -0800 Subject: [PATCH 176/391] Add type annotations for modified test_util decorators. Adds type annotations for the recently refactored decorators in `test_util.py`. Type annotations for other objects in this file are forthcoming. PiperOrigin-RevId: 583093098 --- tensorflow/python/framework/test_util.py | 142 ++++++++++++++++------- 1 file changed, 99 insertions(+), 43 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 3ff5963a274853..7c86ed1709e152 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -18,7 +18,7 @@ import collections from collections import OrderedDict -from collections.abc import Iterator +from collections.abc import Callable, Iterator import contextlib import functools import gc @@ -30,7 +30,7 @@ import tempfile import threading import time -from typing import Union +from typing import Any, cast, Optional, overload, TypeVar, Union import unittest from absl.testing import parameterized @@ -98,6 +98,9 @@ from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export +_TC = TypeVar("_TC", bound=type["TensorFlowTestCase"]) +_R = TypeVar("_R") + # If the below import is made available through the BUILD rule, then this # function is overridden and will instead return True and cause Tensorflow @@ -1273,7 +1276,7 @@ def wrapper(*args, **kwargs): return wrapper -def add_graph_building_optimization_tests(cls): +def add_graph_building_optimization_tests(cls: _TC) -> _TC: """Adds methods with graph_building_optimization enabled to the test suite. Example: @@ -1328,7 +1331,9 @@ def disable_eager_op_as_function(unused_msg): return _disable_test(execute_func=False) -def set_xla_env_flag(flag=""): +def set_xla_env_flag( + flag: str = "", +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Decorator for setting XLA_FLAGS prior to running a test. This function returns a decorator intended to be applied to test methods in @@ -1352,10 +1357,10 @@ def testFoo(self): function. """ - def decorator(f): + def decorator(f: Callable[..., _R]) -> Callable[..., _R]: @functools.wraps(f) - def decorated(*args, **kwargs): + def decorated(*args, **kwargs) -> _R: original_xla_flags = os.environ.get("XLA_FLAGS") new_xla_flags = flag if original_xla_flags: @@ -1374,7 +1379,9 @@ def decorated(*args, **kwargs): return decorator -def build_as_function_and_v1_graph(func): +def build_as_function_and_v1_graph( + func: Callable[..., Any], +) -> Callable[..., None]: """Run a test case in v1 graph mode and inside tf.function in eager mode. WARNING: This decorator can only be used in test cases that statically checks @@ -1398,7 +1405,12 @@ def build_as_function_and_v1_graph(func): @parameterized.named_parameters(("_v1_graph", "v1_graph"), ("_function", "function")) @functools.wraps(func) - def decorated(self, run_mode, *args, **kwargs): + def decorated( + self: "TensorFlowTestCase", + run_mode: str, + *args, + **kwargs, + ) -> None: if run_mode == "v1_graph": with ops.Graph().as_default(): func(self, *args, **kwargs) @@ -1558,8 +1570,10 @@ def run_eagerly(self, **kwargs): return decorator -def run_in_v1_v2(device_to_use: str = None, - assert_no_eager_garbage: bool = False): +def run_in_v1_v2( + device_to_use: Optional[str] = None, + assert_no_eager_garbage: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Execute the decorated test in v1 and v2 modes. The overall execution is similar to that of `run_in_graph_and_eager_mode`. @@ -1581,13 +1595,13 @@ def run_in_v1_v2(device_to_use: str = None, A decorator that runs a given test in v1 and v2 modes. """ - def decorator(f): + def decorator(f: Callable[..., Any]) -> Callable[..., None]: decorator_tag = "wrapped_with_v1_v2_decorator" if hasattr(f, decorator_tag): # Already decorated with this very same decorator return f - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> None: logging.info("Running %s in V1 mode.", f.__name__) try: with self.subTest("V1_mode"): @@ -1686,7 +1700,19 @@ def bound_f(): return decorated -def deprecated_graph_mode_only(func): +@overload +def deprecated_graph_mode_only(func: Callable[..., _R]) -> Callable[..., _R]: + ... + + +@overload +def deprecated_graph_mode_only(func: _TC) -> Optional[_TC]: + ... + + +def deprecated_graph_mode_only( + func: Union[_TC, Callable[..., _R]], +) -> Union[_TC, Callable[..., _R]]: """Execute the decorated test in graph mode. This is a decorator intended to be applied to tests that are not compatible @@ -1819,7 +1845,7 @@ def run_v2_only(func=None, reason=None): return _run_vn_only(func=func, v2=True, reason=reason) -def run_gpu_only(func): +def run_gpu_only(func: Callable[..., _R]) -> Callable[..., _R]: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence @@ -1835,7 +1861,7 @@ def run_gpu_only(func): if tf_inspect.isclass(func): raise ValueError("`run_gpu_only` only supports test methods.") - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: if not is_gpu_available(): self.skipTest("Test requires GPU") @@ -1844,7 +1870,7 @@ def decorated(self, *args, **kwargs): return decorated -def run_cuda_only(func): +def run_cuda_only(func: Callable[..., _R]) -> Callable[..., _R]: """Execute the decorated test only if a GPU is available. This function is intended to be applied to tests that require the presence @@ -1860,7 +1886,7 @@ def run_cuda_only(func): if tf_inspect.isclass(func): raise ValueError("`run_cuda_only` only supports test methods.") - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: if not is_gpu_available(cuda_only=True): self.skipTest("Test requires CUDA GPU") @@ -1869,7 +1895,7 @@ def decorated(self, *args, **kwargs): return decorated -def run_gpu_or_tpu(func): +def run_gpu_or_tpu(func: Callable[..., _R]) -> Callable[..., _R]: """Execute the decorated test only if a physical GPU or TPU is available. This function is intended to be applied to tests that require the presence @@ -1888,7 +1914,7 @@ def run_gpu_or_tpu(func): if tf_inspect.isclass(func): raise ValueError("`run_gpu_or_tpu` only supports test methods.") - def decorated(self, *args, **kwargs): + def decorated(self: "TensorFlowTestCase", *args, **kwargs) -> _R: if config.list_physical_devices("GPU"): return func(self, "GPU", *args, **kwargs) @@ -2120,7 +2146,7 @@ def run(self, *args, **kwargs): raise -def disable_cudnn_autotune(func): +def disable_cudnn_autotune(func: Callable[..., _R]) -> Callable[..., _R]: """Disable autotuning during the call to this function. Some tests want to base assertions on a graph being isomorphic with a copy. @@ -2133,7 +2159,7 @@ def disable_cudnn_autotune(func): Decorated function. """ - def decorated(*args, **kwargs): + def decorated(*args, **kwargs) -> _R: original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" original_xla_flags = os.environ.get("XLA_FLAGS") @@ -2159,13 +2185,17 @@ def decorated(*args, **kwargs): # The description is just for documentation purposes. -def enable_tf_xla_constant_folding(description): +def enable_tf_xla_constant_folding( + description: str, +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: if not isinstance(description, str): raise ValueError("'description' should be string, got {}".format( type(description))) - def enable_tf_xla_constant_folding_impl(func): + def enable_tf_xla_constant_folding_impl( + func: Callable[..., _R], + ) -> Callable[..., _R]: """Enable constant folding during the call to this function. Some tests fail without constant folding. @@ -2177,7 +2207,7 @@ def enable_tf_xla_constant_folding_impl(func): Decorated function. """ - def decorated(*args, **kwargs): + def decorated(*args, **kwargs) -> _R: original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) result = func(*args, **kwargs) @@ -2190,11 +2220,13 @@ def decorated(*args, **kwargs): # Updates test function by selectively disabling it. -def _disable_test(execute_func): +def _disable_test( + execute_func: bool, +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: - def disable_test_impl(func): + def disable_test_impl(func: Callable[..., _R]) -> Callable[..., _R]: - def decorated(*args, **kwargs): + def decorated(*args, **kwargs) -> _R: if execute_func: return func(*args, **kwargs) @@ -2204,64 +2236,84 @@ def decorated(*args, **kwargs): # The description is just for documentation purposes. -def disable_xla(description): # pylint: disable=unused-argument +def disable_xla( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if xla is not enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_mlir_bridge(description): # pylint: disable=unused-argument +def disable_mlir_bridge( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if MLIR bridge is not enabled.""" execute_func = not is_mlir_bridge_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_asan(description): # pylint: disable=unused-argument +def disable_asan( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if ASAN is not enabled.""" execute_func = not is_asan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_msan(description): # pylint: disable=unused-argument +def disable_msan( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if MSAN is not enabled.""" execute_func = not is_msan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_tsan(description): # pylint: disable=unused-argument +def disable_tsan( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if TSAN is not enabled.""" execute_func = not is_tsan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_ubsan(description): # pylint: disable=unused-argument +def disable_ubsan( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """Execute the test method only if UBSAN is not enabled.""" execute_func = not is_ubsan_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def disable_tfrt(unused_description): - - def disable_tfrt_impl(cls_or_func): +def disable_tfrt( + unused_description: str, # pylint: disable=unused-argument +) -> Callable[ + [Union[_TC, Callable[..., _R]]], + Union[_TC, Callable[..., _R], None] +]: + + def disable_tfrt_impl( + cls_or_func: Union[_TC, Callable[..., _R]] + ) -> Union[_TC, Callable[..., _R], None]: """Execute the test only if tfrt is not enabled.""" if tf_inspect.isclass(cls_or_func): if tfrt_utils.enabled(): return None else: - return cls_or_func + return cast(_TC, cls_or_func) else: - def decorated(*args, **kwargs): + func = cast(Callable[..., _R], cls_or_func) + def decorated(*args, **kwargs) -> _R: if tfrt_utils.enabled(): return else: - return cls_or_func(*args, **kwargs) + return func(*args, **kwargs) return tf_decorator.make_decorator(cls_or_func, decorated) @@ -2296,19 +2348,23 @@ def all_test_methods_impl(cls): # The description is just for documentation purposes. -def no_xla_auto_jit(description): # pylint: disable=unused-argument +def no_xla_auto_jit( + description: str, # pylint: disable=unused-argument +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: """This test is not intended to be run with XLA auto jit enabled.""" execute_func = not is_xla_enabled() return _disable_test(execute_func) # The description is just for documentation purposes. -def xla_allow_fallback(description): # pylint: disable=unused-argument +def xla_allow_fallback( + description: str, # pylint: disable=unused-argument +): - def xla_allow_fallback_impl(func): + def xla_allow_fallback_impl(func: Callable[..., _R]) -> Callable[..., _R]: """Allow fallback to TF even though testing xla.""" - def decorated(*args, **kwargs): + def decorated(*args, **kwargs) -> _R: if is_xla_enabled(): # Update the global XLABuildOpsPassFlags to enable lazy compilation, # which allows the compiler to fall back to TF classic. Remember the From 7c824b23530e4a90624a32f1669f15ae56bb6280 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 16 Nov 2023 10:39:41 -0800 Subject: [PATCH 177/391] PR #6657: [XLA:GPU ] add cuDNN flash attention support in XLA (2nd PR with only MLIR lowering and thunk/runtime) Imported from GitHub PR https://github.com/openxla/xla/pull/6657 This is the 2nd PR of splitting https://github.com/openxla/xla/pull/5910 with only MLIR lowering and thunk/r... PiperOrigin-RevId: 583093648 --- .../transforms/lmhlo_gpu_to_gpu_runtime.cc | 51 +++- .../xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 17 +- .../lhlo_gpu/IR/lhlo_gpu_ops_enums.td | 8 +- .../xla/xla/service/gpu/backend_configs.proto | 6 + .../xla/xla/service/gpu/fused_mha_thunk.cc | 64 +++-- .../xla/xla/service/gpu/fused_mha_thunk.h | 12 +- .../xla/service/gpu/gpu_fused_mha_runner.cc | 184 +++++++++----- .../xla/service/gpu/gpu_fused_mha_runner.h | 68 ++++-- .../xla/service/gpu/ir_emitter_unnested.cc | 87 ++++++- .../xla/xla/service/gpu/nvptx_compiler.cc | 2 +- .../service/gpu/runtime/fused_attention.cc | 227 +++++++++++++++--- .../mhlo_to_lhlo_with_xla.cc | 121 +++++++++- 12 files changed, 670 insertions(+), 177 deletions(-) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc index d3d52e4516b46a..3a74ed46de4244 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc @@ -741,7 +741,8 @@ class FusedAttentionForwardLowering set_attr("fmha_scale", op.getFmhaScaleAttr()); set_attr("dropout_rate", op.getDropoutRateAttr()); set_attr("seed", op.getSeedAttr()); - + set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); + set_attr("is_causal_mask", op.getIsCausalMaskAttr()); set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); set_attr("algorithm_config", op.getAlgorithmConfigAttr()); set_attr("bmm1_dot_dimension_numbers", op.getBmm1DotDimensionNumbers()); @@ -784,8 +785,10 @@ template class FusedAttentionBackwardLowering : public OpRewritePattern { private: - static constexpr const char kCustomCallTarget[] = + static constexpr const char kFusedAttentionCustomCallTarget[] = "xla.gpu.fused.attention.backward."; + static constexpr const char kFlashAttentionCustomCallTarget[] = + "xla.gpu.flash.attention.backward."; public: explicit FusedAttentionBackwardLowering(MLIRContext* ctx, UidGenerator& uid, @@ -797,11 +800,36 @@ class FusedAttentionBackwardLowering LogicalResult matchAndRewrite(FusedDotAttentionBackward op, PatternRewriter& rewriter) const override { // Get the custom call target. - std::string fused_attention = kCustomCallTarget; + bool is_flash_attention = op.getIsFlashAttention(); + std::string fused_attention = is_flash_attention + ? kFlashAttentionCustomCallTarget + : kFusedAttentionCustomCallTarget; auto num_operands = op.getNumOperands(); switch (op.getFusedMhaDag()) { + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + if (is_flash_attention) { + if (num_operands == 12) { + fused_attention += "scale.softmax"; + } else { + return op.emitOpError( + "unexpected number of operands for flash attention backward - " + "BMM_Softmax_BMM"); + } + } + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasSoftmax: + if (is_flash_attention) { + if (num_operands == 13) { + fused_attention += "scale.bias.softmax"; + } else { + return op.emitOpError( + "unexpected number of operands for flash attention backward - " + "BMM_Bias_Softmax_BMM"); + } + break; + } if (num_operands == 10) { fused_attention += "scale.softmax"; } else if (num_operands == 11) { @@ -877,7 +905,8 @@ class FusedAttentionBackwardLowering set_attr("fmha_scale", op.getFmhaScaleAttr()); set_attr("dropout_rate", op.getDropoutRateAttr()); set_attr("seed", op.getSeedAttr()); - + set_attr("is_flash_attention", op.getIsFlashAttentionAttr()); + set_attr("is_causal_mask", op.getIsCausalMaskAttr()); set_attr("fused_mha_dag", op.getFusedMhaDagAttr()); set_attr("algorithm_config", op.getAlgorithmConfigAttr()); set_attr("bmm1_grad_gemm1_dot_dimension_numbers", @@ -889,6 +918,20 @@ class FusedAttentionBackwardLowering set_attr("bmm2_grad_gemm2_dot_dimension_numbers", op.getBmm2GradGemm2DotDimensionNumbers()); + auto set_xi64 = [&](StringRef name, mlir::ArrayAttr array) { + int rank = array.size(); + SmallVector values; + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = array[i].dyn_cast(); + values.push_back(attr.getInt()); + } + set_attr(name, b.getI64TensorAttr(values)); + }; + + set_xi64("intermediate_tensor_dimensions", + op.getIntermediateTensorDimensions()); + set_xi64("intermediate_tensor_layout", op.getIntermediateTensorLayout()); + // Erase the original fused dot attention operation. rewriter.eraseOp(op); diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 3091ed06f4fec8..e56d2964d767ac 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -362,7 +362,9 @@ def LHLOGPU_fusedMHAOp : LHLOGPU_Op<"fMHA", [AttrSizedOperandSegments]> { FusedMhaDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask ); } @@ -374,21 +376,30 @@ def LHLOGPU_fusedMHABackwardOp : LHLOGPU_Op<"fMHABackward", [AttrSizedOperandSeg Arg:$bmm2_grad_gemm1_lhs, Arg:$d_output, Arg, "", [MemRead]>:$mask, + Arg, "", [MemRead]>:$bias, + Arg, "", [MemRead]>:$fwd_output, Arg:$d_bmm1_lhs, Arg:$d_bmm1_rhs, Arg:$d_bmm2_rhs, - Arg:$d_S, + Arg, "", [MemWrite]>:$d_S, + Arg, "", [MemWrite]>:$softmax_sum, + Arg, "", [MemWrite]>:$d_Q_accum, Arg:$scratch, Arg, "", [MemWrite]>:$d_bias, MHLO_DotDimensionNumbers:$bmm1_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm1_grad_gemm2_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm1_dot_dimension_numbers, MHLO_DotDimensionNumbers:$bmm2_grad_gemm2_dot_dimension_numbers, + I64ArrayAttr:$intermediate_tensor_dimensions, + I64ArrayAttr:$intermediate_tensor_layout, F64Attr:$fmha_scale, FusedMhaBackwardDagSignatureAttr:$fused_mha_dag, FusedMHAAlgorithmConfigAttr:$algorithm_config, OptionalAttr:$dropout_rate, - OptionalAttr:$seed); + OptionalAttr:$seed, + BoolAttr:$is_flash_attention, + BoolAttr:$is_causal_mask + ); } def LHLOGPU_RadixSortOp: LHLOGPU_Op<"radix_sort", [SameVariadicOperandSize]> { diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 8ab0646a44a6d3..7ce614e43b8597 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -153,6 +153,8 @@ def FusedMhaBackwardDagScaleBiasSoftmaxDropout : I32EnumAttrCase<"BackwardScaleB def FusedMhaBackwardDagScaleBiasSoftmax : I32EnumAttrCase<"BackwardScaleBiasSoftmax", 1>; def FusedMhaBackwardDagScaleBiasMaskSoftmax : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmax", 2>; def FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout : I32EnumAttrCase<"BackwardScaleBiasMaskSoftmaxDropout", 3>; +def FusedMhaBackwardDagSoftmax : I32EnumAttrCase<"BackwardSoftmax", 4>; +def FusedMhaBackwardDagSoftmaxDropout : I32EnumAttrCase<"BackwardSoftmaxDropout", 5>; def FusedMhaDagSignature: I32EnumAttr<"FusedMhaDagSignature", "DAG configuration for Fused Multi-Headed Attention", @@ -175,11 +177,13 @@ def FusedMhaBackwardDagSignature: I32EnumAttr<"FusedMhaBackwardDagSignature", FusedMhaBackwardDagScaleBiasSoftmaxDropout, FusedMhaBackwardDagScaleBiasSoftmax, FusedMhaBackwardDagScaleBiasMaskSoftmax, - FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout]> { + FusedMhaBackwardDagScaleBiasMaskSoftmaxDropout, + FusedMhaBackwardDagSoftmax, + FusedMhaBackwardDagSoftmaxDropout]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } def FusedMhaDagSignatureAttr : EnumAttr; def FusedMhaBackwardDagSignatureAttr : EnumAttr; -#endif // LHLO_GPU_OPS_ENUMS \ No newline at end of file +#endif // LHLO_GPU_OPS_ENUMS diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 00504c2764b0f0..867f7f8fe2af77 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -189,4 +189,10 @@ message CudnnfMHABackendConfig { // Random seed used by dropout int64 seed = 15; + + // Is flash attention + bool is_flash_attention = 20; + + // Is causal mask + bool is_causal_mask = 21; } diff --git a/third_party/xla/xla/service/gpu/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/fused_mha_thunk.cc index 96562a8e2f9403..f0ba6f3fbd1774 100644 --- a/third_party/xla/xla/service/gpu/fused_mha_thunk.cc +++ b/third_party/xla/xla/service/gpu/fused_mha_thunk.cc @@ -61,6 +61,15 @@ FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner( return *it->second; } +std::optional AssignBufferIfNotNull( + const BufferAllocations& buffer_allocations, + BufferAllocation::Slice& slice) { + return slice.allocation() != nullptr + ? std::optional{buffer_allocations + .GetDeviceAddress(slice)} + : std::nullopt; +} + Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; se::DeviceMemoryBase lhs_bmm1_buffer = @@ -74,19 +83,12 @@ Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase scratch_buffer = buffer_allocations.GetDeviceAddress(scratch_buffer_); - std::optional mask_buffer; - if (mask_buffer_.allocation() != nullptr) { - mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_); - } - std::optional bias_buffer; - if (bias_buffer_.allocation() != nullptr) { - bias_buffer = buffer_allocations.GetDeviceAddress(bias_buffer_); - } - - std::optional activation_buffer; - if (activation_buffer_.allocation() != nullptr) { - activation_buffer = buffer_allocations.GetDeviceAddress(activation_buffer_); - } + std::optional mask_buffer = + AssignBufferIfNotNull(buffer_allocations, mask_buffer_); + std::optional bias_buffer = + AssignBufferIfNotNull(buffer_allocations, bias_buffer_); + std::optional activation_buffer = + AssignBufferIfNotNull(buffer_allocations, activation_buffer_); RunFusedMHAOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); @@ -109,7 +111,9 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice mask, BufferAllocation::Slice d_bias) + BufferAllocation::Slice softmax_sum, BufferAllocation::Slice d_Q_accum, + BufferAllocation::Slice mask, BufferAllocation::Slice d_bias, + BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias) : Thunk(Kind::kFusedMHA, thunk_info), bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), @@ -121,8 +125,12 @@ FusedMHABackwardThunk::FusedMHABackwardThunk( d_bmm1_rhs_buffer_(d_bmm1_rhs), d_bmm2_rhs_buffer_(d_bmm2_rhs), d_s_buffer_(d_s), + softmax_sum_buffer_(softmax_sum), + d_Q_accum_buffer_(d_Q_accum), mask_buffer_(mask), d_bias_buffer_(d_bias), + fwd_output_buffer_(fwd_output), + bias_buffer_(bias), config_(std::move(config)) {} FusedMultiHeadedAttentionBackwardRunner& @@ -169,18 +177,21 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceMemoryBase d_bmm2_rhs_buffer = buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - se::DeviceMemoryBase d_S_buffer = - buffer_allocations.GetDeviceAddress(d_s_buffer_); + std::optional d_s_buffer = + AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); + std::optional softmax_sum_buffer = + AssignBufferIfNotNull(buffer_allocations, softmax_sum_buffer_); + std::optional d_Q_accum_buffer = + AssignBufferIfNotNull(buffer_allocations, d_Q_accum_buffer_); + std::optional mask_buffer = + AssignBufferIfNotNull(buffer_allocations, mask_buffer_); + std::optional d_bias_buffer = + AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); + std::optional fwd_output_buffer = + AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); + std::optional bias_buffer = + AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional mask_buffer; - if (mask_buffer_.allocation() != nullptr) { - mask_buffer = buffer_allocations.GetDeviceAddress(mask_buffer_); - } - - std::optional d_bias_buffer; - if (d_bias_buffer_.allocation() != nullptr) { - d_bias_buffer = buffer_allocations.GetDeviceAddress(d_bias_buffer_); - } RunFusedMHABackwardOptions opts; opts.runner_cache = &GetOrCreateRunner(params.stream); @@ -189,7 +200,8 @@ Status FusedMHABackwardThunk::ExecuteOnStream(const ExecuteParams& params) { config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, mask_buffer, d_bias_buffer, params.stream, opts)); + d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, + d_bias_buffer, fwd_output_buffer, bias_buffer, params.stream, opts)); if (!params.stream->ok()) { return InternalError("FusedMHABackwardThunk::ExecuteOnStream failed."); } diff --git a/third_party/xla/xla/service/gpu/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/fused_mha_thunk.h index a1db1d23e16c9e..a0d9e58aa0e648 100644 --- a/third_party/xla/xla/service/gpu/fused_mha_thunk.h +++ b/third_party/xla/xla/service/gpu/fused_mha_thunk.h @@ -91,9 +91,13 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice d_bmm1_lhs_slice, BufferAllocation::Slice d_bmm1_rhs_slice, BufferAllocation::Slice d_bmm2_rhs_slice, - BufferAllocation::Slice d_S_slice, + BufferAllocation::Slice d_s_slice, + BufferAllocation::Slice softmax_sum_slice, + BufferAllocation::Slice d_Q_accum_slice, BufferAllocation::Slice mask_slice, - BufferAllocation::Slice d_bias_slice); + BufferAllocation::Slice d_bias_slice, + BufferAllocation::Slice fwd_output_slice, + BufferAllocation::Slice bias_slice); FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; @@ -111,8 +115,12 @@ class FusedMHABackwardThunk : public Thunk { BufferAllocation::Slice d_bmm1_rhs_buffer_; BufferAllocation::Slice d_bmm2_rhs_buffer_; BufferAllocation::Slice d_s_buffer_; + BufferAllocation::Slice softmax_sum_buffer_; + BufferAllocation::Slice d_Q_accum_buffer_; BufferAllocation::Slice mask_buffer_; BufferAllocation::Slice d_bias_buffer_; + BufferAllocation::Slice fwd_output_buffer_; + BufferAllocation::Slice bias_buffer_; FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( const stream_executor::Stream* stream); diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc index 74c3d839750a5a..a796f0b5652ca0 100644 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc @@ -83,8 +83,8 @@ Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, params.config->activation, dropout_rate, seed, - false, - false}; + params.config->is_flash_attention, + params.config->is_causal_mask}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); return (*runner)(stream, options.profile_result, scratch_memory, @@ -201,20 +201,21 @@ void AssignSeed(GpufMHAConfig &config, } template -Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, - DeviceMemory d_s_buffer, - DeviceMemoryBase mask_buffer, - DeviceMemoryBase d_bias_buffer, - DeviceMemoryBase scratch_memory) { +Status RunFusedMHABackward( + GpufMHABackwardParams params, se::Stream *stream, + RunFusedMHABackwardOptions options, + DeviceMemory bmm1_grad_gemm1_rhs_buffer, + DeviceMemory bmm1_grad_gemm2_rhs_buffer, + DeviceMemory bmm2_grad_gemm1_lhs_buffer, + DeviceMemory bmm2_grad_gemm2_rhs_buffer, + DeviceMemory d_output_buffer, + DeviceMemory d_bmm1_lhs_buffer, + DeviceMemory d_bmm1_rhs_buffer, + DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, + DeviceMemoryBase softmax_buffer, DeviceMemoryBase d_Q_accum_buffer, + DeviceMemoryBase mask_buffer, DeviceMemoryBase d_bias_buffer, + DeviceMemoryBase fwd_output_buffer, DeviceMemoryBase bias_buffer, + DeviceMemoryBase scratch_memory) { se::dnn::LazyOpRunner *lazy_runner = options.runner_cache->AsFusedMHABackwardRunner(); std::optional> @@ -223,6 +224,7 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, local_runner.emplace(params.config->algorithm); lazy_runner = &*local_runner; } + // FMHA TODO: add GetDNNFusedMHAKindFromCudnnfMHAKind here TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAKind kind, GetDNNFusedMHAKindFromCudnnfMHAKind(params.config->kind)); std::optional dropout_rate; @@ -239,27 +241,25 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, if (params.config->seed) { seed = *params.config->seed; } - // TODO: set is_flash_attention to real value, set it to false for now - se::dnn::FusedMHABackwardOp::Config config{ - kind, - scale, - params.config->bmm1_grad_gemm1_rhs, - params.config->bmm1_grad_gemm2_rhs, - params.config->bmm2_grad_gemm1_lhs, - params.config->bmm2_grad_gemm2_rhs, - params.config->d_output, - params.config->d_bmm1_lhs, - params.config->d_bmm1_rhs, - params.config->d_bmm2_rhs, - std::optional(params.config->d_s), - params.config->mask, - params.config->d_bias, - std::nullopt, - std::nullopt, - dropout_rate, - seed, - false, - false}; + se::dnn::FusedMHABackwardOp::Config config{kind, + scale, + params.config->bmm1_grad_gemm1_rhs, + params.config->bmm1_grad_gemm2_rhs, + params.config->bmm2_grad_gemm1_lhs, + params.config->bmm2_grad_gemm2_rhs, + params.config->d_output, + params.config->d_bmm1_lhs, + params.config->d_bmm1_rhs, + params.config->d_bmm2_rhs, + params.config->d_s, + params.config->mask, + params.config->d_bias, + params.config->fwd_output, + params.config->bias, + dropout_rate, + seed, + params.config->is_flash_attention, + params.config->is_causal_mask}; TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); // TODO: pass in real softmax_sum, dQ_accum, fwd_output @@ -267,9 +267,10 @@ Status RunFusedMHABackward(GpufMHABackwardParams params, se::Stream *stream, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, se::DeviceMemoryBase(), - se::DeviceMemoryBase(), mask_buffer, d_bias_buffer, - se::DeviceMemoryBase(), se::DeviceMemoryBase()); + d_bmm2_rhs_buffer, d_s_buffer, softmax_buffer, + d_Q_accum_buffer, mask_buffer, d_bias_buffer, + fwd_output_buffer, bias_buffer); + return OkStatus(); } template @@ -292,7 +293,20 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, se::DeviceMemory(params.d_bmm1_rhs_buffer); auto d_bmm2_rhs_buffer = se::DeviceMemory(params.d_bmm2_rhs_buffer); - auto d_s_buffer = se::DeviceMemory(params.d_s_buffer); + + // optional buffers + auto d_s_buffer = params.d_s_buffer.has_value() + ? se::DeviceMemory(*params.d_s_buffer) + : se::DeviceMemoryBase(); + auto softmax_sum_buffer = + params.softmax_sum_buffer.has_value() + ? se::DeviceMemory(*params.softmax_sum_buffer) + : se::DeviceMemoryBase(); + + auto d_Q_accum_buffer = + params.d_Q_accum_buffer.has_value() + ? se::DeviceMemory(*params.d_Q_accum_buffer) + : se::DeviceMemoryBase(); auto mask_buffer = params.mask_buffer.has_value() ? se::DeviceMemory(*params.mask_buffer) @@ -302,6 +316,15 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, ? se::DeviceMemory(*params.d_bias_buffer) : se::DeviceMemoryBase(); + auto fwd_output_buffer = + params.fwd_output_buffer.has_value() + ? se::DeviceMemory(*params.fwd_output_buffer) + : se::DeviceMemoryBase(); + + auto bias_buffer = params.bias_buffer.has_value() + ? se::DeviceMemory(*params.bias_buffer) + : se::DeviceMemoryBase(); + se::dnn::AlgorithmDesc algorithm = params.config->algorithm; if (options.runner_cache) { algorithm = options.runner_cache->ToAlgorithmDesc(); @@ -322,8 +345,9 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params, stream, options, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, mask_buffer, - d_bias_buffer, scratch_memory); + d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, + d_Q_accum_buffer, mask_buffer, d_bias_buffer, fwd_output_buffer, + bias_buffer, scratch_memory); break; default: return InternalError("Invalid cuDNN fMHA kind"); @@ -428,6 +452,8 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, bias_shape.layout().minor_to_major()); } config.kind = desc.kind; + config.is_flash_attention = desc.is_flash_attention; + config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); @@ -449,7 +475,6 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, GetDNNDataTypeFromPrimitiveType( @@ -537,7 +562,6 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, if (desc.d_bias_shape) { const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( d_bias_shape.element_type())); @@ -553,7 +577,27 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), mask_shape.layout().minor_to_major()); } + if (desc.fwd_output_shape) { + const Shape &fwd_output_shape = *desc.fwd_output_shape; + TF_ASSIGN_OR_RETURN( + DataType fwd_output_type, + GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); + config.fwd_output = + TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), + fwd_output_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + config.kind = desc.kind; + config.is_flash_attention = desc.is_flash_attention; + config.is_causal_mask = desc.is_causal_mask; const CudnnfMHABackendConfig &backend_config = desc.backend_config; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); @@ -601,9 +645,14 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase d_bmm1_lhs_buffer, se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, se::DeviceMemoryBase d_s_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, std::optional mask_buffer, - std::optional d_bias_buffer) { + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer) { GpufMHABackwardParams params; params.config = &config; params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; @@ -615,9 +664,12 @@ Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; params.d_s_buffer = d_s_buffer; + params.softmax_sum_buffer = softmax_sum_buffer; + params.d_Q_accum_buffer = d_Q_accum_buffer; params.mask_buffer = mask_buffer; params.d_bias_buffer = d_bias_buffer; - + params.fwd_output_buffer = fwd_output_buffer; + params.bias_buffer = bias_buffer; return params; } @@ -651,28 +703,32 @@ Status RunGpuFMHA(const GpufMHAConfig &fmha_config, return OkStatus(); } -Status RunGpuFMHABackward(const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - se::DeviceMemoryBase d_s_buffer, - std::optional mask_buffer, - std::optional d_bias_buffer, - se::Stream *stream, - RunFusedMHABackwardOptions options) { +Status RunGpuFMHABackward( + const GpufMHABackwardConfig &fmha_config, + se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, + se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, + se::DeviceMemoryBase d_bmm1_lhs_buffer, + se::DeviceMemoryBase d_bmm1_rhs_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, + std::optional mask_buffer, + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer, se::Stream *stream, + RunFusedMHABackwardOptions options) { TF_ASSIGN_OR_RETURN( GpufMHABackwardParams params, GpufMHABackwardParams::For( fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, mask_buffer, d_bias_buffer)); + d_bmm2_rhs_buffer, d_s_buffer, softmax_sum_buffer, d_Q_accum_buffer, + mask_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer)); PrimitiveType input_primitive_type = fmha_config.input_type; switch (input_primitive_type) { case F16: diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h index 637a3c474c4f7e..041993431030c7 100644 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h @@ -46,6 +46,8 @@ namespace gpu { struct GpufMHADescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; + bool is_flash_attention; + bool is_causal_mask; Shape lhs_bmm1_shape; Shape rhs_bmm1_shape; Shape rhs_bmm2_shape; @@ -62,6 +64,8 @@ struct GpufMHADescriptor { struct GpufMHABackwardDescriptor { CudnnfMHAKind kind; CudnnfMHABackendConfig backend_config; + bool is_flash_attention; + bool is_causal_mask; Shape bmm1_grad_gemm1_rhs_shape; Shape bmm1_grad_gemm2_rhs_shape; Shape bmm2_grad_gemm1_lhs_shape; @@ -75,8 +79,11 @@ struct GpufMHABackwardDescriptor { DotDimensionNumbers bmm2_grad_gemm1_dnums; DotDimensionNumbers bmm2_grad_gemm2_dnums; + std::optional d_s_shape; + std::optional fwd_output_shape; std::optional mask_shape; std::optional d_bias_shape; + std::optional bias_shape; }; // Structure to describe static properties of a GPU fused Multi-Headed // Attention. @@ -91,7 +98,8 @@ struct GpufMHAConfig { std::optional seed; se::dnn::AlgorithmDesc algorithm; - + bool is_flash_attention; + bool is_causal_mask; // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] // mask -> [batch_size, 1, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor lhs_bmm1; @@ -119,7 +127,8 @@ struct GpufMHABackwardConfig { std::optional seed; se::dnn::AlgorithmDesc algorithm; - + bool is_flash_attention; + bool is_causal_mask; // mask -> [batch_size, 1, q_seq_len, kv_seq_len] // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; @@ -130,9 +139,11 @@ struct GpufMHABackwardConfig { se::dnn::TensorDescriptor d_bmm1_lhs; se::dnn::TensorDescriptor d_bmm1_rhs; se::dnn::TensorDescriptor d_bmm2_rhs; - se::dnn::TensorDescriptor d_s; - std::optional d_bias; + std::optional d_s; std::optional mask; + std::optional d_bias; + std::optional fwd_output; + std::optional bias; }; // Implementation struct exposed for debugging and log analysis. @@ -165,9 +176,14 @@ struct GpufMHABackwardParams { se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase d_bmm1_lhs_buffer, se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, se::DeviceMemoryBase d_s_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, std::optional mask_buffer, - std::optional d_bias_buffer); + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer); const GpufMHABackwardConfig* config; // Not owned se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; @@ -178,9 +194,13 @@ struct GpufMHABackwardParams { se::DeviceMemoryBase d_bmm1_lhs_buffer; se::DeviceMemoryBase d_bmm1_rhs_buffer; se::DeviceMemoryBase d_bmm2_rhs_buffer; - se::DeviceMemoryBase d_s_buffer; - std::optional d_bias_buffer; + std::optional d_s_buffer; + std::optional softmax_sum_buffer; + std::optional d_Q_accum_buffer; std::optional mask_buffer; + std::optional d_bias_buffer; + std::optional fwd_output_buffer; + std::optional bias_buffer; }; class FusedMultiHeadedAttentionRunner { @@ -371,20 +391,24 @@ Status RunGpuFMHA(const GpufMHAConfig& fmha_config, std::optional activation_buffer, se::Stream* stream, RunFusedMHAOptions = {}); -Status RunGpuFMHABackward(const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - se::DeviceMemoryBase d_s_buffer, - std::optional mask_buffer, - std::optional d_bias_buffer, - se::Stream* stream, RunFusedMHABackwardOptions = {}); +Status RunGpuFMHABackward( + const GpufMHABackwardConfig& fmha_config, + se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, + se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, + se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, + se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, + se::DeviceMemoryBase d_bmm1_lhs_buffer, + se::DeviceMemoryBase d_bmm1_rhs_buffer, + se::DeviceMemoryBase d_bmm2_rhs_buffer, + std::optional d_s_buffer, + std::optional softmax_sum_buffer, + std::optional d_Q_accum_buffer, + std::optional mask_buffer, + std::optional d_bias_buffer, + std::optional fwd_output_buffer, + std::optional bias_buffer, se::Stream* stream, + RunFusedMHABackwardOptions = {}); std::string ToString(const GpufMHAConfig& config); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 4c2ca90915f744..9921870f1b6853 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -248,6 +248,13 @@ StatusOr AsCudnnBackwardfMHAKind( case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasMaskSoftmaxDropout: return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout; + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; + break; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; + break; default: return xla::InternalError("Unsupported fused_mha_backward_dag_signature"); } @@ -1226,6 +1233,11 @@ Status IrEmitterUnnested::EmitFusedMHAThunk(mlir::Operation* op) { ShapeUtil::MakeShapeWithDenseLayout( GetShape(fmha.getOutput()).element_type(), intermediate_tensor_dims_array, intermediate_tensor_layout_array); + + // set if flash attention here + descriptor.is_flash_attention = fmha.getIsFlashAttention(); + // set if causal mask here + descriptor.is_causal_mask = fmha.getIsCausalMask(); return OkStatus(); }; @@ -1253,9 +1265,9 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { GpufMHABackwardDescriptor descriptor; BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, - scratch_slice, mask_slice; + scratch_slice, mask_slice, fwd_output_slice, bias_slice; BufferAllocation::Slice d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_S_slice, d_bias_slice; + d_s_slice, softmax_sum_slice, d_Q_accum_slice, d_bias_slice; auto populate_common = [&](auto fmha) -> Status { descriptor.backend_config.set_fmha_scale( @@ -1285,6 +1297,10 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { algorithm->mutable_workspace_size()->set_value(workspace_size); } + // set if flash attention here + descriptor.is_flash_attention = fmha.getIsFlashAttention(); + // set if causal mask here + descriptor.is_causal_mask = fmha.getIsCausalMask(); descriptor.bmm1_grad_gemm1_dnums = ConvertDotDimensionNumbers(fmha.getBmm1GradGemm1DotDimensionNumbers()); descriptor.bmm1_grad_gemm2_dnums = @@ -1308,10 +1324,31 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(bmm1_grad_gemm2_rhs_slice, GetAllocationSlice(fmha.getBmm1GradGemm2Rhs())); - descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( - GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), - GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), - GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); + // fwd activation + // fmha.getBmm2GradGemm1Lhs() could be bmm2_grad_gemm1_lhs for regular + // attention or softmax stats for flash attention here we set the shape to + // be bmm2_grad_gemm1_lhs even it is flash attention + if (descriptor.is_flash_attention) { + // flash attention TODO: make sure the layout is correct for + // bmm2_grad_gemm1_lhs + TF_ASSIGN_OR_RETURN(auto intermediate_tensor_dims_array, + ConvertMlirArrayAttrToInt64Array( + fmha.getIntermediateTensorDimensions())); + TF_ASSIGN_OR_RETURN( + auto intermediate_tensor_layout_array, + ConvertMlirArrayAttrToInt64Array(fmha.getIntermediateTensorLayout())); + + descriptor.bmm2_grad_gemm1_lhs_shape = + ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getDOutput()).element_type(), + intermediate_tensor_dims_array, intermediate_tensor_layout_array); + } else { + descriptor.bmm2_grad_gemm1_lhs_shape = + ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getBmm2GradGemm1Lhs()).element_type(), + GetShape(fmha.getBmm2GradGemm1Lhs()).dimensions(), + GetShape(fmha.getBmm2GradGemm1Lhs()).layout().minor_to_major()); + } TF_ASSIGN_OR_RETURN(bmm2_grad_gemm1_lhs_slice, GetAllocationSlice(fmha.getBmm2GradGemm1Lhs())); @@ -1350,7 +1387,13 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(scratch_slice, GetAllocationSlice(fmha.getScratch())); - TF_ASSIGN_OR_RETURN(d_S_slice, GetAllocationSlice(fmha.getD_S())); + if (fmha.getD_S() != nullptr) { + descriptor.d_s_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getD_S()).element_type(), + GetShape(fmha.getD_S()).dimensions(), + GetShape(fmha.getD_S()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(d_s_slice, GetAllocationSlice(fmha.getD_S())); + } if (fmha.getDBias() != nullptr) { descriptor.d_bias_shape = ShapeUtil::MakeShapeWithDenseLayout( @@ -1374,6 +1417,33 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(mask_slice, GetAllocationSlice(fmha.getMask())); } + // add flash attention backward related slice here + if (fmha.getBias() != nullptr) { + descriptor.bias_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getBias()).element_type(), + GetShape(fmha.getBias()).dimensions(), + GetShape(fmha.getBias()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSlice(fmha.getBias())); + } + + if (fmha.getSoftmaxSum() != nullptr) { + TF_ASSIGN_OR_RETURN(softmax_sum_slice, + GetAllocationSlice(fmha.getSoftmaxSum())); + } + + if (fmha.getD_QAccum() != nullptr) { + TF_ASSIGN_OR_RETURN(d_Q_accum_slice, + GetAllocationSlice(fmha.getD_QAccum())); + } + + if (fmha.getFwdOutput() != nullptr) { + descriptor.fwd_output_shape = ShapeUtil::MakeShapeWithDenseLayout( + GetShape(fmha.getFwdOutput()).element_type(), + GetShape(fmha.getFwdOutput()).dimensions(), + GetShape(fmha.getFwdOutput()).layout().minor_to_major()); + TF_ASSIGN_OR_RETURN(fwd_output_slice, + GetAllocationSlice(fmha.getFwdOutput())); + } return OkStatus(); }; @@ -1395,7 +1465,8 @@ Status IrEmitterUnnested::EmitFusedMHABackwardThunk(mlir::Operation* op) { bmm1_grad_gemm1_rhs_slice, bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, - d_S_slice, mask_slice, d_bias_slice)); + d_s_slice, softmax_sum_slice, d_Q_accum_slice, mask_slice, d_bias_slice, + fwd_output_slice, bias_slice)); return OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 6354250dddf602..fa6dca5b2eafb6 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -227,11 +227,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( mha_fusion_pipeline.AddPass(); mha_fusion_pipeline.AddPass(); } + mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); mha_fusion_pipeline.AddPass>( alg_sim_options); mha_fusion_pipeline.AddPass(/*is_layout_sensitive=*/true); - // Rewrite Multi-Headed Attention modules to Fused MHA custom-calls. if (stream_exec) { mha_fusion_pipeline.AddPass( diff --git a/third_party/xla/xla/service/gpu/runtime/fused_attention.cc b/third_party/xla/xla/service/gpu/runtime/fused_attention.cc index 3174445ef1ee5e..9dbd3e6aa1907f 100644 --- a/third_party/xla/xla/service/gpu/runtime/fused_attention.cc +++ b/third_party/xla/xla/service/gpu/runtime/fused_attention.cc @@ -114,6 +114,10 @@ static auto EncodeFusedAttentionBackwardDAGSignature( lmhlo_gpu::FusedMhaBackwardDagSignature signature) { switch (signature) { // backward + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmax; + case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout: + return xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout; case mlir::lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasSoftmax: return xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax; @@ -193,8 +197,8 @@ static GpufMHADescriptor GetGpufMHADescriptor( absl::Span intermediate_tensor_dimensions, absl::Span intermediate_tensor_layout, AlgorithmConfig algo, DotDimensionNumbers bmm1_dot_dimension_numbers, - DotDimensionNumbers bmm2_dot_dimension_numbers, - std::optional dropout = std::nullopt) { + DotDimensionNumbers bmm2_dot_dimension_numbers, bool is_flash_attention, + bool is_causal_mask, std::optional dropout = std::nullopt) { GpufMHADescriptor descriptor; descriptor.backend_config.set_fmha_scale(fmha_scale); @@ -250,7 +254,8 @@ static GpufMHADescriptor GetGpufMHADescriptor( } descriptor.kind = kind; - + descriptor.is_flash_attention = is_flash_attention; + descriptor.is_causal_mask = is_causal_mask; return descriptor; } @@ -262,11 +267,19 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( std::optional mask, std::optional d_bias, StridedMemrefView d_bmm1_lhs, StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - StridedMemrefView d_S, double fmha_scale, AlgorithmConfig algo, + std::optional d_S, + std::optional softmax_sum, + std::optional d_Q_accum, + std::optional fwd_output, + std::optional bias, double fmha_scale, + AlgorithmConfig algo, DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers, + absl::Span intermediate_tensor_dimensions, + absl::Span intermediate_tensor_layout, + bool is_flash_attention, bool is_causal_mask, std::optional dropout_attrs = std::nullopt) { GpufMHABackwardDescriptor descriptor; descriptor.backend_config.set_fmha_scale(fmha_scale); @@ -313,7 +326,15 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( descriptor.bmm1_grad_gemm1_rhs_shape = apply_shape(bmm1_grad_gemm1_rhs); descriptor.bmm1_grad_gemm2_rhs_shape = apply_shape(bmm1_grad_gemm2_rhs); descriptor.bmm2_grad_gemm2_rhs_shape = apply_shape(bmm2_grad_gemm2_rhs); - descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); + if (is_flash_attention) { + // if it is flash attention then bmm2_grad_gemm1_lhs will be softmax_stats + // instead of P we need to use real P layout + descriptor.bmm2_grad_gemm1_lhs_shape = ShapeUtil::MakeShapeWithDenseLayout( + descriptor.bmm2_grad_gemm2_rhs_shape.element_type(), + intermediate_tensor_dimensions, intermediate_tensor_layout); + } else { + descriptor.bmm2_grad_gemm1_lhs_shape = apply_shape(bmm2_grad_gemm1_lhs); + } descriptor.d_output_shape = apply_shape(d_output); descriptor.d_bmm1_lhs_shape = apply_shape(d_bmm1_lhs); @@ -326,14 +347,20 @@ static GpufMHABackwardDescriptor GetGpufMHABackwardDescriptor( if (d_bias.has_value()) { descriptor.d_bias_shape = apply_shape(*d_bias); } - + if (fwd_output.has_value()) { + descriptor.fwd_output_shape = apply_shape(*fwd_output); + } + if (bias.has_value()) { + descriptor.bias_shape = apply_shape(*bias); + } if (dropout_attrs.has_value()) { descriptor.backend_config.set_dropout_rate(dropout_attrs->dropout_rate); descriptor.backend_config.set_seed(dropout_attrs->seed); } descriptor.kind = kind; - + descriptor.is_flash_attention = is_flash_attention; + descriptor.is_causal_mask = is_causal_mask; return descriptor; } @@ -344,7 +371,8 @@ static absl::Status FusedAttentionForwardImpl( StridedMemrefView rhs_bmm2, std::optional mask, std::optional bias, StridedMemrefView output, FlatMemrefView scratch, std::optional activation, - int64_t uid, double fmha_scale, + int64_t uid, double fmha_scale, bool is_flash_attention, + bool is_causal_mask, absl::Span intermediate_tensor_dimensions, absl::Span intermediate_tensor_layout, DotDimensionNumbers bmm1_dot_dimension_numbers, @@ -364,7 +392,7 @@ static absl::Status FusedAttentionForwardImpl( fmha_scale, intermediate_tensor_dimensions, intermediate_tensor_layout, algorithm_config, bmm1_dot_dimension_numbers, bmm2_dot_dimension_numbers, - dropout_attrs); + is_flash_attention, is_causal_mask, dropout_attrs); StatusOr config = GpufMHAConfig::For(descriptor); if (!config.ok()) return tsl::ToAbslStatus(config.status()); @@ -414,10 +442,17 @@ static absl::Status FusedAttentionBackwardImpl( StridedMemrefView bmm1_grad_gemm2_rhs, StridedMemrefView bmm2_grad_gemm2_rhs, StridedMemrefView bmm2_grad_gemm1_lhs, StridedMemrefView d_output, - std::optional mask, StridedMemrefView d_bmm1_lhs, + std::optional mask, + std::optional bias, + std::optional fwd_output, StridedMemrefView d_bmm1_lhs, StridedMemrefView d_bmm1_rhs, StridedMemrefView d_bmm2_rhs, - StridedMemrefView d_S, FlatMemrefView scratch, + std::optional d_S, + std::optional softmax_sum, + std::optional d_Q_accum, FlatMemrefView scratch, std::optional d_bias, int64_t uid, double fmha_scale, + bool is_flash_attention, bool is_causal_mask, + absl::Span intermediate_tensor_dimensions, + absl::Span intermediate_tensor_layout, DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers, DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers, DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers, @@ -436,12 +471,13 @@ static absl::Status FusedAttentionBackwardImpl( GpufMHABackwardDescriptor descriptor = GetGpufMHABackwardDescriptor( kind, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, bmm2_grad_gemm2_rhs, bmm2_grad_gemm1_lhs, d_output, mask, d_bias, d_bmm1_lhs, d_bmm1_rhs, - d_bmm2_rhs, d_S, fmha_scale, algorithm_config, - bmm1_grad_gemm1_dot_dimension_numbers, + d_bmm2_rhs, d_S, softmax_sum, d_Q_accum, fwd_output, bias, + fmha_scale, algorithm_config, bmm1_grad_gemm1_dot_dimension_numbers, bmm1_grad_gemm2_dot_dimension_numbers, bmm2_grad_gemm1_dot_dimension_numbers, - bmm2_grad_gemm2_dot_dimension_numbers, dropout_attrs); - + bmm2_grad_gemm2_dot_dimension_numbers, + intermediate_tensor_dimensions, intermediate_tensor_layout, + is_flash_attention, is_causal_mask, dropout_attrs); StatusOr config = GpufMHABackwardConfig::For(descriptor); if (!config.ok()) return tsl::ToAbslStatus(config.status()); @@ -463,9 +499,13 @@ static absl::Status FusedAttentionBackwardImpl( se::DeviceMemoryBase d_bmm1_lhs_buffer = GetDeviceAddress(d_bmm1_lhs); se::DeviceMemoryBase d_bmm1_rhs_buffer = GetDeviceAddress(d_bmm1_rhs); se::DeviceMemoryBase d_bmm2_rhs_buffer = GetDeviceAddress(d_bmm2_rhs); - se::DeviceMemoryBase d_S_buffer = GetDeviceAddress(d_S); se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch); + se::DeviceMemoryBase d_S_buffer; + if (d_S.has_value()) { + d_S_buffer = GetDeviceAddress(*d_S); + } + se::DeviceMemoryBase mask_buffer; if (mask.has_value()) { mask_buffer = GetDeviceAddress(*mask); @@ -476,6 +516,26 @@ static absl::Status FusedAttentionBackwardImpl( d_bias_buffer = GetDeviceAddress(*d_bias); } + se::DeviceMemoryBase softmax_sum_buffer; + if (softmax_sum.has_value()) { + softmax_sum_buffer = GetDeviceAddress(*softmax_sum); + } + + se::DeviceMemoryBase d_Q_accum_buffer; + if (d_Q_accum.has_value()) { + d_Q_accum_buffer = GetDeviceAddress(*d_Q_accum); + } + + se::DeviceMemoryBase fwd_output_buffer; + if (fwd_output.has_value()) { + fwd_output_buffer = GetDeviceAddress(*fwd_output); + } + + se::DeviceMemoryBase bias_buffer; + if (bias.has_value()) { + bias_buffer = GetDeviceAddress(*bias); + } + RunFusedMHABackwardOptions opts; opts.runner_cache = &(*fda)->runner; @@ -484,7 +544,9 @@ static absl::Status FusedAttentionBackwardImpl( (*fda)->config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_S_buffer, mask_buffer, d_bias_buffer, run_options->stream(), opts); + d_S_buffer, softmax_sum_buffer, d_Q_accum_buffer, mask_buffer, + d_bias_buffer, fwd_output_buffer, bias_buffer, run_options->stream(), + opts); if (!st.ok() || !run_options->stream()->ok()) { return tsl::ToAbslStatus(st); } @@ -500,6 +562,8 @@ auto BindFusedAttentionAttributes(runtime::CustomCallBinding binding) { return std::move(binding) .template Attr("uid") .template Attr("fmha_scale") + .template Attr("is_flash_attention") + .template Attr("is_causal_mask") .template Attr>( "intermediate_tensor_dimensions") .template Attr>("intermediate_tensor_layout") @@ -805,6 +869,11 @@ auto BindFusedAttentionBackwardAttributes( return std::move(binding) .template Attr("uid") .template Attr("fmha_scale") + .template Attr("is_flash_attention") + .template Attr("is_causal_mask") + .template Attr>( + "intermediate_tensor_dimensions") + .template Attr>("intermediate_tensor_layout") .template Attr( "bmm1_grad_gemm1_dot_dimension_numbers") .template Attr( @@ -822,11 +891,11 @@ auto FusedAttentionBackwardCall(const char* name) { .UserData() .UserData() .State("uid") - .Arg() // bmm1_grad_gemm1_rhs - .Arg() // bmm1_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm2_rhs - .Arg() // bmm2_grad_gemm1_lhs - .Arg(); + .Arg() // bmm1_grad_gemm1_rhs + .Arg() // bmm1_grad_gemm2_rhs + .Arg() // bmm2_grad_gemm2_rhs + .Arg() // bmm2_grad_gemm1_lhs + .Arg(); // d_output } XLA_RUNTIME_DEFINE_CUSTOM_CALL( @@ -836,10 +905,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.softmax") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Arg() // d_bias ) @@ -854,10 +927,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.softmax") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -872,10 +949,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.softmax.dropout") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Arg() // d_bias ) @@ -890,10 +971,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.softmax.dropout") .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -907,13 +992,17 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( BindFusedAttentionBackwardAttributes( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax") - .Arg() // mask - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Arg() // scratch - .Arg() // d_bias + .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum + .Arg() // scratch + .Arg() // d_bias ) .Value(std::optional()) // dropout_rate .Value(std::optional()) // seed @@ -926,10 +1015,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.mask.softmax") .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) @@ -943,13 +1036,17 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( BindFusedAttentionBackwardAttributes( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.dbias.mask.softmax.dropout") - .Arg() // mask - .Arg() // d_bmm1_lhs - .Arg() // d_bmm1_rhs - .Arg() // d_bmm2_rhs - .Arg() // d_S - .Arg() // scratch - .Arg() // d_bias + .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum + .Arg() // scratch + .Arg() // d_bias ) .Attr("dropout_rate") // dropout_rate .Attr("seed") // seed @@ -962,16 +1059,66 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( FusedAttentionBackwardCall( "xla.gpu.fused.attention.backward.scale.mask.softmax.dropout") .Arg() // mask + .Value(std::optional()) // bias + .Value(std::optional()) // fwd_output .Arg() // d_bmm1_lhs .Arg() // d_bmm1_rhs .Arg() // d_bmm2_rhs .Arg() // d_S + .Value(std::optional()) // softmax_sum + .Value(std::optional()) // d_Q_accum .Arg() // scratch .Value(std::optional()) // d_bias ) .Attr("dropout_rate") // dropout_rate .Attr("seed") // seed ); + +// flash attention backward custom call +XLA_RUNTIME_DEFINE_CUSTOM_CALL( + FlashAttentionScaleBiasSoftmaxBackward, + FunctionWrapper(), checks, + BindFusedAttentionBackwardAttributes( + FusedAttentionBackwardCall( + "xla.gpu.flash.attention.backward.scale.bias.softmax") + .Value(std::optional()) // mask + .Arg() // bias + .Arg() // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Value(std::optional()) // d_S + .Arg() // softmax_sum + .Arg() // d_Q_accum + .Arg() // scratch + .Value(std::optional()) // d_bias + ) + .Value(std::optional()) // dropout_rate + .Value(std::optional()) // seed +); + +XLA_RUNTIME_DEFINE_CUSTOM_CALL( + FlashAttentionScaleSoftmaxBackward, + FunctionWrapper(), checks, + BindFusedAttentionBackwardAttributes( + FusedAttentionBackwardCall( + "xla.gpu.flash.attention.backward.scale.softmax") + .Value(std::optional()) // mask + .Value(std::optional()) // bias + .Arg() // fwd_output + .Arg() // d_bmm1_lhs + .Arg() // d_bmm1_rhs + .Arg() // d_bmm2_rhs + .Value(std::optional()) // d_S + .Arg() // softmax_sum + .Arg() // d_Q_accum + .Arg() // scratch + .Value(std::optional()) // d_bias + ) + .Value(std::optional()) // dropout_rate + .Value(std::optional()) // seed +); + //===----------------------------------------------------------------------===// // cuBLASLt custom calls bindings and registration. //===----------------------------------------------------------------------===// @@ -1040,6 +1187,14 @@ void RegisterFusedAttentionBackwardCustomCalls( FusedAttentionScaleBiasMaskSoftmaxDropoutBackward); registry.Register(fused_attention("scale.mask.softmax.dropout"), FusedAttentionScaleMaskSoftmaxDropoutBackward); + // flash attention bwd + auto flash_attention = [](std::string name) { + return "xla.gpu.flash.attention.backward." + name; + }; + registry.Register(flash_attention("scale.bias.softmax"), + FlashAttentionScaleBiasSoftmaxBackward); + registry.Register(flash_attention("scale.softmax"), + FlashAttentionScaleSoftmaxBackward); } } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index ac7c223b874268..1a340c51cb326e 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -780,6 +780,12 @@ AsLhloFusedMhaBackwardDagSignature(xla::gpu::CudnnfMHAKind kind) { return lmhlo_gpu::FusedMhaBackwardDagSignature:: BackwardScaleBiasMaskSoftmaxDropout; break; + case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: + return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmax; + break; + case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: + return lmhlo_gpu::FusedMhaBackwardDagSignature::BackwardSoftmaxDropout; + break; default: return xla::InternalError("unknown cudnn fmha bwd kind"); } @@ -1270,13 +1276,17 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHA( has_activation ? 1 : 0}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); + // set is flash attention here + op.setIsFlashAttentionAttr( + builder_.getBoolAttr(config.is_flash_attention())); + // set is causal mask here + op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); return op.getOperation(); }; llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); - switch (kind) { case xla::gpu::CudnnfMHAKind::kBmmBmm: case xla::gpu::CudnnfMHAKind::kSoftmax: { @@ -1365,8 +1375,11 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, xla::gpu::GetCudnnfMHAKind(custom_call)); - bool has_dbias = custom_call->shape().tuple_shapes().size() == 6; + bool is_flash_attention = config.is_flash_attention(); + bool has_dbias = + custom_call->shape().tuple_shapes().size() == 6 && !is_flash_attention; bool has_mask = false; + bool has_bias = false; auto set_common_fmha_backward_attributes = [&, this](auto op) -> tsl::StatusOr { @@ -1384,13 +1397,44 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( op.setBmm2GradGemm2DotDimensionNumbersAttr(GetDotDimensionNumbersAttr( builder_, config.bmm2_grad_gemm2_dot_dimension_numbers())); + auto intermediate_tensor_shape = Shape(config.intermediate_tensor_shape()); + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + auto intermediate_tensor_dims = builder_.getI64ArrayAttr( + arrayref(intermediate_tensor_shape.dimensions())); + op.setIntermediateTensorDimensionsAttr(intermediate_tensor_dims); + + auto intermediate_tensor_layout = builder_.getI64ArrayAttr( + arrayref(intermediate_tensor_shape.layout().minor_to_major())); + op.setIntermediateTensorLayoutAttr(intermediate_tensor_layout); + op.setFmhaScaleAttr(builder_.getF64FloatAttr(config.fmha_scale())); - int32_t operand_sizes[] = {1, 1, 1, 1, 1, has_mask ? 1 : 0, - 1, 1, 1, 1, 1, has_dbias ? 1 : 0}; + int32_t operand_sizes[] = {1, + 1, + 1, + 1, + 1, + has_mask ? 1 : 0, + has_bias ? 1 : 0, + is_flash_attention ? 1 : 0, // fwd_output + 1, + 1, + 1, + is_flash_attention ? 0 : 1, // d_S + is_flash_attention ? 1 : 0, // softmax_sum + is_flash_attention ? 1 : 0, // d_Q_accum + 1, + has_dbias ? 1 : 0}; op->setAttr(op.getOperandSegmentSizeAttr(), builder_.getDenseI32ArrayAttr(operand_sizes)); + // set is flash attention here + op.setIsFlashAttentionAttr( + builder_.getBoolAttr(config.is_flash_attention())); + // set is causal mask here + op.setIsCausalMaskAttr(builder_.getBoolAttr(config.is_causal_mask())); const auto& algorithm = config.algorithm(); std::vector knob_ids; std::vector knob_values; @@ -1406,7 +1450,7 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return op.getOperation(); }; - llvm::SmallVector operands; + llvm::SmallVector operands; TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(0), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(1), &operands)); TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(2), &operands)); @@ -1415,15 +1459,35 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( switch (kind) { case xla::gpu::CudnnfMHAKind::kBackwardBmmBmm: - case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: + case xla::gpu::CudnnfMHAKind::kBackwardSoftmax: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmax: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); auto fmha_backward = CreateOpWithoutAttrs( custom_call, operands); return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardSoftmaxDropout: case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: { + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); auto fmha_backward = CreateOpWithoutAttrs( custom_call, operands); @@ -1433,9 +1497,26 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: + case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmax: { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + has_mask = true; + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmax: { TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(7), &operands)); + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); has_mask = true; auto fmha_backward = CreateOpWithoutAttrs( @@ -1443,9 +1524,31 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnfMHABackward( return set_common_fmha_backward_attributes(fmha_backward); } - case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: + case xla::gpu::CudnnfMHAKind::kBackwardScaleMaskSoftmaxDropout: { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(6), &operands)); + } + TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); + has_mask = true; + auto fmha_backward = CreateOpWithoutAttrs( + custom_call, operands); + fmha_backward.setDropoutRateAttr( + builder_.getF64FloatAttr(config.dropout_rate())); + fmha_backward.setSeedAttr(builder_.getI64IntegerAttr(config.seed())); + return set_common_fmha_backward_attributes(fmha_backward); + } case xla::gpu::CudnnfMHAKind::kBackwardScaleBiasMaskSoftmaxDropout: { TF_RETURN_IF_ERROR(GetOrCreateView(custom_call->operand(5), &operands)); + // push fwd output for bwd here if it is flash attention + if (config.is_flash_attention()) { + has_bias = true; + TF_RETURN_IF_ERROR( + GetOrCreateView(custom_call->operand(6), &operands)); // bias + TF_RETURN_IF_ERROR( + GetOrCreateView(custom_call->operand(7), &operands)); // fwd_output + } TF_RETURN_IF_ERROR(GetOrCreateView(custom_call, &operands)); has_mask = true; auto fmha_backward = CreateOpWithoutAttrs( From 3c0b8f66809c63a20c368d26a2175445a0e4e991 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 10:46:55 -0800 Subject: [PATCH 178/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/a2f7558daa1c026305ba9148bce0d34ed0a83a6e. PiperOrigin-RevId: 583096340 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 66639a6944c897..8fecb534923045 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "06b4bc05a86e9d2464a82e96df6703252bf100a1" - TFRT_SHA256 = "84a8f80403c6a4a8281d0cd291ff1ea3ce6f0ed29ab35b2d3d155f4ffec66488" + TFRT_COMMIT = "a2f7558daa1c026305ba9148bce0d34ed0a83a6e" + TFRT_SHA256 = "6c4d2c1e9835ec186dcd813cfec4d68537d7dbefab42bc0fabf2c0813c0b64e0" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 66639a6944c897..8fecb534923045 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "06b4bc05a86e9d2464a82e96df6703252bf100a1" - TFRT_SHA256 = "84a8f80403c6a4a8281d0cd291ff1ea3ce6f0ed29ab35b2d3d155f4ffec66488" + TFRT_COMMIT = "a2f7558daa1c026305ba9148bce0d34ed0a83a6e" + TFRT_SHA256 = "6c4d2c1e9835ec186dcd813cfec4d68537d7dbefab42bc0fabf2c0813c0b64e0" tf_http_archive( name = "tf_runtime", From 535a6fddac76d243b5cd078d14bf9c5e4aa409c1 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 16 Nov 2023 10:54:06 -0800 Subject: [PATCH 179/391] Clean up generate_compile_commands.py Remove extraneous references to clang_tidy PiperOrigin-RevId: 583098899 --- .../build_tools/lint/generate_compile_commands.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py index 2deb1fc3f177c5..c52c4cc93203be 100644 --- a/third_party/xla/build_tools/lint/generate_compile_commands.py +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -32,14 +32,14 @@ @dataclasses.dataclass -class ClangTidyCommand: - """Represents a clang-tidy command with options on a specific file.""" +class CompileCommand: + """Represents a compilation command with options on a specific file.""" file: str arguments: list[str] @classmethod - def from_args_list(cls, args_list: list[str]) -> "ClangTidyCommand": + def from_args_list(cls, args_list: list[str]) -> "CompileCommand": """Alternative constructor which uses the args_list from `bazel aquery`. This collects arguments and the file being run on from the output of @@ -75,21 +75,21 @@ def to_dumpable_json(self, directory: str) -> _JSONDict: def extract_compile_commands( parsed_aquery_output: _JSONDict, -) -> list[ClangTidyCommand]: - """Gathers clang-tidy commands to run from `bazel aquery` JSON output. +) -> list[CompileCommand]: + """Gathers compile commands to run from `bazel aquery` JSON output. Arguments: parsed_aquery_output: Parsed JSON representing the output of `bazel aquery --output=jsonproto`. Returns: - The list of ClangTidyCommands that should be executed. + The list of CompileCommands that should be executed. """ actions = parsed_aquery_output["actions"] commands = [] for action in actions: - command = ClangTidyCommand.from_args_list(action["arguments"]) + command = CompileCommand.from_args_list(action["arguments"]) commands.append(command) return commands @@ -99,7 +99,6 @@ def main(): logging.basicConfig() logging.getLogger().setLevel(logging.INFO) - # Gather and run clang-tidy invocations logging.info("Reading `bazel aquery` output from stdin...") parsed_aquery_output = json.loads(sys.stdin.read()) From f54d79268930b45a55f29fc69bc6dc64fcf7f695 Mon Sep 17 00:00:00 2001 From: Samuel Agyakwa Date: Thu, 16 Nov 2023 11:19:48 -0800 Subject: [PATCH 180/391] [PJRT C API] Fix pybind11 implicit conversion from boolean to integral when getting the the type from an std::variant PiperOrigin-RevId: 583107380 --- third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc | 2 +- third_party/xla/xla/pjrt/pjrt_common.h | 5 ++++- third_party/xla/xla/python/py_executable.h | 7 +++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 65af68ba2bb9f9..d36f3cf9cd3d1d 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -129,7 +129,7 @@ TEST(PjRtCApiHelperTest, InvalidOptionTypeIndex) { EXPECT_NE(status, tsl::OkStatus()); EXPECT_THAT(status.message(), HasSubstr("Option passed to PJRT_Client_Create with name string " - "has type index 1 but expected type index is 0")); + "has type index 2 but expected type index is 0")); } TEST(PjRtCApiHelperTest, Callback) { diff --git a/third_party/xla/xla/pjrt/pjrt_common.h b/third_party/xla/xla/pjrt/pjrt_common.h index 263aa115113108..0de187db139b2e 100644 --- a/third_party/xla/xla/pjrt/pjrt_common.h +++ b/third_party/xla/xla/pjrt/pjrt_common.h @@ -23,8 +23,11 @@ limitations under the License. namespace xla { +// bool comes before int64_t because when pybind11 tries to convert a Python +// object to a C++ type, it will try to convert it to the first type in the list +// of possible types that it can be converted to (b/309163973). using PjRtValueType = - std::variant, float, bool>; + std::variant, float>; } // namespace xla diff --git a/third_party/xla/xla/python/py_executable.h b/third_party/xla/xla/python/py_executable.h index 6f0440356d304e..8e722c908235be 100644 --- a/third_party/xla/xla/python/py_executable.h +++ b/third_party/xla/xla/python/py_executable.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "pybind11/gil.h" // from @pybind11 #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/py_array.h" #include "xla/python/py_client.h" @@ -139,10 +140,8 @@ class PyLoadedExecutable return ifrt_loaded_executable_->GetCompiledMemoryStats(); } - StatusOr, float, bool>>> - GetCostAnalysis() const { + StatusOr> GetCostAnalysis() + const { return ifrt_loaded_executable_->GetCostAnalysis(); } From be63ae329866d5eb4e59baf8d41a57127b998d56 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 11:25:55 -0800 Subject: [PATCH 181/391] Renames 'LeafStrategies' to 'StrategyGroups' (since each of its members is a StrategyGroup). PiperOrigin-RevId: 583109427 --- .../auto_sharding/auto_sharding.cc | 162 +++++++++--------- .../auto_sharding/auto_sharding.h | 10 +- .../auto_sharding/auto_sharding_cost_graph.h | 18 +- .../auto_sharding_dot_handler.cc | 8 +- .../auto_sharding/auto_sharding_impl.cc | 4 +- .../auto_sharding/auto_sharding_strategy.h | 4 +- .../auto_sharding/auto_sharding_util.cc | 6 +- .../auto_sharding/auto_sharding_wrapper.h | 2 +- 8 files changed, 107 insertions(+), 107 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 30fb3f2a9ba6e6..aa23a2160babdb 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -108,11 +108,11 @@ std::vector ReshardingCostVector( // Factory functions for StrategyGroup. std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( - size_t instruction_id, LeafStrategies& leaf_strategies) { + size_t instruction_id, StrategyGroups& strategy_groups) { auto strategy_group = std::make_unique(); strategy_group->is_tuple = false; - strategy_group->node_idx = leaf_strategies.size(); - leaf_strategies.push_back(strategy_group.get()); + strategy_group->node_idx = strategy_groups.size(); + strategy_groups.push_back(strategy_group.get()); strategy_group->instruction_id = instruction_id; return strategy_group; } @@ -120,9 +120,9 @@ std::unique_ptr CreateLeafStrategyGroupWithoutInNodes( // Factory functions for StrategyGroup. std::unique_ptr CreateLeafStrategyGroup( size_t instruction_id, const HloInstruction* ins, - const StrategyMap& strategy_map, LeafStrategies& leaf_strategies) { + const StrategyMap& strategy_map, StrategyGroups& strategy_groups) { auto strategy_group = - CreateLeafStrategyGroupWithoutInNodes(instruction_id, leaf_strategies); + CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); for (int64_t i = 0; i < ins->operand_count(); ++i) { strategy_group->in_nodes.push_back(strategy_map.at(ins->operand(i)).get()); } @@ -234,7 +234,7 @@ GenerateReshardingCostsAndShardingsForAllOperands( std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, size_t instruction_id, bool have_memory_cost, - LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, StableHashMap>& pretrimmed_strategy_map) { std::unique_ptr strategy_group; @@ -246,7 +246,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( for (size_t i = 0; i < src_strategy_group->childs.size(); ++i) { auto child_strategies = MaybeFollowInsStrategyGroup( src_strategy_group->childs[i].get(), shape.tuple_shapes(i), - instruction_id, have_memory_cost, leaf_strategies, cluster_env, + instruction_id, have_memory_cost, strategy_groups, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); @@ -254,7 +254,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( } else { CHECK(shape.IsArray() || shape.IsToken()); strategy_group = - CreateLeafStrategyGroupWithoutInNodes(instruction_id, leaf_strategies); + CreateLeafStrategyGroupWithoutInNodes(instruction_id, strategy_groups); strategy_group->in_nodes.push_back(src_strategy_group); // Only follows the given strategy when there is no other strategy to be // restored. @@ -303,7 +303,7 @@ StatusOr> FollowReduceStrategy( const HloInstruction* ins, const Shape& output_shape, const HloInstruction* operand, const HloInstruction* unit, size_t instruction_id, StrategyMap& strategy_map, - LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, bool allow_mixed_mesh_shape, bool crash_at_error) { std::unique_ptr strategy_group; if (output_shape.IsTuple()) { @@ -313,7 +313,7 @@ StatusOr> FollowReduceStrategy( auto child_strategy_status = FollowReduceStrategy( ins, ins->shape().tuple_shapes().at(i), ins->operand(i), ins->operand(i + ins->shape().tuple_shapes_size()), instruction_id, - strategy_map, leaf_strategies, cluster_env, allow_mixed_mesh_shape, + strategy_map, strategy_groups, cluster_env, allow_mixed_mesh_shape, crash_at_error); if (!child_strategy_status.ok()) { return child_strategy_status; @@ -324,7 +324,7 @@ StatusOr> FollowReduceStrategy( } } else if (output_shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_groups); const StrategyGroup* src_strategy_group = strategy_map.at(operand).get(); // Follows the strategy of the operand. strategy_group->following = src_strategy_group; @@ -1120,9 +1120,9 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim( } } -StatusOr> CreateAllStrategiesVector( +StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, - LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, bool only_allow_divisible, @@ -1133,18 +1133,18 @@ StatusOr> CreateAllStrategiesVector( strategy_group->childs.reserve(shape.tuple_shapes_size()); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { auto child_strategies = - CreateAllStrategiesVector(ins, shape.tuple_shapes(i), instruction_id, - leaf_strategies, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - create_replicated_strategies) + CreateAllStrategiesGroup(ins, shape.tuple_shapes(i), instruction_id, + strategy_groups, cluster_env, strategy_map, + option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + create_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); } } else if (shape.IsArray()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_groups); EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, strategy_map, strategy_group, only_allow_divisible, "", call_graph); @@ -1186,7 +1186,7 @@ StatusOr> CreateAllStrategiesVector( } } else if (shape.IsToken()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_groups); AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, strategy_group, replicated_penalty); } else { @@ -1197,12 +1197,12 @@ StatusOr> CreateAllStrategiesVector( StatusOr> CreateParameterStrategyGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, - LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, + StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, bool only_allow_divisible) { - return CreateAllStrategiesVector( - ins, shape, instruction_id, leaf_strategies, cluster_env, strategy_map, + return CreateAllStrategiesGroup( + ins, shape, instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, option.allow_replicated_parameters); } @@ -1530,10 +1530,10 @@ std::unique_ptr CreateElementwiseOperatorStrategies( const InstructionDepthMap& depth_map, const AliasMap& alias_map, const StableHashMap>& pretrimmed_strategy_map, - int64_t max_depth, LeafStrategies& leaf_strategies, + int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs) { std::unique_ptr strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, leaf_strategies); + instruction_id, ins, strategy_map, strategy_groups); // Choose an operand to follow int64_t follow_idx; @@ -1612,9 +1612,9 @@ std::unique_ptr CreateReshapeStrategies( const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, bool only_allow_divisible, double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const AutoShardingOption& option, LeafStrategies& leaf_strategies) { + const AutoShardingOption& option, StrategyGroups& strategy_groups) { std::unique_ptr strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, leaf_strategies); + instruction_id, ins, strategy_map, strategy_groups); const HloInstruction* operand = ins->operand(0); const Array& device_mesh = cluster_env.device_mesh_; const Array& device_mesh_1d = cluster_env.device_mesh_1d_; @@ -1702,7 +1702,7 @@ std::unique_ptr CreateReshapeStrategies( // NOLINTBEGIN(readability/fn_size) // TODO(zhuohan): Decompose this function into smaller pieces // Build possible sharding strategies and their costs for all instructions. -StatusOr> +StatusOr> BuildStrategyAndCost(const HloInstructionSequence& sequence, const HloModule* module, const absl::flat_hash_map& @@ -1723,7 +1723,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // doesn't need to strictly follow it. We restore the trimmed strategies in // this situation. StableHashMap> pretrimmed_strategy_map; - LeafStrategies leaf_strategies; + StrategyGroups strategy_groups; AssociativeDotPairs associative_dot_pairs; const std::vector& instructions = sequence.instructions(); @@ -1768,7 +1768,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kRng: { strategy_group = CreateParameterStrategyGroup( - ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible) .value(); @@ -1776,14 +1776,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kConstant: { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - leaf_strategies); + strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0); break; } case HloOpcode::kScatter: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); // We follow the first operand (the array we're scattering into) auto src_strategy_group = strategy_map.at(ins->operand(0)).get(); CHECK(!src_strategy_group->is_tuple); @@ -1814,7 +1814,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kGather: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); const HloInstruction* indices = ins->operand(1); const Shape& shape = ins->shape(); for (int32_t index_dim = 0; index_dim < indices->shape().rank(); @@ -1900,7 +1900,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kBroadcast: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); const HloInstruction* operand = ins->operand(0); @@ -1932,13 +1932,13 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, only_allow_divisible, replicated_penalty, batch_dim_map, option, - leaf_strategies); + strategy_groups); break; } case HloOpcode::kTranspose: case HloOpcode::kReverse: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); const HloInstruction* operand = ins->operand(0); @@ -1984,7 +1984,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); int64_t follow_idx; switch (opcode) { // TODO(yuemmawang) Re-evaluate the follow_idx choices for the @@ -2085,7 +2085,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, auto operand_strategies = strategy_map.at(ins->operand(0)).get(); strategy_group = MaybeFollowInsStrategyGroup( operand_strategies, ins->shape(), instruction_id, - /* have_memory_cost */ true, leaf_strategies, cluster_env, + /* have_memory_cost */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); break; } @@ -2093,13 +2093,13 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, if (ins->shape() == ins->operand(0)->shape()) { strategy_group = CreateElementwiseOperatorStrategies( instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, associative_dot_pairs); } else { strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, only_allow_divisible, replicated_penalty, batch_dim_map, option, - leaf_strategies); + strategy_groups); } break; } @@ -2157,14 +2157,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kClamp: { strategy_group = CreateElementwiseOperatorStrategies( instruction_id, ins, strategy_map, cluster_env, depth_map, - alias_map, pretrimmed_strategy_map, max_depth, leaf_strategies, + alias_map, pretrimmed_strategy_map, max_depth, strategy_groups, associative_dot_pairs); break; } case HloOpcode::kReduce: { auto strategies_status = FollowReduceStrategy( ins, ins->shape(), ins->operand(0), ins->operand(1), instruction_id, - strategy_map, leaf_strategies, cluster_env, + strategy_map, strategy_groups, cluster_env, option.allow_mixed_mesh_shape, !trying_multiple_mesh_shapes); if (strategies_status.ok()) { strategy_group = std::move(strategies_status.value()); @@ -2175,7 +2175,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kDot: { TF_RETURN_IF_ERROR(HandleDot( - strategy_group, leaf_strategies, strategy_map, ins, instruction_id, + strategy_group, strategy_groups, strategy_map, ins, instruction_id, cluster_env, batch_dim_map, option, call_graph)); if (option.allow_replicated_strategy_for_dot_and_conv) { AddReplicatedStrategy( @@ -2187,7 +2187,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kConvolution: { TF_RETURN_IF_ERROR(HandleConv( - strategy_group, leaf_strategies, strategy_map, ins, instruction_id, + strategy_group, strategy_groups, strategy_map, ins, instruction_id, cluster_env, batch_dim_map, option, call_graph)); if (option.allow_replicated_strategy_for_dot_and_conv) { AddReplicatedStrategy( @@ -2199,14 +2199,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } case HloOpcode::kRngGetAndUpdateState: { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - leaf_strategies); + strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, 0); break; } case HloOpcode::kIota: { strategy_group = CreateLeafStrategyGroupWithoutInNodes(instruction_id, - leaf_strategies); + strategy_groups); if (cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, strategy_map, strategy_group, @@ -2248,7 +2248,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_map.at(operand).get(); auto child_strategies = MaybeFollowInsStrategyGroup( src_strategy_group, operand->shape(), instruction_id, - /* have_memory_cost= */ true, leaf_strategies, cluster_env, + /* have_memory_cost= */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); @@ -2263,7 +2263,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_group = MaybeFollowInsStrategyGroup( src_strategy_group->childs[ins->tuple_index()].get(), ins->shape(), instruction_id, - /* have_memory_cost= */ true, leaf_strategies, cluster_env, + /* have_memory_cost= */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); break; } @@ -2281,7 +2281,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, ++i) { std::unique_ptr child_strategies = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), cluster_env, strategy_map, child_strategies, replicated_penalty); @@ -2290,8 +2290,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } } else { strategy_group = - CreateAllStrategiesVector( - ins, ins->shape(), instruction_id, leaf_strategies, + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, true) .value(); @@ -2299,14 +2299,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } else { if (only_replicated) { strategy_group = CreateLeafStrategyGroup( - instruction_id, ins, strategy_map, leaf_strategies); + instruction_id, ins, strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, replicated_penalty); } else { strategy_group = - CreateAllStrategiesVector( - ins, ins->shape(), instruction_id, leaf_strategies, + CreateAllStrategiesGroup( + ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, true) .value(); @@ -2321,7 +2321,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(src_strategy_group->is_tuple); strategy_group = MaybeFollowInsStrategyGroup( src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, leaf_strategies, cluster_env, + /* have_memory_cost= */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); } else if (ins->has_sharding()) { generate_non_following_strategies(false); @@ -2336,7 +2336,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, strategy_map.at(operand).get(); strategy_group = MaybeFollowInsStrategyGroup( src_strategy_group, ins->shape(), instruction_id, - /* have_memory_cost= */ true, leaf_strategies, cluster_env, + /* have_memory_cost= */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); } } else if (IsTopKCustomCall(ins)) { @@ -2356,7 +2356,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, auto child_strategies = MaybeFollowInsStrategyGroup( src_strategy_group->childs[i].get(), ins->shape().tuple_shapes().at(i), instruction_id, - /* have_memory_cost= */ true, leaf_strategies, cluster_env, + /* have_memory_cost= */ true, strategy_groups, cluster_env, pretrimmed_strategy_map); child_strategies->tuple_element_idx = i; strategy_group->childs.push_back(std::move(child_strategies)); @@ -2368,24 +2368,24 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kInfeed: case HloOpcode::kSort: { strategy_group = - CreateAllStrategiesVector( - ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - /*create_replicated_strategies*/ true) + CreateAllStrategiesGroup(ins, ins->shape(), instruction_id, + strategy_groups, cluster_env, strategy_map, + option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible, + /*create_replicated_strategies*/ true) .value(); break; } case HloOpcode::kOutfeed: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); GenerateOutfeedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, replicated_penalty); break; } case HloOpcode::kAfterAll: { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, - strategy_map, leaf_strategies); + strategy_map, strategy_groups); AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, strategy_group, replicated_penalty); break; @@ -2496,7 +2496,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } } - return std::make_tuple(std::move(strategy_map), std::move(leaf_strategies), + return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), std::move(associative_dot_pairs)); } @@ -2505,7 +2505,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, AutoShardingSolverResult CallSolver( const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, + const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, bool compute_iis, int64_t solver_timeout_in_seconds, const AutoShardingOption& option, @@ -2513,7 +2513,7 @@ AutoShardingSolverResult CallSolver( sharding_propagation_solution) { // Serialize edges and edge costs to 1d numpy arrays AutoShardingSolverRequest request; - request.num_nodes = leaf_strategies.size(); + request.num_nodes = strategy_groups.size(); request.memory_budget = option.memory_budget_per_device; request.s_len = cost_graph.node_lens_; request.s_follow = cost_graph.follow_idx_; @@ -2540,7 +2540,7 @@ AutoShardingSolverResult CallSolver( // Serialize node costs int num_nodes_without_default = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { - const StrategyGroup* strategy_group = leaf_strategies[node_idx]; + const StrategyGroup* strategy_group = strategy_groups[node_idx]; auto instruction_name = instructions.at(strategy_group->instruction_id)->name(); request.instruction_names.push_back( @@ -2585,8 +2585,8 @@ AutoShardingSolverResult CallSolver( // spec std::vector> new_followers; for (const auto& pair : alias_set) { - const StrategyGroup* src_strategy_group = leaf_strategies[pair.first]; - const StrategyGroup* dst_strategy_group = leaf_strategies[pair.second]; + const StrategyGroup* src_strategy_group = strategy_groups[pair.first]; + const StrategyGroup* dst_strategy_group = strategy_groups[pair.second]; Matrix raw_cost(src_strategy_group->strategies.size(), dst_strategy_group->strategies.size()); for (NodeStrategyIdx i = 0; i < src_strategy_group->strategies.size(); @@ -3115,29 +3115,29 @@ std::string PrintStrategyMap(const StrategyMap& strategy_map, std::string PrintAutoShardingSolution(const HloInstructionSequence& sequence, const LivenessSet& liveness_set, const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, + const StrategyGroups& strategy_groups, const CostGraph& cost_graph, absl::Span s_val, double objective) { std::string str("=== Auto sharding strategy ===\n"); const std::vector& instructions = sequence.instructions(); - size_t N = leaf_strategies.size(); + size_t N = strategy_groups.size(); // Print the chosen strategy for (NodeIdx node_idx = 0; node_idx < N; ++node_idx) { absl::StrAppend( &str, node_idx, " ", ToAdaptiveString( - instructions[leaf_strategies[node_idx]->instruction_id]), + instructions[strategy_groups[node_idx]->instruction_id]), " "); NodeStrategyIdx stra_idx = cost_graph.RemapIndex(node_idx, s_val[node_idx]); if (cost_graph.follow_idx_[node_idx] < 0) { absl::StrAppend( - &str, leaf_strategies[node_idx]->strategies[stra_idx].ToString(), + &str, strategy_groups[node_idx]->strategies[stra_idx].ToString(), "\n"); } else { absl::StrAppend( - &str, leaf_strategies[node_idx]->strategies[stra_idx].ToString(), + &str, strategy_groups[node_idx]->strategies[stra_idx].ToString(), " follow ", cost_graph.follow_idx_[node_idx], "\n"); } } @@ -4386,21 +4386,21 @@ StatusOr AutoShardingImplementation::RunAutoSharding( ins_depth_map = spmd::BuildInstructionDepthMap(sequence, batch_dim_map); // ----- Build strategies and costs ----- spmd::StrategyMap strategy_map; - spmd::LeafStrategies leaf_strategies; + spmd::StrategyGroups strategy_groups; spmd::AssociativeDotPairs associative_dot_pairs; TF_ASSIGN_OR_RETURN( - std::tie(strategy_map, leaf_strategies, associative_dot_pairs), + std::tie(strategy_map, strategy_groups, associative_dot_pairs), BuildStrategyAndCost( sequence, module, instruction_execution_counts, ins_depth_map, batch_dim_map, alias_map, cluster_env, option_, *call_graph, hlo_cost_analysis, option_.try_multiple_mesh_shapes)); spmd::AliasSet alias_set = spmd::BuildAliasSet(module, strategy_map); - CheckAliasSetCompatibility(alias_set, leaf_strategies, sequence); + CheckAliasSetCompatibility(alias_set, strategy_groups, sequence); XLA_VLOG_LINES(8, PrintStrategyMap(strategy_map, sequence)); // ----- Build cost graph and merge unimportant nodes ----- - spmd::CostGraph cost_graph(leaf_strategies, associative_dot_pairs); + spmd::CostGraph cost_graph(strategy_groups, associative_dot_pairs); cost_graph.Simplify(option_.simplify_graph); // ----- Build the liveness node set ----- @@ -4424,7 +4424,7 @@ StatusOr AutoShardingImplementation::RunAutoSharding( double objective = -1.0; if (!option_.load_solution_vector) { auto solver_result = Solve( - *hlo_live_range, liveness_node_set, strategy_map, leaf_strategies, + *hlo_live_range, liveness_node_set, strategy_map, strategy_groups, cost_graph, alias_set, option_, sharding_propagation_solution); if (solver_result.skip_auto_sharding) { return AutoShardingResult::kModuleUnchangedNoShardingPerfomed; @@ -4442,7 +4442,7 @@ StatusOr AutoShardingImplementation::RunAutoSharding( } XLA_VLOG_LINES(5, PrintAutoShardingSolution(sequence, liveness_set, - strategy_map, leaf_strategies, + strategy_map, strategy_groups, cost_graph, s_val, objective)); XLA_VLOG_LINES(1, PrintSolutionMemoryUsage(liveness_set, strategy_map, cost_graph, s_val)); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index b7f2ea79aac87f..fa1a19c8d34f3f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -144,7 +144,7 @@ std::vector FollowInsCostVector(int64_t source_len, int64_t index); std::unique_ptr CreateLeafStrategyGroup( size_t instruction_id, const HloInstruction* ins, - const StrategyMap& strategy_map, LeafStrategies& leaf_strategies); + const StrategyMap& strategy_map, StrategyGroups& strategy_groups); void SetInNodesWithInstruction(std::unique_ptr& strategy_group, const HloInstruction* ins, @@ -159,14 +159,14 @@ Status FilterStrategy(const HloInstruction* ins, const Shape& shape, const AutoShardingOption& option); Status HandleDot(std::unique_ptr& strategy_group, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); Status HandleConv(std::unique_ptr& strategy_group, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, @@ -192,7 +192,7 @@ AliasSet BuildAliasSet(const HloModule* module, const StrategyMap& strategy_map); void CheckAliasSetCompatibility(const AliasSet& alias_set, - const LeafStrategies& leaf_strategies, + const StrategyGroups& strategy_groups, const HloInstructionSequence& sequence); void GenerateReduceScatter( @@ -214,7 +214,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, AutoShardingSolverResult Solve( const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, + const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution = {}); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index c8b6611dd0d108..fd627355089688 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -37,14 +37,14 @@ namespace spmd { // It merges nodes and does path compression. class CostGraph { public: - CostGraph(const LeafStrategies& leaf_strategies, + CostGraph(const StrategyGroups& strategy_groups, const AssociativeDotPairs& associative_dot_pairs) { - node_lens_.reserve(leaf_strategies.size()); - extra_node_costs_.reserve(leaf_strategies.size()); - adjacency_.assign(leaf_strategies.size(), StableHashSet()); + node_lens_.reserve(strategy_groups.size()); + extra_node_costs_.reserve(strategy_groups.size()); + adjacency_.assign(strategy_groups.size(), StableHashSet()); // Build the cost graph - for (const auto& strategies : leaf_strategies) { + for (const auto& strategies : strategy_groups) { node_lens_.push_back(strategies->strategies.size()); extra_node_costs_.push_back( std::vector(strategies->strategies.size(), 0.0)); @@ -101,14 +101,14 @@ class CostGraph { Matrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { - if (leaf_strategies[src_idx]->strategies[i].communication_cost > 0) { + if (strategy_groups[src_idx]->strategies[i].communication_cost > 0) { CHECK_LE( std::abs( - leaf_strategies[src_idx]->strategies[i].communication_cost - - leaf_strategies[dst_idx]->strategies[i].communication_cost), + strategy_groups[src_idx]->strategies[i].communication_cost - + strategy_groups[dst_idx]->strategies[i].communication_cost), 1e-6); edge_cost(i, i) = - -leaf_strategies[src_idx]->strategies[i].communication_cost; + -strategy_groups[src_idx]->strategies[i].communication_cost; } } AddEdgeCost(src_idx, dst_idx, edge_cost); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index ffc16f0868571a..159e3de1625a86 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -659,14 +659,14 @@ class DotHandler : public HandlerBase { // Register strategies for dot instructions. Status HandleDot(std::unique_ptr& strategy_group, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_groups); DotHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph); @@ -853,14 +853,14 @@ class ConvHandler : public HandlerBase { // Register strategies for dot instructions. Status HandleConv(std::unique_ptr& strategy_group, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + StrategyGroups& strategy_groups, StrategyMap& strategy_map, const HloInstruction* ins, size_t instruction_id, const ClusterEnvironment& cluster_env, const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, - leaf_strategies); + strategy_groups); ConvHandler handler(strategy_group, strategy_map, ins, cluster_env, batch_map, option, call_graph); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index 027426f1331b84..546470f6ad7f25 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -33,12 +33,12 @@ namespace spmd { AutoShardingSolverResult Solve( const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, + const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const AutoShardingOption& option, const absl::flat_hash_map& sharding_propagation_solution) { return CallSolver(hlo_live_range, liveness_node_set, strategy_map, - leaf_strategies, cost_graph, alias_set, /*s_hint*/ {}, + strategy_groups, cost_graph, alias_set, /*s_hint*/ {}, /*compute_iis*/ true, option.solver_timeout_in_seconds, option, sharding_propagation_solution); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 4a8e479ef6ed9e..24fa988df57299 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -201,8 +201,8 @@ using LivenessNodeSet = std::vector>; // Map an instruction to its strategy group. using StrategyMap = StableHashMap>; -// The list of all leaf strategies. -using LeafStrategies = std::vector; +// The list of all strategy groups. +using StrategyGroups = std::vector; // The list of all dot instruction pairs that can be optimized by // AllReduceReassociate pass. using AssociativeDotPairs = diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 262fa0663c9723..88305332ff2f6a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1749,13 +1749,13 @@ AliasSet BuildAliasSet(const HloModule* module, } void CheckAliasSetCompatibility(const AliasSet& alias_set, - const LeafStrategies& leaf_strategies, + const StrategyGroups& strategy_groups, const HloInstructionSequence& sequence) { const std::vector& instructions = sequence.instructions(); // Checks the compatibility for (const auto& pair : alias_set) { - const StrategyGroup* src_strategy_group = leaf_strategies[pair.first]; - const StrategyGroup* dst_strategy_group = leaf_strategies[pair.second]; + const StrategyGroup* src_strategy_group = strategy_groups[pair.first]; + const StrategyGroup* dst_strategy_group = strategy_groups[pair.second]; size_t compatible_cnt = 0; bool replicated = false; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 374fb1b37f7e19..0308d636d6835f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -38,7 +38,7 @@ namespace spmd { AutoShardingSolverResult CallSolver( const HloLiveRange& hlo_live_range, const LivenessNodeSet& liveness_node_set, const StrategyMap& strategy_map, - const LeafStrategies& leaf_strategies, const CostGraph& cost_graph, + const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, const std::vector& s_hint, bool compute_iis, int64_t solver_timeout_in_seconds, const AutoShardingOption& option, From cd4080b947ef454a28b4cd2cd08aa9c1c9003083 Mon Sep 17 00:00:00 2001 From: Steven Toribio Date: Thu, 16 Nov 2023 11:32:48 -0800 Subject: [PATCH 182/391] gpu_delegate: Link nativewindow PiperOrigin-RevId: 583111445 --- tensorflow/lite/delegates/gpu/BUILD | 1 + .../lite/delegates/gpu/async_buffers_test.cc | 5 ++++- tensorflow/lite/delegates/gpu/build_defs.bzl | 21 +------------------ 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index f46efca9e8d007..098bfd034e9184 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -344,6 +344,7 @@ cc_test( "tflite_not_portable_ios", ], deps = [ + ":android_hardware_buffer", ":async_buffers", ":delegate", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/delegates/gpu/async_buffers_test.cc b/tensorflow/lite/delegates/gpu/async_buffers_test.cc index 91e8d5212c1d2a..649c41f4be6797 100644 --- a/tensorflow/lite/delegates/gpu/async_buffers_test.cc +++ b/tensorflow/lite/delegates/gpu/async_buffers_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/delegates/gpu/android_hardware_buffer.h" #include "tensorflow/lite/delegates/gpu/api.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" @@ -29,6 +30,7 @@ namespace { TEST(AsyncBufferTest, DuplicateTest) { if (__builtin_available(android 26, *)) { + auto Instance = OptionalAndroidHardwareBuffer::Instance; // Create tie TensorObjectDef* tie = new TensorObjectDef(); tie->object_def.data_type = DataType::FLOAT32; @@ -45,7 +47,8 @@ TEST(AsyncBufferTest, DuplicateTest) { AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER; AHardwareBuffer* ahwb; - EXPECT_EQ(AHardwareBuffer_allocate(&buffDesc, &ahwb), 0); + EXPECT_TRUE(Instance().IsSupported(&buffDesc)); + EXPECT_EQ(Instance().Allocate(&buffDesc, &ahwb), 0); // Init GL Env to properly use gl fcns std::unique_ptr env; diff --git a/tensorflow/lite/delegates/gpu/build_defs.bzl b/tensorflow/lite/delegates/gpu/build_defs.bzl index 9caa5f56743183..d98a201551176a 100644 --- a/tensorflow/lite/delegates/gpu/build_defs.bzl +++ b/tensorflow/lite/delegates/gpu/build_defs.bzl @@ -1,24 +1,5 @@ """Additional build options needed for the GPU Delegate.""" -# copybara:uncomment_begin(google-only) -# load("//third_party/android/ndk/platforms:grte_top.bzl", "min_supported_ndk_api") -# copybara:uncomment_end - -def nativewindow_linkopts(): - # copybara:uncomment_begin(google-only) - # return min_supported_ndk_api("26", ["-lnativewindow"]) - # copybara:uncomment_end - # copybara:comment_begin(oss-only) - return select({ - "//tensorflow:android": [ - # TODO: should only link against -lnativewindow - # if Android min supported NDK API Level is at least 26? - "-lnativewindow", - ], - "//conditions:default": [], - }) - # copybara:comment_end - def gpu_delegate_linkopts(): """Additional link options needed when linking in the GPU Delegate.""" return select({ @@ -31,7 +12,7 @@ def gpu_delegate_linkopts(): "-lGLESv2", ], "//conditions:default": [], - }) + nativewindow_linkopts() + }) def tflite_angle_heapcheck_deps(): # copybara:uncomment_begin(google-only) From fef4c9c056dc3bf00f14cf8c946eb50defc6c2ef Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 11:38:58 -0800 Subject: [PATCH 183/391] Temporarily install CMake on MacOS, for Python 3.12. It's currently required to build `dm-tree` for 3.12. PiperOrigin-RevId: 583113314 --- ci/official/utilities/setup_macos.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index ab76d22cfaeece..7f00dea7838197 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -54,4 +54,11 @@ elif [[ "${TFCI_WHL_BAZEL_TEST_ENABLE}" == 1 ]]; then echo 'run all the tests. Please make sure your system has enough disk space' echo 'You can control where Bazel stores test artifacts by setting the' echo '`TEST_TMPDIR` environment variable.' -fi \ No newline at end of file +fi + +if [[ "${TFCI_PYTHON_VERSION}" == "3.12" ]]; then + # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. + # Having CMake allows building them. + # Once the wheels are added, this should be removed - b/308399490. + sudo apt-get install -y --no-install-recommends cmake +fi From f7ccf984c6210585d30854937e551338e44d0755 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 16 Nov 2023 11:57:57 -0800 Subject: [PATCH 184/391] Refactor `SimpleMemoryArena`, extract code that deals with the resizable aligned buffer to a separate class. PiperOrigin-RevId: 583118803 --- tensorflow/lite/simple_memory_arena.cc | 121 +++++++++++++------------ tensorflow/lite/simple_memory_arena.h | 62 +++++++++---- 2 files changed, 109 insertions(+), 74 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 694215115297bc..9c6a596ed82d10 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -15,19 +15,18 @@ limitations under the License. #include "tensorflow/lite/simple_memory_arena.h" -#include -#include - #include #include #include #include #include #include +#include #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" + #ifdef TF_LITE_TENSORFLOW_PROFILER #include "tensorflow/lite/tensorflow_profiler_logger.h" #endif // TF_LITE_TENSORFLOW_PROFILER @@ -44,6 +43,56 @@ T AlignTo(size_t alignment, T offset) { namespace tflite { +bool ResizableAlignedBuffer::Resize(size_t new_size) { + const size_t new_allocation_size = RequiredAllocationSize(new_size); + if (new_allocation_size <= allocation_size_) { + // Skip reallocation when resizing down. + return false; + } +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/true); + OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), + new_allocation_size); +#endif + auto new_buffer = std::unique_ptr(new char[new_allocation_size]); + char* new_aligned_ptr = reinterpret_cast( + AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); + if (new_size > 0 && allocation_size_ > 0) { + // Copy data when both old and new buffers are bigger than 0 bytes. + const size_t new_alloc_alignment_adjustment = + new_aligned_ptr - new_buffer.get(); + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); + const size_t copy_amount = + std::min(allocation_size_ - old_alloc_alignment_adjustment, + new_allocation_size - new_alloc_alignment_adjustment); + std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); + } + buffer_ = std::move(new_buffer); + aligned_ptr_ = new_aligned_ptr; +#ifdef TF_LITE_TENSORFLOW_PROFILER + if (allocation_size_ > 0) { + OnTfLiteArenaDealloc(subgraph_index_, + reinterpret_cast(this), + allocation_size_); + } +#endif + allocation_size_ = new_allocation_size; +#ifdef TF_LITE_TENSORFLOW_PROFILER + PauseHeapMonitoring(/*pause=*/false); +#endif + return true; +} + +void ResizableAlignedBuffer::Release() { +#ifdef TF_LITE_TENSORFLOW_PROFILER + OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), + allocation_size_); +#endif + buffer_.reset(); + allocation_size_ = 0; + aligned_ptr_ = nullptr; +} + void SimpleMemoryArena::PurgeAfter(int32_t node) { for (int i = 0; i < active_allocs_.size(); ++i) { if (active_allocs_[i].first_node > node) { @@ -91,7 +140,7 @@ TfLiteStatus SimpleMemoryArena::Allocate( TfLiteContext* context, size_t alignment, size_t size, int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc) { - TF_LITE_ENSURE(context, alignment <= arena_alignment_); + TF_LITE_ENSURE(context, alignment <= underlying_buffer_.GetAlignment()); new_alloc->tensor = tensor; new_alloc->first_node = first_node; new_alloc->last_node = last_node; @@ -142,48 +191,12 @@ TfLiteStatus SimpleMemoryArena::Allocate( } TfLiteStatus SimpleMemoryArena::Commit(bool* arena_reallocated) { - size_t required_size = RequiredBufferSize(); - if (required_size > underlying_buffer_size_) { - *arena_reallocated = true; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/true); - OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), - required_size); -#endif - char* new_alloc = new char[required_size]; - char* new_underlying_buffer_aligned_ptr = reinterpret_cast( - AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); - - // If the arena had been previously allocated, copy over the old memory. - // Since Alloc pointers are offset based, they will remain valid in the new - // memory block. - if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { - size_t copy_amount = std::min( - underlying_buffer_.get() + underlying_buffer_size_ - - underlying_buffer_aligned_ptr_, - new_alloc + required_size - new_underlying_buffer_aligned_ptr); - memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, - copy_amount); - } - -#ifdef TF_LITE_TENSORFLOW_PROFILER - if (underlying_buffer_size_ > 0) { - OnTfLiteArenaDealloc(subgraph_index_, - reinterpret_cast(this), - underlying_buffer_size_); - } -#endif - underlying_buffer_.reset(new_alloc); - underlying_buffer_size_ = required_size; - underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; -#ifdef TF_LITE_TENSORFLOW_PROFILER - PauseHeapMonitoring(/*pause=*/false); -#endif - } else { - *arena_reallocated = false; - } + // Resize the arena to the high water mark (calculated by Allocate), retaining + // old contents and alignment in the process. Since Alloc pointers are offset + // based, they will remain valid in the new memory block. + *arena_reallocated = underlying_buffer_.Resize(high_water_mark_); committed_ = true; - return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; + return kTfLiteOk; } TfLiteStatus SimpleMemoryArena::ResolveAlloc( @@ -191,12 +204,12 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - TF_LITE_ENSURE(context, - underlying_buffer_size_ >= (alloc.offset + alloc.size)); + TF_LITE_ENSURE(context, underlying_buffer_.GetAllocationSize() >= + (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + *output_ptr = underlying_buffer_.GetPtr() + alloc.offset; } return kTfLiteOk; } @@ -210,13 +223,7 @@ TfLiteStatus SimpleMemoryArena::ClearPlan() { TfLiteStatus SimpleMemoryArena::ReleaseBuffer() { committed_ = false; -#ifdef TF_LITE_TENSORFLOW_PROFILER - OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - underlying_buffer_size_); -#endif - underlying_buffer_size_ = 0; - underlying_buffer_aligned_ptr_ = nullptr; - underlying_buffer_.reset(); + underlying_buffer_.Release(); return kTfLiteOk; } @@ -228,8 +235,8 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_size_, - active_allocs_); + tflite::DumpArenaInfo(name, execution_plan, + underlying_buffer_.GetAllocationSize(), active_allocs_); } } // namespace tflite diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 0e527df9ac98b1..05bb52e6a225e4 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -55,6 +55,44 @@ struct ArenaAllocWithUsageInterval { } }; +class ResizableAlignedBuffer { + public: + explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) + : allocation_size_(0), + alignment_(alignment), + subgraph_index_(subgraph_index) { + // To silence unused private member warning, only used with + // TF_LITE_TENSORFLOW_PROFILER + (void)subgraph_index_; + } + + // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps + // alignment and any existing the data. Returns true when any external + // pointers into the data array need to be adjusted (the buffer was moved). + bool Resize(size_t new_size); + // Releases any allocated memory. + void Release(); + + // Pointer to the data array. + char* GetPtr() const { return aligned_ptr_; } + // Size of the allocation (NOT of the data array). + size_t GetAllocationSize() const { return allocation_size_; } + // Alignment of the data array. + size_t GetAlignment() const { return alignment_; } + + private: + size_t RequiredAllocationSize(size_t data_array_size) const { + return data_array_size + alignment_ - 1; + } + + std::unique_ptr buffer_; + size_t allocation_size_; + size_t alignment_; + char* aligned_ptr_; + + int subgraph_index_; +}; + // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is @@ -63,11 +101,9 @@ struct ArenaAllocWithUsageInterval { class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment, int subgraph_index = 0) - : subgraph_index_(subgraph_index), - committed_(false), - arena_alignment_(arena_alignment), + : committed_(false), high_water_mark_(0), - underlying_buffer_size_(0), + underlying_buffer_(arena_alignment, subgraph_index), active_allocs_() {} // Delete all allocs. This should be called when allocating the first node of @@ -99,10 +135,6 @@ class SimpleMemoryArena { int32_t tensor, int32_t first_node, int32_t last_node, ArenaAllocWithUsageInterval* new_alloc); - inline size_t RequiredBufferSize() { - return high_water_mark_ + arena_alignment_ - 1; - } - TfLiteStatus Commit(bool* arena_reallocated); TfLiteStatus ResolveAlloc(TfLiteContext* context, @@ -119,10 +151,12 @@ class SimpleMemoryArena { // again until Commit() is called & tensor allocations are resolved. TfLiteStatus ReleaseBuffer(); - size_t GetBufferSize() const { return underlying_buffer_size_; } + size_t GetBufferSize() const { + return underlying_buffer_.GetAllocationSize(); + } std::intptr_t BasePointer() const { - return reinterpret_cast(underlying_buffer_aligned_ptr_); + return reinterpret_cast(underlying_buffer_.GetPtr()); } // Dumps the memory allocation information of this memory arena (which could @@ -142,16 +176,10 @@ class SimpleMemoryArena { void DumpDebugInfo(const std::string& name, const std::vector& execution_plan) const; - protected: - int subgraph_index_; - private: bool committed_; - size_t arena_alignment_; size_t high_water_mark_; - std::unique_ptr underlying_buffer_; - size_t underlying_buffer_size_; - char* underlying_buffer_aligned_ptr_; + ResizableAlignedBuffer underlying_buffer_; std::vector active_allocs_; }; From 7dc97232ee5748de795f29aecc16b66f951523f9 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 16 Nov 2023 12:35:27 -0800 Subject: [PATCH 185/391] py::cast> forms a reference to a temporary inside the cast, lifting to a parameter doesn't have that problem. PiperOrigin-RevId: 583129785 --- third_party/xla/xla/python/xla_compiler.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index d92c2ca020e6a5..2342ace4ca1c7c 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -288,9 +288,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { // Shapes py::class_ layout_class(m, "Layout"); layout_class - .def(py::init([](py::object minor_to_major) { - return std::make_unique( - py::cast>(minor_to_major)); + .def(py::init([](absl::Span minor_to_major) { + return std::make_unique(minor_to_major); })) .def("minor_to_major", [](Layout layout) { return SpanToTuple(layout.minor_to_major()); }) From c9e5566952a7991d3dc67d618c5e8885748949ae Mon Sep 17 00:00:00 2001 From: Yilei Yang Date: Thu, 16 Nov 2023 12:38:46 -0800 Subject: [PATCH 186/391] Support assignment expressions / walrus operator in autograph. PiperOrigin-RevId: 583130593 --- tensorflow/python/autograph/pyct/cfg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index fd8ddf046d29e9..3c4f0ac15919e6 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -780,6 +780,11 @@ def visit_ImportFrom(self, node): def visit_Expr(self, node): self._process_basic_statement(node) + def visit_NamedExpr(self, node): + # TODO(yileiyang): Add a test case once we have a newer astunparse version. + # NamedExpr was introduced in Python 3.8 and supported in gast 0.5.1+. + self._process_basic_statement(node) + def visit_Assign(self, node): self._process_basic_statement(node) From cd8bd49635a1d487ad60bb2291259153e21af868 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 16 Nov 2023 12:46:57 -0800 Subject: [PATCH 187/391] Use `malloc` instead of `new` to allocate buffers to reduce overhead needed to ensure alignment. PiperOrigin-RevId: 583132918 --- tensorflow/lite/simple_memory_arena.cc | 22 +++++++++++++--------- tensorflow/lite/simple_memory_arena.h | 16 ++++++++++------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 9c6a596ed82d10..80f216072dada6 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -16,12 +16,12 @@ limitations under the License. #include "tensorflow/lite/simple_memory_arena.h" #include +#include #include +#include #include #include -#include #include -#include #include #include "tensorflow/lite/core/c/common.h" @@ -54,20 +54,20 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_allocation_size); #endif - auto new_buffer = std::unique_ptr(new char[new_allocation_size]); + char* new_buffer = reinterpret_cast(std::malloc(new_allocation_size)); char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); + AlignTo(alignment_, reinterpret_cast(new_buffer))); if (new_size > 0 && allocation_size_ > 0) { // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = - new_aligned_ptr - new_buffer.get(); - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); + const size_t new_alloc_alignment_adjustment = new_aligned_ptr - new_buffer; + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_; const size_t copy_amount = std::min(allocation_size_ - old_alloc_alignment_adjustment, new_allocation_size - new_alloc_alignment_adjustment); std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); } - buffer_ = std::move(new_buffer); + std::free(buffer_); + buffer_ = new_buffer; aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER if (allocation_size_ > 0) { @@ -84,11 +84,15 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { } void ResizableAlignedBuffer::Release() { + if (buffer_ == nullptr) { + return; + } #ifdef TF_LITE_TENSORFLOW_PROFILER OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), allocation_size_); #endif - buffer_.reset(); + std::free(buffer_); + buffer_ = nullptr; allocation_size_ = 0; aligned_ptr_ = nullptr; } diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 05bb52e6a225e4..87603a26c32e78 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -15,10 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ #define TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ -#include - +#include +#include #include -#include #include #include @@ -58,7 +57,8 @@ struct ArenaAllocWithUsageInterval { class ResizableAlignedBuffer { public: explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : allocation_size_(0), + : buffer_(nullptr), + allocation_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { // To silence unused private member warning, only used with @@ -66,6 +66,8 @@ class ResizableAlignedBuffer { (void)subgraph_index_; } + ~ResizableAlignedBuffer() { Release(); } + // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps // alignment and any existing the data. Returns true when any external // pointers into the data array need to be adjusted (the buffer was moved). @@ -82,10 +84,12 @@ class ResizableAlignedBuffer { private: size_t RequiredAllocationSize(size_t data_array_size) const { - return data_array_size + alignment_ - 1; + // malloc guarantees returned pointers are aligned to at least max_align_t. + return data_array_size + + std::max(std::size_t{0}, alignment_ - alignof(std::max_align_t)); } - std::unique_ptr buffer_; + char* buffer_; size_t allocation_size_; size_t alignment_; char* aligned_ptr_; From 5c77c0046ed37d25538e1c5910496a0c7bbc4bbb Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 16 Nov 2023 13:04:00 -0800 Subject: [PATCH 188/391] Symlink `bazel-xla/external` to `external` in generate_compile_commands.py This is necessary for headers outside of XLA to be found in the include paths given by `bazel aquery`. PiperOrigin-RevId: 583138128 --- .../xla/build_tools/lint/generate_compile_commands.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py index c52c4cc93203be..702439e418b693 100644 --- a/third_party/xla/build_tools/lint/generate_compile_commands.py +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -14,6 +14,12 @@ # ============================================================================ r"""Produces a `compile_commands.json` from the output of `bazel aquery`. +This tool requires that a build has been completed for all targets in the +query (e.g., for the example usage below `bazel build //xla/...`). This is due +to generated files like proto headers and files generated via tablegen. So if +LSP or other tools get out of date, it may be necessary to rebuild or regenerate +`compile_commands.json`, or both. + Example usage: bazel aquery "mnemonic(CppCompile, //xla/...)" --output=jsonproto | \ python3 build_tools/lint/generate_compile_commands.py @@ -99,6 +105,11 @@ def main(): logging.basicConfig() logging.getLogger().setLevel(logging.INFO) + # Setup external symlink so headers can be found in include paths + logging.info("Symlinking `xla/bazel-xla/external` to `xla/external`") + bazel_xla_external = _XLA_SRC_ROOT / "bazel-xla" / "external" + bazel_xla_external.symlink_to(_XLA_SRC_ROOT / "external") + logging.info("Reading `bazel aquery` output from stdin...") parsed_aquery_output = json.loads(sys.stdin.read()) From 9dee341d93cb3ca5a1474ff8c3c9baeaa1ba933d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 13:13:30 -0800 Subject: [PATCH 189/391] Update API Goldens with recent changes in Keras API PiperOrigin-RevId: 583141180 --- RELEASE.md | 3 +++ .../tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt | 2 +- .../golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 3e87a523116c95..f2456bfcb0b1e9 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -48,6 +48,9 @@ table maintained by the layer. If this layer is not used in conjunction with `UpdateEmbeddingCallback` the behavior of the layer would be same as `keras.layers.Embedding`. +* `keras.optimizers.Adam` + * Added the option to set adaptive epsilon to match implementations with Jax + and PyTorch equivalents. ### Breaking Changes diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt index 675bb89d694de6..15cdd2e274e29b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'adaptive_epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " } member_method { name: "add_variable" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt index d31bab3e3d8c7d..fb2ea437049b45 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'adaptive_epsilon\', \'amsgrad\', \'weight_decay\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'False\', \'None\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " } member_method { name: "add_variable" From 3097a30a93cdafdbc6475a06a862de6c53983a91 Mon Sep 17 00:00:00 2001 From: Richard Levasseur Date: Thu, 16 Nov 2023 13:54:27 -0800 Subject: [PATCH 190/391] Internal Code Change PiperOrigin-RevId: 583153780 --- tensorflow/python/tools/api/generator2/generate_api.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/tools/api/generator2/generate_api.bzl b/tensorflow/python/tools/api/generator2/generate_api.bzl index 64e9b96276eebe..c2a96438576d22 100644 --- a/tensorflow/python/tools/api/generator2/generate_api.bzl +++ b/tensorflow/python/tools/api/generator2/generate_api.bzl @@ -1,5 +1,6 @@ """Rules to generate the TensorFlow public API from annotated files.""" +# Placeholder: load PyInfo load("@bazel_skylib//lib:paths.bzl", "paths") load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES") load(":apis.bzl", _APIS = "APIS") From 1c3bcc2dfcbb75ce31c7254f8f23691d9346a041 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Thu, 16 Nov 2023 13:57:43 -0800 Subject: [PATCH 191/391] Enable TF-TPU wheels upload to PyPI PiperOrigin-RevId: 583154934 --- ci/official/envs/nightly_linux_x86_tpu_py310 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py311 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py312 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py39 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310 index 31f7a15bd3ca00..8331f324b6517f 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py310 +++ b/ci/official/envs/nightly_linux_x86_tpu_py310 @@ -1,5 +1,5 @@ # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads +source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311 index 4061b0330e6e93..9a93c1d5fda548 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py311 +++ b/ci/official/envs/nightly_linux_x86_tpu_py311 @@ -1,5 +1,5 @@ # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads +source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" diff --git a/ci/official/envs/nightly_linux_x86_tpu_py312 b/ci/official/envs/nightly_linux_x86_tpu_py312 index 31afd3709e4b92..086c53046c9fd6 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py312 +++ b/ci/official/envs/nightly_linux_x86_tpu_py312 @@ -1,5 +1,5 @@ # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads +source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39 index 645eeed5827e6e..012206b18a7ca9 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py39 +++ b/ci/official/envs/nightly_linux_x86_tpu_py39 @@ -1,5 +1,5 @@ # Disable tpu uploads while being worked on -source ci/official/envs/disable_all_uploads +source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=release_cpu_linux --config=tpu" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" From 787d11107e7ca420a9914dbe6392b3b1cbc01af0 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Thu, 16 Nov 2023 14:18:20 -0800 Subject: [PATCH 192/391] Improve readability in model_builder_help_test. PiperOrigin-RevId: 583162632 --- .../delegates/gpu/common/model_builder_helper_test.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc index 0b6819333d6b39..f13bc539785467 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper_test.cc @@ -25,18 +25,20 @@ namespace tflite { namespace gpu { namespace { +using ::testing::ElementsAre; + TEST(ModelBuilderHelperTest, CreateVectorCopyDataDifferentSize) { TfLiteTensor tflite_tensor; tflite_tensor.type = kTfLiteInt32; int32_t src_data[4] = {1, 2, 3, 4}; tflite_tensor.data.i32 = src_data; tflite_tensor.dims = TfLiteIntArrayCreate(1); - tflite_tensor.dims->data[0] = 4; - tflite_tensor.bytes = 4 * sizeof(int32_t); + tflite_tensor.dims->data[0] = sizeof(src_data) / sizeof(src_data[0]); + tflite_tensor.bytes = sizeof(src_data); int16_t dst[4]; ASSERT_OK(CreateVectorCopyData(tflite_tensor, dst)); - EXPECT_THAT(dst, testing::ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(dst, ElementsAre(1, 2, 3, 4)); TfLiteIntArrayFree(tflite_tensor.dims); } From 74c016e8cd65e5693fa5c4ea6ba09ce3a1c58a19 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 14:25:10 -0800 Subject: [PATCH 193/391] Allow tensor dialect in stablehlo_quant_opt PiperOrigin-RevId: 583164946 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 1 + .../mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 15dce89f4aac99..501653f6bcdb2e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -543,6 +543,7 @@ tf_cc_binary( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/mlir_hlo:mhlo_passes", "@stablehlo//:stablehlo_ops", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index 6883fbababb535..a55b1a88e5d964 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -53,7 +54,7 @@ int main(int argc, char** argv) { mlir::tf_saved_model::TensorFlowSavedModelDialect, mlir::func::FuncDialect, mlir::shape::ShapeDialect, mlir::arith::ArithDialect, mlir::tf_type::TFTypeDialect, - mlir::quant::QuantizationDialect, + mlir::quant::QuantizationDialect, mlir::tensor::TensorDialect, mlir::quantfork::QuantizationForkDialect, mlir::stablehlo::StablehloDialect, mlir::tf_executor::TensorFlowExecutorDialect>(); From 078f5a6635e9528cc1b2eb941bb1168fee50727e Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 16 Nov 2023 14:51:47 -0800 Subject: [PATCH 194/391] Rolling back as some internal tests are failing. PiperOrigin-RevId: 583173175 --- .../tensorflow/transforms/shape_inference.cc | 3 +- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/service/BUILD | 1 + third_party/xla/xla/service/hlo_lexer.cc | 9 ++++- third_party/xla/xla/service/hlo_lexer.h | 15 ++++---- third_party/xla/xla/service/hlo_parser.cc | 15 +++++--- .../xla/xla/service/hlo_parser_test.cc | 11 ++++++ third_party/xla/xla/service/hlo_verifier.cc | 4 +++ third_party/xla/xla/service/hlo_verifier.h | 8 +++++ .../xla/xla/service/hlo_verifier_test.cc | 29 ++++++++++++++++ third_party/xla/xla/shape.cc | 10 ++++++ third_party/xla/xla/shape.h | 13 ++++++- third_party/xla/xla/shape_test.cc | 34 +++++++++++++++++-- third_party/xla/xla/shape_util.cc | 27 +++++++++++---- .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 12 ++++--- .../translate/hlo_to_mhlo/tests/import.hlotxt | 9 +++++ .../translate/mhlo_to_hlo/tests/export.mlir | 14 ++++++++ .../translate/mhlo_to_hlo/type_to_shape.cc | 7 ++-- .../mhlo_to_hlo/type_to_shape_test.cc | 11 ++++-- 19 files changed, 199 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index c458b2c6cd8725..afad8871b399ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1896,7 +1896,8 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { xla::Shape input_shape = xla::TypeToShape(op.getOperand().getType()); - if (input_shape == xla::Shape()) return false; + if (input_shape == xla::Shape() || input_shape.is_unbounded_dynamic()) + return false; xla::Shape start_indices_shape = xla::TypeToShape(op.getStartIndices().getType()); diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 24100c8dbcdc5d..caf77930363679 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -470,6 +470,7 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", + "//xla:status", "@com_google_absl//absl/hash:hash_testing", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 788e7c110ef4e4..6591a9797cee8f 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4894,6 +4894,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/hlo_lexer.cc b/third_party/xla/xla/service/hlo_lexer.cc index 0e53e6f844e99f..bd516129caa7a4 100644 --- a/third_party/xla/xla/service/hlo_lexer.cc +++ b/third_party/xla/xla/service/hlo_lexer.cc @@ -96,6 +96,7 @@ TokKind HloLexer::LexToken() { token_state_.token_start = current_ptr_; int current_char = GetNextChar(); + TokKind tmp; switch (current_char) { default: // [a-zA-Z_] @@ -132,7 +133,11 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexNumberOrPattern(); + tmp = LexNumberOrPattern(); + if (tmp == TokKind::kError && current_char == '?') { + return TokKind::kQuestionMark; + } + return tmp; case '=': return TokKind::kEqual; case '<': @@ -569,6 +574,8 @@ std::string TokKindToString(TokKind kind) { return "kColon"; case TokKind::kAsterisk: return "kAsterisk"; + case TokKind::kQuestionMark: + return "kQuestionMark"; case TokKind::kOctothorp: return "kOctothorp"; case TokKind::kPlus: diff --git a/third_party/xla/xla/service/hlo_lexer.h b/third_party/xla/xla/service/hlo_lexer.h index 031ec1ae295330..5681818c07162c 100644 --- a/third_party/xla/xla/service/hlo_lexer.h +++ b/third_party/xla/xla/service/hlo_lexer.h @@ -39,13 +39,14 @@ enum class TokKind { kError, // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kAsterisk, // * - kOctothorp, // # - kPlus, // + - kTilde, // ~ + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kQuestionMark, // ? + kOctothorp, // # + kPlus, // + + kTilde, // ~ kLsquare, kRsquare, // [ ] kLbrace, diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index d7105c6ddb5be1..223b246b9bbe9b 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -5396,6 +5396,7 @@ bool HloParserImpl::ParseParamList() { // dimension_sizes ::= '[' dimension_list ']' // dimension_list // ::= /*empty*/ +// ::= '?' // ::= <=? int64_t (',' param)* // param ::= name shape bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, @@ -5403,12 +5404,18 @@ bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, auto parse_and_add_item = [&]() { int64_t i; bool is_dynamic = false; - if (lexer_.GetKind() == TokKind::kLeq) { + if (lexer_.GetKind() == TokKind::kQuestionMark) { + i = Shape::kUnboundedSize; is_dynamic = true; lexer_.Lex(); - } - if (!ParseInt64(&i)) { - return false; + } else { + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } + if (!ParseInt64(&i)) { + return false; + } } dimension_sizes->push_back(i); dynamic_dimensions->push_back(is_dynamic); diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 3e069e3371d8e0..3f95e0461e4e1f 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" #include "xla/window_util.h" @@ -4069,6 +4070,16 @@ TEST_F(HloParserTest, ParseShapeStringR2F32) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, ParseShapeStringUnbounded) { + std::string shape_string = "f32[?,784]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { std::string shape_string = "(f32[1572864],s8[5120,1024])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index ab4eb980047004..17bd5c3d1512ee 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -160,6 +160,10 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) { if (arity) { TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); } + if (!opts_.allow_unbounded_dynamism && hlo->shape().is_unbounded_dynamic()) { + return InvalidArgument("Unbounded dynamism is disabled for instruction: %s", + hlo->ToString()); + } return OkStatus(); } diff --git a/third_party/xla/xla/service/hlo_verifier.h b/third_party/xla/xla/service/hlo_verifier.h index 813b3ba30d01b1..29744af10982bb 100644 --- a/third_party/xla/xla/service/hlo_verifier.h +++ b/third_party/xla/xla/service/hlo_verifier.h @@ -91,6 +91,11 @@ struct HloVerifierOpts { return std::move(*this); } + HloVerifierOpts&& WithAllowUnboundedDynamism(bool allow) { + allow_unbounded_dynamism = allow; + return std::move(*this); + } + bool IsLayoutSensitive() const { return layout_sensitive; } bool AllowMixedPrecision() const { return allow_mixed_precision; } @@ -131,6 +136,9 @@ struct HloVerifierOpts { // Whether bitcast should have the same size, including all paddings. bool allow_bitcast_to_have_different_size = false; + // Whether unbounded dynamic sizes should be allowed for shapes. + bool allow_unbounded_dynamism = false; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index fe36e7318ad956..d6faca7cc585ea 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -2933,5 +2934,33 @@ ENTRY entry { TF_ASSERT_OK(status); } +TEST_F(HloVerifierTest, UnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), HasSubstr("Unbounded dynamism is disabled")); +} + +TEST_F(HloVerifierTest, EnableUnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + HloVerifier verifier{HloVerifierOpts{}.WithAllowUnboundedDynamism(true)}; + auto status = verifier.Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 1909291788aa92..0ad4897f320e9f 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -137,6 +137,16 @@ bool Shape::is_static() const { return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } +bool Shape::is_unbounded_dynamic() const { + if (IsTuple() && absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + return subshape.is_unbounded_dynamic(); + })) { + return true; + } + return absl::c_any_of(dimensions_, + [](int64_t dim) { return dim == kUnboundedSize; }); +} + void Shape::DeleteDimension(int64_t dim_to_delete) { CHECK(IsArray()); CHECK_GE(dim_to_delete, 0); diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 214b87a0f3b505..9386fc18043ac7 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SHAPE_H_ #define XLA_SHAPE_H_ -#include +#include #include #include #include @@ -91,6 +91,17 @@ class Shape { bool is_dynamic() const { return !is_static(); } + // Unbounded dynamism. + // If `dimensions(axis) == kUnboundedSize && is_dynamic_dimension(axis)`, + // this means that the axis has unbounded dynamic size. + // The sentinel value for kUnboundedSize is chosen to be exactly the same + // as the sentinel value mlir::ShapedType::kDynamic. + static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + + // Returns true if the shape has one or more dimensions with unbounded sizes. + // Tuple shapes are traversed recursively. + bool is_unbounded_dynamic() const; + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index d691ee64b17079..322f02e4773f67 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -41,11 +41,14 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); const Shape dynamic_matrix_ = ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); + const Shape unbounded_ = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, - tuple_, nested_tuple_, dynamic_matrix_}) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_, + dynamic_matrix_, unbounded_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -83,6 +86,8 @@ TEST_F(ShapeTest, DynamicShapeToString) { array_shape.set_dynamic_dimension(2, false); EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); + + EXPECT_EQ("f32[?,784]", unbounded_.ToString()); } TEST_F(ShapeTest, EqualityTest) { @@ -120,6 +125,28 @@ TEST_F(ShapeTest, IsStatic) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_FALSE(unbounded_.is_static()); +} + +TEST_F(ShapeTest, IsDynamic) { + EXPECT_FALSE(matrix_.is_dynamic()); + EXPECT_FALSE(matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(dynamic_matrix_.is_dynamic()); + EXPECT_FALSE(dynamic_matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(unbounded_.is_dynamic()); + EXPECT_TRUE(unbounded_.is_unbounded_dynamic()); + + Shape unbounded_tuple = tuple_; + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dimensions(1, Shape::kUnboundedSize); + EXPECT_TRUE(unbounded_tuple.is_unbounded_dynamic()); } TEST_F(ShapeTest, IsDynamicDimension) { @@ -133,6 +160,9 @@ TEST_F(ShapeTest, IsDynamicDimension) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_TRUE(unbounded_.is_dynamic_dimension(0)); + EXPECT_FALSE(unbounded_.is_dynamic_dimension(1)); } TEST_F(ShapeTest, ProgramShapeToFromProto) { diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 55882a59e7cbdc..004fd94ea5fcdc 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -248,14 +248,18 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { const int ndims = dimensions.size(); auto layout = shape->mutable_layout(); auto* minor_to_major = layout->mutable_minor_to_major(); + auto is_unbounded_dynamic = absl::c_any_of( + dimensions, [](int64_t dim) { return dim == Shape::kUnboundedSize; }); for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; - if (d < 0) { + if (d < 0 && d != Shape::kUnboundedSize) { return false; } - dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); - if (dense_shape_size < 0) { - return false; + if (!is_unbounded_dynamic) { + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); + if (dense_shape_size < 0) { + return false; + } } shape->add_dimensions(d); @@ -698,9 +702,14 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("["); auto print_one = [&](int i) { if (shape.is_dynamic_dimension(i)) { - printer->Append("<="); + if (shape.dimensions(i) != Shape::kUnboundedSize) { + printer->Append(StrCat("<=", shape.dimensions(i))); + } else { + printer->Append("?"); + } + } else { + printer->Append(shape.dimensions(i)); } - printer->Append(shape.dimensions(i)); }; print_one(0); for (int i = 1, n = shape.dimensions_size(); i < n; ++i) { @@ -926,7 +935,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dimension = shape.dimensions(i); - if (dimension < 0) { + if (dimension < 0 && dimension != Shape::kUnboundedSize) { return InvalidArgument( "shape's dimensions must not be < 0; dimension at index %d was %d", i, dimension); @@ -944,6 +953,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return OkStatus(); } + if (shape.is_unbounded_dynamic()) { + return OkStatus(); + } + int64_t shape_size = [&]() { int64_t dense_shape_size = 1; if (shape.dimensions().empty()) { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index 42682a251385ae..275f59a43ebdba 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -59,22 +59,24 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, ConvertPrimitiveTypeToMLIRType(xla_ty.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); - bool is_dynamic = false; + bool is_bounded_dynamic = false; int64_t rank = xla_ty.rank(); llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); for (int64_t dim = 0; dim < rank; ++dim) { int64_t dim_size = xla_ty.dimensions(dim); if (xla_ty.is_dynamic_dimension(dim)) { - bounds[dim] = dim_size; - is_dynamic = true; + if (dim_size != Shape::kUnboundedSize) { + bounds[dim] = dim_size; + is_bounded_dynamic = true; + } } else { shape[dim] = dim_size; } } using mlir::mhlo::TypeExtensionsAttr; mlir::Attribute encoding; - if (is_dynamic) { + if (is_bounded_dynamic) { encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); } @@ -89,7 +91,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, if (xla_ty.has_layout()) { auto layout = xla_ty.layout(); if (LayoutUtil::IsSparse(layout)) { - if (is_dynamic) + if (is_bounded_dynamic) return Unimplemented( "MHLO doesn't support bounded dynamic shapes for sparse tensors"); llvm::SmallVector dlts; diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 6e8ef58022478a..8344c204ab4f89 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -1838,3 +1838,12 @@ add { %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) } + +// CHECK-LABEL: func.func private @unbounded(%arg0: tensor) -> tensor { +// CHECK-NEXT: [[VAL0:%.*]] = mhlo.abs %arg0 : tensor +// CHECK-NEXT: return [[VAL0]] : tensor +// CHECK-NEXT: } +%unbounded (Arg_0.1: f32[?,784]) -> f32[?,784] { + %Arg_0.1 = f32[?,784] parameter(0) + ROOT %abs.2 = f32[?,784] abs(f32[?,784] %Arg_0.1) +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir index 7cd4212027046d..9bcb03e1ea747a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -3048,3 +3048,17 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x func.func @main(%arg0: tensor {mhlo.parameter_replication = [true]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [false, true]}) -> tensor { return %arg0 : tensor } + +// ----- + +func.func @main(%operand: tensor) -> tensor { + %0 = mhlo.abs %operand : tensor + func.return %0 : tensor +} + +// CHECK: HloModule {{.*}}, entry_computation_layout={(f32[?,784]{1,0})->f32[?,784]{1,0}} +// CHECK-EMPTY: +// CHECK-NEXT: ENTRY {{.*}} ([[ARG0:.*]]: f32[?,784]) -> f32[?,784] { +// CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) +// CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} +// CHECK-NEXT: } diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc index 8ccc406b756828..27fbcb2ad60e85 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -178,12 +178,11 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); for (int64_t dim = 0; dim < rank; ++dim) { - // Only fully static shapes are supported. - // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { - if (bounds[dim] == ShapedType::kDynamic) return {}; - shape[dim] = bounds[dim]; + shape[dim] = bounds[dim] != ShapedType::kDynamic + ? bounds[dim] + : Shape::kUnboundedSize; is_dynamic[dim] = true; } else { if (bounds[dim] != ShapedType::kDynamic) return {}; diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc index 37d82730cb881f..e38dbc355d0426 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -138,8 +138,15 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}, {true, false}) .ToProto())); - // Shape cannot represent dynamic shapes. - // TODO(b/115638799): Update once Shape can support dynamic shapes. + EXPECT_THAT( + TypeToShape(RankedTensorType::get({mlir::ShapedType::kDynamic, 784}, + b.getF32Type())) + .ToProto(), + EqualsProto(ShapeUtil::MakeShape(PrimitiveType::F32, + {Shape::kUnboundedSize, 784}, + {true, false}) + .ToProto())); + EXPECT_THAT(TypeToShape(UnrankedTensorType::get(b.getF32Type())).ToProto(), EqualsProto(Shape().ToProto())); From aaf3af3b3e52d8b72780a63fa9a081456a0596a4 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 16 Nov 2023 15:24:42 -0800 Subject: [PATCH 195/391] Fix symlink in generate_compile_commands.py and only create symlink when necessary PiperOrigin-RevId: 583182954 --- .../xla/build_tools/lint/generate_compile_commands.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/xla/build_tools/lint/generate_compile_commands.py b/third_party/xla/build_tools/lint/generate_compile_commands.py index 702439e418b693..735fc53f8aa8a6 100644 --- a/third_party/xla/build_tools/lint/generate_compile_commands.py +++ b/third_party/xla/build_tools/lint/generate_compile_commands.py @@ -105,10 +105,10 @@ def main(): logging.basicConfig() logging.getLogger().setLevel(logging.INFO) - # Setup external symlink so headers can be found in include paths - logging.info("Symlinking `xla/bazel-xla/external` to `xla/external`") - bazel_xla_external = _XLA_SRC_ROOT / "bazel-xla" / "external" - bazel_xla_external.symlink_to(_XLA_SRC_ROOT / "external") + # Setup external symlink if necessary so headers can be found in include paths + if not (external := _XLA_SRC_ROOT / "external").exists(): + logging.info("Symlinking `xla/bazel-xla/external` to `xla/external`") + external.symlink_to(_XLA_SRC_ROOT / "bazel-xla" / "external") logging.info("Reading `bazel aquery` output from stdin...") parsed_aquery_output = json.loads(sys.stdin.read()) From 046bdd9efc312724b503d74e5a338c2ba9670404 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 15:29:30 -0800 Subject: [PATCH 196/391] Reports the model fingerprint & solution fingerprint to help diagnose issues with determinism. PiperOrigin-RevId: 583184433 --- .../xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding_solver.cc | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 42a95183639812..7dbfa6a114694d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -102,6 +102,7 @@ cc_library( "@com_google_absl//absl/time", "@com_google_ortools//ortools/linear_solver", "@com_google_ortools//ortools/linear_solver:linear_solver_cc_proto", + "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:types", ] + auto_sharding_solver_deps(), diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 80b2ce58957355..e86ca08613a246 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/util.h" +#include "tsl/platform/fingerprint.h" #include "tsl/platform/hash.h" #include "tsl/platform/types.h" #include "ortools/linear_solver/linear_solver.h" @@ -555,8 +556,19 @@ AutoShardingSolverResult SolveAndExtractSolution( return AutoShardingSolverResult(absl::InternalError(err_msg), true); } + // Fingerprint the model & solution (useful when checking for determinism). + // We use TensorFlow's fingerprint library here, which differs from CP-SAT's. + operations_research::MPModelProto model_proto; + solver.ExportModelToProto(&model_proto); + uint64_t model_fprint = tsl::Fingerprint64(model_proto.SerializeAsString()); + operations_research::MPSolutionResponse response; + solver.FillSolutionResponseProto(&response); + uint64_t solution_fprint = tsl::Fingerprint64(response.SerializeAsString()); + LOG(INFO) << "Solver Status: " << status - << " Objective value: " << solver.Objective().Value(); + << " Objective value: " << solver.Objective().Value() + << " Model fingerprint: " << model_fprint + << " Solution fingerprint: " << solution_fprint; if (solver.Objective().Value() >= kInfinityCost) { LOG(WARNING) << "Objective (" << solver.Objective().Value() << ") is larger than kInfinityCost. It means the solver " @@ -566,13 +578,9 @@ AutoShardingSolverResult SolveAndExtractSolution( if (VLOG_IS_ON(10)) { // Print solver information for debugging. This hasn't been useful so far, // so leave it at VLOG level 10. - operations_research::MPModelProto model_proto; - solver.ExportModelToProto(&model_proto); VLOG(10) << "MODEL:"; XLA_VLOG_LINES(10, model_proto.DebugString()); VLOG(10) << "RESPONSE:"; - operations_research::MPSolutionResponse response; - solver.FillSolutionResponseProto(&response); XLA_VLOG_LINES(10, response.DebugString()); } From 508e03cc944fafc4518ebbbe38dc91d2ade20f01 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 15:30:50 -0800 Subject: [PATCH 197/391] Sorts the elements of the liveness (node) set to ensure the consistent creation of identical CP-SAT instances. PiperOrigin-RevId: 583184857 --- .../xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index aa23a2160babdb..3e5411a6a685f9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -4416,6 +4416,7 @@ StatusOr AutoShardingImplementation::RunAutoSharding( strategy_group->GetSubStrategyGroup(index)->node_idx; if (node_idx >= 0) liveness_node_set[t].push_back(node_idx); } + std::sort(liveness_node_set[t].begin(), liveness_node_set[t].end()); } // ----- Call the ILP Solver ----- From f7b098a78825c77e028a3c970da466d9ddafcf8d Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 16 Nov 2023 15:46:59 -0800 Subject: [PATCH 198/391] [PJRT] Add utilities for attaching layout mode info to HLO module. PiperOrigin-RevId: 583189180 --- third_party/xla/xla/pjrt/utils.cc | 76 +++++++++++++++++++ third_party/xla/xla/pjrt/utils.h | 16 ++++ third_party/xla/xla/python/xla_client_test.py | 8 +- 3 files changed, 96 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 3967c91ffe002b..59d12ced99c16b 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -22,12 +22,15 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -260,6 +263,79 @@ StatusOr> GetOutputLayoutModes(mlir::ModuleOp module) { return MlirAttrsToLayoutModes(main.getAllResultAttrs(), main.getNumResults()); } +// Make sure to choose delimiter that will never show up in Layout strings. +static const char* kLayoutModeDelimiter = ";"; + +static std::string GetFrontendAttr(absl::Span layout_modes) { + return absl::StrJoin(layout_modes, kLayoutModeDelimiter, + [](std::string* out, const LayoutMode& mode) { + absl::StrAppend(out, mode.ToString()); + }); +} + +Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, + XlaComputation& xla_computation) { + TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, + GetArgLayoutModes(module)); + TF_ASSIGN_OR_RETURN(std::vector out_layout_modes, + GetOutputLayoutModes(module)); + + // Type is string->string proto map. Using auto here to deal with different + // build environments. + auto& frontend_attrs = *xla_computation.mutable_proto() + ->mutable_frontend_attributes() + ->mutable_map(); + frontend_attrs["arg_layout_modes"] = GetFrontendAttr(arg_layout_modes); + frontend_attrs["out_layout_modes"] = GetFrontendAttr(out_layout_modes); + return OkStatus(); +} + +static StatusOr> GetLayoutModesFromFrontendAttr( + absl::string_view attr) { + // SkipEmpty() needed to avoid returning the empty string when attr is empty. + std::vector str_modes = + absl::StrSplit(attr, kLayoutModeDelimiter, absl::SkipEmpty()); + std::vector result; + for (const std::string& str_mode : str_modes) { + TF_ASSIGN_OR_RETURN(LayoutMode mode, LayoutMode::FromString(str_mode)); + result.emplace_back(std::move(mode)); + } + return result; +} + +static StatusOr> GetLayoutModes( + const XlaComputation& computation, absl::string_view frontend_attr_name, + size_t num_values) { + const auto& frontend_attrs = computation.proto().frontend_attributes().map(); + auto iter = frontend_attrs.find(frontend_attr_name); + if (iter == frontend_attrs.end()) { + // Return all default layouts if frontend attr isn't present. + return std::vector(num_values); + } + return GetLayoutModesFromFrontendAttr(iter->second); +} + +StatusOr> GetArgLayoutModes( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + size_t num_args = program_shape.parameters_size() == 1 && + program_shape.parameters(0).IsTuple() + ? program_shape.parameters(0).tuple_shapes_size() + : program_shape.parameters_size(); + return GetLayoutModes(computation, "arg_layout_modes", num_args); +} + +StatusOr> GetOutputLayoutModes( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + size_t num_outputs = program_shape.result().IsTuple() + ? program_shape.result().tuple_shapes_size() + : 1; + return GetLayoutModes(computation, "out_layout_modes", num_outputs); +} + static StatusOr LayoutModeToXlaShape( const LayoutMode& layout_mode, const Shape& unsharded_shape, const Shape& sharded_shape, diff --git a/third_party/xla/xla/pjrt/utils.h b/third_party/xla/xla/pjrt/utils.h index ae6129b1e94adb..7c423afc6dafb6 100644 --- a/third_party/xla/xla/pjrt/utils.h +++ b/third_party/xla/xla/pjrt/utils.h @@ -55,6 +55,22 @@ StatusOr> GetArgLayoutModes(mlir::ModuleOp module); // LayoutMode::Mode::kDefault. StatusOr> GetOutputLayoutModes(mlir::ModuleOp module); +// Populates the frontend attributes "arg_layout_mode" and "out_layout_mode" in +// xla_computation based on `module`. This function must be called before the +// LayoutMode getters below work correctly on `computation`. +Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, + XlaComputation& xla_computation); +// Returns the LayoutMode for each argument of the computations. Checks for the +// "arg_layout_mode" frontend attribute, and if not present, assumes +// LayoutMode::Mode::kDefault. +StatusOr> GetArgLayoutModes( + const XlaComputation& computation); +// Returns the LayoutMode for each argument of the computations. Checks for the +// "out_layout_mode" frontend attribute, and if not present, assumes +// LayoutMode::Mode::kDefault. +StatusOr> GetOutputLayoutModes( + const XlaComputation& computation); + // Returns (arg shapes, output shape) with properly-set Layouts that can // be passed to XLA to reflect arg_layout_modes and out_layout_modes. StatusOr, Shape>> LayoutModesToXlaShapes( diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 2708311aaab9d6..4bb0c60f97c608 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -599,7 +599,7 @@ def testGetOutputLayouts(self): self.assertEmpty(layouts[1].minor_to_major()) self.assertLen(layouts[2].minor_to_major(), 1) - @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + @unittest.skipIf(pathways, "not implemented") def testSetArgumentLayouts(self): # TODO(b/309682374): implement on CPU and GPU if self.backend.platform != "tpu": @@ -688,7 +688,7 @@ def MakeArg(shape, dtype, layout): self.assertEqual(actual.minor_to_major(), expected.layout().minor_to_major()) - @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + @unittest.skipIf(pathways, "not implemented") def testSetOutputLayouts(self): # TODO(b/309682374): implement on CPU and GPU if self.backend.platform != "tpu": @@ -787,7 +787,7 @@ def SetLayoutsSharded(self): output_layouts[0].minor_to_major(), default_executable.get_output_layouts()[0].minor_to_major()) - @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + @unittest.skipIf(pathways, "not implemented") def testAutoArgumentLayouts(self): # TODO(b/309682374): implement on CPU and GPU if self.backend.platform != "tpu": @@ -832,7 +832,7 @@ def testAutoArgumentLayouts(self): default_executable.get_parameter_layouts()[1].minor_to_major(), ) - @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + @unittest.skipIf(pathways, "not implemented") def testAutoOutputLayouts(self): # TODO(b/309682374): implement on CPU and GPU if self.backend.platform != "tpu": From 6b33f2ad070d74c026365e13fa791eecc942be75 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 16 Nov 2023 15:57:44 -0800 Subject: [PATCH 199/391] Cleanup: Remove unused `_get_min_max_from_calibrator`. It has been previously migrated to `py_function_lib`. PiperOrigin-RevId: 583191898 --- .../tensorflow/python/quantize_model.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 125f51e0d0b9e4..1aaf3d9aa62bab 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -21,9 +21,6 @@ from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stats_pb2 -from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import pywrap_calibration from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset @@ -137,31 +134,6 @@ def _run_static_range_qat( ) -def _get_min_max_from_calibrator( - node_id: bytes, - calib_opts: quant_opts_pb2.CalibrationOptions, -) -> tuple[float, float]: - """Calculate min and max from statistics using calibration options. - - Args: - node_id: bytes of node id. - calib_opts: Calibration options used for calculating min and max. - - Returns: - (min_value, max_value): Min and max calculated using calib_opts. - - Raises: - ValueError: Unsupported calibration method is given. - """ - statistics: calib_stats_pb2.CalibrationStatistics = ( - pywrap_calibration.get_statistics_from_calibrator(node_id) - ) - min_value, max_value = calibration_algorithm.get_min_max_value( - statistics, calib_opts - ) - return min_value, max_value - - def _enable_dump_tensor(graph_def: graph_pb2.GraphDef) -> None: """Enable DumpTensor in the graph def. From 9561130d6bc69369d69014c5905f4a7fc37f21ae Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Thu, 16 Nov 2023 16:00:43 -0800 Subject: [PATCH 200/391] Rollback https://github.com/openxla/xla/commit/15177468f898efe4b27a1ba7d79ccc4557fd48d8. We can now set the CUBLASLT_MATMUL_DESC_FAST_ACCUM flag because Flax and Praxis now set the fast accumulation flag. So this rollbacks the change the disabled setting the flag. Also a comment is added. PiperOrigin-RevId: 583192694 --- third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index a8bc2b00d4c2bc..36f01e48ab8bf0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -211,9 +211,11 @@ cudaDataType_t BlasLt::MatrixLayout::type() const { AsCublasOperation(trans_b))); TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue)); TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi)); - // TODO(b/259609697): Set the CUBLASLT_MATMUL_DESC_FAST_ACCUM attribute if - // enable_fast_accum is true, once Flax/Praxis properly pass a PrecisionConfig - // of HIGH or HIGHEST on the backwards pass. + // The CUBLASLT_MATMUL_DESC_FAST_ACCUM flag only impacts FP8 gemms. It speeds + // up gemms at the expense of accumulation precision. In practice, it is safe + // to set on the forward pass but not the backward pass. + TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + static_cast(enable_fast_accum))); return std::move(desc); } From 7eca2971aa9c6be6490e7c619cfe27b76d1569f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 16:04:37 -0800 Subject: [PATCH 201/391] Re-enable layering_check for target. PiperOrigin-RevId: 583193828 --- third_party/xla/third_party/tsl/tsl/platform/default/BUILD | 1 - .../xla/third_party/tsl/tsl/platform/profile_utils/BUILD | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index e56abd66607093..428af13748f2cb 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -362,7 +362,6 @@ cc_library( "//tsl:with_numa_support": ["TENSORFLOW_USE_NUMA"], "//conditions:default": [], }), - features = ["-layering_check"], tags = [ "manual", "no_oss", diff --git a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD index 1c04a4d2d4a1a7..1c5558dbb9faeb 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/profile_utils/BUILD @@ -44,10 +44,10 @@ cc_library( srcs = [ "android_armv7a_cpu_utils_helper.h", "cpu_utils.cc", - "i_cpu_utils_helper.h", ], hdrs = [ "cpu_utils.h", + "i_cpu_utils_helper.h", ], copts = tsl_copts(), visibility = ["//visibility:public"], From a260900a774404c7392b5bc4fba73416fbd27079 Mon Sep 17 00:00:00 2001 From: Robert David Date: Thu, 16 Nov 2023 16:59:23 -0800 Subject: [PATCH 202/391] Store implicitly if the nonpersistent arena has allocated memory. PiperOrigin-RevId: 583207643 --- tensorflow/lite/arena_planner.cc | 6 +++++- tensorflow/lite/arena_planner.h | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index e3f74acf06cc7f..8fd1a794369b50 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -41,6 +41,7 @@ ArenaPlanner::ArenaPlanner(TfLiteContext* context, : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment, subgraph_index), + has_nonpersistent_memory_(false), persistent_arena_(kDefaultArenaAlignment, subgraph_index), preserve_all_tensors_(preserve_all_tensors), tensor_alignment_(tensor_alignment), @@ -379,6 +380,7 @@ TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { TfLiteStatus ArenaPlanner::ReleaseNonPersistentMemory() { // Clear non-persistent arena's buffer. TF_LITE_ENSURE_STATUS(arena_.ReleaseBuffer()); + has_nonpersistent_memory_ = false; // Set data pointers for all non-persistent tensors to nullptr. TfLiteTensor* tensors = graph_info_->tensors(); for (int i = 0; i < static_cast(graph_info_->num_tensors()); ++i) { @@ -394,6 +396,7 @@ TfLiteStatus ArenaPlanner::AcquireNonPersistentMemory() { // First commit arena_ to allocate underlying buffer. bool reallocated; TF_LITE_ENSURE_STATUS(arena_.Commit(&reallocated)); + has_nonpersistent_memory_ = true; // Resolve allocations for all tensors not on the persistent arena. TfLiteTensor* tensors = graph_info_->tensors(); for (int i = 0; i < static_cast(graph_info_->num_tensors()); ++i) { @@ -406,7 +409,7 @@ TfLiteStatus ArenaPlanner::AcquireNonPersistentMemory() { } bool ArenaPlanner::HasNonPersistentMemory() { - return arena_.GetBufferSize() != 0; + return has_nonpersistent_memory_; } void ArenaPlanner::DumpDebugInfo(const std::vector& execution_plan) const { @@ -424,6 +427,7 @@ void ArenaPlanner::GetAllocInfo(size_t* arena_size, TfLiteStatus ArenaPlanner::Commit(bool* reallocated) { bool arena_reallocated, persistent_arena_reallocated; TF_LITE_ENSURE_STATUS(arena_.Commit(&arena_reallocated)); + has_nonpersistent_memory_ = true; TF_LITE_ENSURE_STATUS( persistent_arena_.Commit(&persistent_arena_reallocated)); *reallocated = arena_reallocated; diff --git a/tensorflow/lite/arena_planner.h b/tensorflow/lite/arena_planner.h index f8547c352a8fc5..f4644d15986fab 100644 --- a/tensorflow/lite/arena_planner.h +++ b/tensorflow/lite/arena_planner.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ARENA_PLANNER_H_ #define TENSORFLOW_LITE_ARENA_PLANNER_H_ +#include #include #include #include @@ -30,7 +31,6 @@ limitations under the License. namespace tflite { constexpr const int kDefaultArenaAlignment = 64; -struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // @@ -141,6 +141,8 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for all temporary and graph outputs // that are declared kTfLiteArenaRw. SimpleMemoryArena arena_; + // True when the arena_ has allocated memory (Commit was called). + bool has_nonpersistent_memory_; // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. From 83862362f3d494452e0e7c21a35f9ce5b874c169 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 16 Nov 2023 17:00:34 -0800 Subject: [PATCH 203/391] [TF:PLUGIN] Fix CPluginOpKernelConstruction::GetInt32AttrList to fill up return vector properly. PiperOrigin-RevId: 583207934 --- .../common_runtime/next_pluggable_device/c_plugin_op_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc index a266fe7bcf8f3a..109d9ed62b95b2 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc @@ -83,7 +83,7 @@ Status CPluginOpKernelConstruction::GetInt32AttrList( &total_size, status); TF_RETURN_IF_ERROR(StatusFromTF_Status(status)); - value->reserve(list_size); + value->resize(list_size); TF_OpKernelConstruction_GetAttrInt32List( ctx_, attr_name.data(), value->data(), /*max_vals=*/list_size, status); From f12c0699eb94d97def21b3b169f862fa716dd203 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 17:13:51 -0800 Subject: [PATCH 204/391] Improve StableHLO Graph Partitioning Algorithm for Quantizer * Add algorithm to trace defining ops of the subgraph's operands and stop when it hits a branch. * Add procedure to duplicate small constants that may be used for shape inference. Some constants may be needed for shape inference for multiple subgraphs. This allows the constants to be partitioned into each subgraphs. * Use SetVector instead of DenseSet so that traversal order is deterministic. PiperOrigin-RevId: 583211319 --- ..._main_function_with_xla_call_module_ops.cc | 122 ++++++++++++++---- ...ain_function_with_xla_call_module_ops.mlir | 94 ++++++++++++++ .../tensorflow/quantize_passes.cc | 3 + 3 files changed, 192 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 829831224a1b35..79adb081e7db3a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -16,11 +16,13 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project @@ -29,6 +31,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" @@ -131,7 +134,7 @@ class LiveOuts { // Delete the current op from liveouts and moves on to the parent ops. void update(Operation& op) { for (Value result_value : op.getResults()) { - liveouts_.erase(result_value); + liveouts_.remove(result_value); } for (Value operand : op.getOperands()) { liveouts_.insert(operand); @@ -142,14 +145,15 @@ class LiveOuts { void snapshot_previous_state() { prev_liveouts_ = liveouts_; } // Return the current live values. - const DenseSet& get() const { return liveouts_; } + const SetVector& get() const { return liveouts_; } // Return the previous live values. - const DenseSet& get_previous() const { return prev_liveouts_; } + const SetVector& get_previous() const { return prev_liveouts_; } private: - DenseSet liveouts_; - DenseSet prev_liveouts_; + // Use SerVector to ensure deterministic traversal order. + SetVector liveouts_; + SetVector prev_liveouts_; }; // Creates the tf.XlaCallModuleOp from attributes. @@ -258,18 +262,18 @@ void ReplaceStablehloOpsWithXlaCallModuleOp( // Contains the actual logic for updating states and replacing StableHLO ops // with tf.XlaCallModuleOps. void UpdateStatesAndReplaceStablehloOps( - const DenseSet& operands, const DenseSet& defined_values, + const SetVector& operands, const SetVector& defined_values, const LiveOuts& liveouts, ModuleOp module_op, ArrayRef reverse_subgraph, const int stablehlo_func_id, func::FuncOp main_func, const bool is_last_subgraph = false) { - DenseSet inputs = operands; + SetVector inputs = operands; for (Value defined_value : defined_values) { - inputs.erase(defined_value); + inputs.remove(defined_value); } - DenseSet outputs = liveouts.get_previous(); + SetVector outputs = liveouts.get_previous(); for (Value live_value : liveouts.get()) { - outputs.erase(live_value); + outputs.remove(live_value); } if (is_last_subgraph) { @@ -277,7 +281,7 @@ void UpdateStatesAndReplaceStablehloOps( // throughout (functions as an invisible op above the very first op that // returns the arguments). for (const BlockArgument arg : main_func.getArguments()) { - outputs.erase(arg); + outputs.remove(arg); } } @@ -305,20 +309,65 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( // statement is not included in any subgraph (e.g. XlaCallModuleOp) and is // untouched. SmallVector reverse_main_func_block_ops; + SetVector ops_to_add; for (Operation& main_func_block_op : llvm::reverse(main_func_block.without_terminator())) { reverse_main_func_block_ops.push_back(&main_func_block_op); + ops_to_add.insert(&main_func_block_op); } // Create a separate subgraph invoked with XlaCallModuleOp per each // set of StableHLO ops in the main func block. SmallVector reverse_subgraph; - DenseSet operands; - DenseSet defined_values; + SetVector operands; + SetVector defined_values; + + // Add op to the subgraph. + auto add_to_subgraph = [&](Operation* op) { + // Move on to the parent ops. + liveouts.update(*op); + ops_to_add.remove(op); + + if (!IsStablehloOp(op)) { + // Always update the liveouts when the subgraph isn't being continued. + liveouts.snapshot_previous_state(); + return; + } + + reverse_subgraph.push_back(op); + defined_values.insert(op->getResults().begin(), op->getResults().end()); + operands.insert(op->getOperands().begin(), op->getOperands().end()); + }; int stablehlo_func_id = -1; for (Operation* op : reverse_main_func_block_ops) { + if (!ops_to_add.contains(op)) continue; + // When hitting a non-StableHLO op, i.e. tf.CustomAggregatorOp, start + // recursively tracing defining ops of the current subgraph's operands. This + // makes sure that all dependencies needed for shape inference are included + // in the subgraph. Tracing stops when hitting a non-StableHLO ops or an op + // with multiple uses. In case of the latter scenario, we have to stop + // because otherwise other users of the op will become dangling references. + // TODO: b/311239049 - Consider rewrite this using BFS. if (!IsStablehloOp(op)) { + bool should_add_op = true; + while (should_add_op) { + should_add_op = false; + Operation* defining_op = nullptr; + for (Value v : operands) { + if (defined_values.contains(v)) continue; + // Check if op has branch and skip if so. + if (v.getDefiningOp() && IsStablehloOp(v.getDefiningOp()) && + v.getDefiningOp()->hasOneUse()) { + defining_op = v.getDefiningOp(); + should_add_op = true; + break; + } + } + if (should_add_op) { + add_to_subgraph(defining_op); + } + } // Create an XlaCallModuleOp if reverse_subgraph isn't empty. if (!reverse_subgraph.empty()) { UpdateStatesAndReplaceStablehloOps(operands, defined_values, liveouts, @@ -331,20 +380,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( defined_values.clear(); } } - - // Move on to the parent ops. - liveouts.update(*op); - - if (!IsStablehloOp(op)) { - // Always update the liveouts when the subgraph isn't being continued. - liveouts.snapshot_previous_state(); - continue; - } - - reverse_subgraph.push_back(op); - - defined_values.insert(op->getResults().begin(), op->getResults().end()); - operands.insert(op->getOperands().begin(), op->getOperands().end()); + add_to_subgraph(op); } // Create the last subgraph if it isn't empty. @@ -355,6 +391,37 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( } } +// Duplicate small constants for each use. +// +// In the subsequent graph partitioning, constants for shape inference need to +// be in the same subgraph. But graph partitioning stops at ops with multiple +// uses. So here we duplicate small constants for each use so that if a +// constant is useful for shape inference for multiple subgraphs, they can be +// included in each subgraphs. If duplicate constants are accidentally created +// in the same subgraph, they can be easily removed with a canonicalizer pass. +// +// We set a size limit since constants needed for shape inference are no +// larger than tensor rank. This avoids duplicating large constants. +void DuplicateSmallConstantOps(ModuleOp module_op, func::FuncOp main_func) { + OpBuilder builder(main_func.getContext()); + for (auto constant_op : + main_func.getBody().getOps()) { + builder.setInsertionPointAfter(constant_op); + if (constant_op.getResult().use_empty() || + constant_op.getResult().hasOneUse()) + continue; + // Do not duplicate constant op if the size is too large. + // 32 is chosen to be larger than all constants useful for shape references, + // while not too large to possibly significantly increase model size. + if (constant_op.getValue().getNumElements() > 32) continue; + while (!constant_op.getResult().hasOneUse()) { + auto new_constant_op = builder.clone(*constant_op.getOperation()); + constant_op.getResult().getUses().begin()->assign( + dyn_cast(new_constant_op)); + } + } +} + void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: runOnOperation() { ModuleOp module_op = getOperation(); @@ -362,6 +429,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: func::FuncOp main_func = GetMainFunc(module_op); if (!main_func) return; + DuplicateSmallConstantOps(module_op, main_func); ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps(module_op, main_func); // TODO - b/298966126: Currently quantizable functions are identified in TF diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 60454fcde06389..036f0709611bf2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -163,3 +163,97 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %0 : tensor<1x3xf32> } } + +// ----- + +// Tests where StableHLO graph in main has a small constant to be duplicated. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1() -> tensor<1024x3xf32> attributes {_from_xla_call_module} + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: %[[INPUT1:.*]]: tensor<1024x3xf32>, %[[INPUT2:.*]]: tensor<1024x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[CONSTANT2]] : tensor<1024x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[INPUT1]], %[[INPUT2]] : tensor<1024x3xf32> + // CHECK: return %[[ADD]], %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %2 = "tf.XlaCallModule"(%1, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %3 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %5 = stablehlo.add %3, %4 : tensor<1024x3xf32> + %6 = stablehlo.multiply %3, %0 : tensor<1024x3xf32> + return %5, %6 : tensor<1024x3xf32>, tensor<1024x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%arg0) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]]#0, %[[SUBGRAPH_2]]#1 + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has branches. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%[[INPUT:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[CONSTANT1]], %[[INPUT]] : tensor<3x3xf32> + // CHECK: return %[[ADD:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: (%[[INPUT1:.*]]: tensor<3x3xf32>, %[[INPUT2:.*]]: tensor<3x3xf32>) + // CHECK-SAME: -> tensor<3x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[INPUT2]] : tensor<3x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[CONSTANT2]] : tensor<3x3xf32> + // CHECK: return %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<3x3xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<3x3xf32> {tf_saved_model.index_path = ["output1"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + %1 = stablehlo.add %0, %arg0 : tensor<3x3xf32> + %2 = "tf.CustomAggregator"(%1) {calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + %3 = "tf.XlaCallModule"(%2, %2) {Sout = [#tf_type.shape<3x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + %5 = stablehlo.add %4, %1 : tensor<3x3xf32> + %6 = stablehlo.multiply %5, %0 : tensor<3x3xf32> + return %6 : tensor<3x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[INPUT:.*]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]] = "tf.CustomAggregator"(%[[SUBGRAPH_1]]) <{id = "0"}> {calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<3x3xf32>) -> tensor<3x3xf32> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0b9cdc09ca5b93..497ae346bfb5eb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -70,6 +70,9 @@ void AddCallModuleSerializationPasses(mlir::PassManager &pm) { pm.addPass( mlir::quant::stablehlo:: createReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass()); + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass may create + // duplicate constants. Add canonicalizer to deduplicate. + pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateXlaCallModuleSerializationPass()); } } // namespace From 7ed714571144c2cc53426d138b2ab74086ad8c08 Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Thu, 16 Nov 2023 17:31:12 -0800 Subject: [PATCH 205/391] Refactor assert_no_new_pyobjects_executing_eagerly into a decorator factory. Previously, `test_util.assert_no_new_pyobjects_executing_eagerly` could be used either as a decorator or as a function which returned a decorator. This change removes the (optional) `func` argument from the function and changes the behavior such that it returns a decorator in all cases. This effectively transforms the function into a pure decorator factory. This change also updates uses of the decorator to all use the decorator factory form (`@assert_no_new_pyobjects_executing_eagerly()`). PiperOrigin-RevId: 583214989 --- .../function/trace_type/trace_type_test.py | 12 ++-- tensorflow/python/eager/backprop_test.py | 6 +- tensorflow/python/eager/forwardprop_test.py | 48 +++++++------- .../python/eager/memory_tests/memory_test.py | 2 +- .../polymorphic_function_test.py | 2 +- .../tracing_compilation_test.py | 2 +- tensorflow/python/eager/tensor_test.py | 16 ++--- tensorflow/python/framework/test_util.py | 24 ++++--- tensorflow/python/framework/test_util_test.py | 4 +- .../kernel_tests/array_ops/array_ops_test.py | 4 +- .../array_ops/constant_op_test.py | 2 +- .../python/kernel_tests/nn_ops/losses_test.py | 4 +- .../python/kernel_tests/nn_ops/rnn_test.py | 2 +- .../summary_ops/summary_ops_test.py | 2 +- tensorflow/python/saved_model/load_test.py | 2 +- tensorflow/python/util/nest_test.py | 66 +++++++++---------- 16 files changed, 98 insertions(+), 100 deletions(-) diff --git a/tensorflow/core/function/trace_type/trace_type_test.py b/tensorflow/core/function/trace_type/trace_type_test.py index 0ef6e8d8d75adc..3e9c7dbe06b05a 100644 --- a/tensorflow/core/function/trace_type/trace_type_test.py +++ b/tensorflow/core/function/trace_type/trace_type_test.py @@ -439,29 +439,29 @@ def testDictofTensorSpecs(self): class TraceTypeMemoryTest(test.TestCase): - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testGeneric(self): trace_type.from_value(1) trace_type.from_value(DummyGenericClass()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTensor(self): tensor = array_ops.zeros([10]) trace_type.from_value(tensor) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTuple(self): trace_type.from_value((1, 2, 3)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDict(self): trace_type.from_value({1: 1, 2: 2, 3: 3}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testList(self): trace_type.from_value([1, 2, 3]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrs(self): trace_type.from_value(TestAttrsClass(1, 2)) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index c4dc6d228c9bf3..a81fb37b013616 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -447,7 +447,7 @@ def testTapeNoOpGradient2By2(self): self.assertAllEqual(dy_dy.numpy(), constant_op.constant(1.0, shape=[2, 2]).numpy()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTapeNoOpGradientMultiTarget2By2(self): a_2_by_2 = constant_op.constant(2.0, shape=[2, 2]) with backprop.GradientTape(persistent=True) as tape: @@ -1648,7 +1648,7 @@ def grad_fn(x): self.assertIn('gradient_tape/my_scope/', op.name) self.assertEqual(num_sin_ops_found, 2) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecomputeGradWithDifferentShape(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -1681,7 +1681,7 @@ def outer_dict(x): self.assertAllEqual(y[1], 2.0) @parameterized.parameters([(True), (False)]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecomputeGradWithNestedFunctionAndWhileLoop(self, reduce_retracing): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 70f6e0e90877b5..82cac2d18a53c8 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -336,7 +336,7 @@ def testJVPFunctionUsedByAccumulatorForOps(self): finally: pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFunctionCacheLimited(self): # Every time this loop is executed, it will create a slightly larger Tensor # and push it through Add's gradient. @@ -357,7 +357,7 @@ def testVariableUnwatchedZero(self): self.assertIsNone(acc.jvp(v)) self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFunctionReturnsResource(self): v = variables.Variable([[1.]]) x = constant_op.constant(1.) @@ -371,7 +371,7 @@ def f(a): y, _ = f(x) self.assertAllClose(2., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMultipleWatchesAdd(self): x = constant_op.constant(-2.) with self.assertRaisesRegex(ValueError, "multiple times"): @@ -387,7 +387,7 @@ def testMultipleWatchesAdd(self): self.assertAllClose(24., acc.jvp(x)) self.assertAllClose(24. * 3., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testReenter(self): x = constant_op.constant(-2.) with forwardprop.ForwardAccumulator(x, 1.5) as acc: @@ -403,7 +403,7 @@ def testReenter(self): yy = y * y self.assertAllClose(6. * -8. * 2., acc.jvp(yy)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDeadTensorsJVPCleared(self): x = array_ops.ones([100]) x_weak = weakref.ref(x) @@ -424,14 +424,14 @@ def testDeadTensorsJVPCleared(self): self.assertIsNone(derived_tensor_weak()) self.assertIsNone(derived_tensor_grad_weak()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testJVPManual(self): primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),), (constant_op.constant(0.2),)) self.assertAllClose(math_ops.sin(0.1), primal) self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumericHigherOrder(self): def f(x): @@ -448,7 +448,7 @@ def f(x): satol=1e-3, ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumericHigherOrderFloat64(self): def f(x): @@ -462,7 +462,7 @@ def f(x): [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)], order=3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testCustomGradient(self): @custom_gradient.custom_gradient @@ -475,7 +475,7 @@ def grad(dy): _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) - # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly + # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly() # fails around this test? def testExceptionCustomGradientRecomputeGradForward(self): @@ -563,7 +563,7 @@ def grad(dy): ("Order{}".format(order), order, expected) for order, expected in enumerate(_X11_35_DERIVATIVES) ]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHigherOrderPureForward(self, order, expected): def _forwardgrad(f): @@ -606,7 +606,7 @@ def f(x): self.assertAllClose(3.5 * 2.5 * 1.1**1.5, outer_jvp) self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testJVPPacking(self): two = constant_op.constant(2.) primal_in = constant_op.constant(1.) @@ -688,7 +688,7 @@ def _expected(mat, tangent): self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1)) self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHVPMemory(self): def fun(x): @@ -698,7 +698,7 @@ def fun(x): tangents = constant_op.constant([3., 4., 5.]) _hvp(fun, (primals,), (tangents,)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testHVPCorrectness(self): def fun(x): @@ -725,7 +725,7 @@ def fun(x): self.assertAllClose(backback_hvp, forwardback_hvp_eager) self.assertAllClose(backback_hvp, forwardback_hvp_function) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testShouldRecordAndStopRecord(self): c = constant_op.constant(1.) c_tangent = constant_op.constant(2.) @@ -747,7 +747,7 @@ def testShouldRecordAndStopRecord(self): self.assertIsNone(acc.jvp(d)) self.assertIsNone(tape.gradient(d, c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecordingSelectively(self): c = constant_op.constant(1.) c_tangent = constant_op.constant(2.) @@ -774,7 +774,7 @@ def testRecordingSelectively(self): self.assertIsNone(tape.gradient(d, c)) self.assertAllClose(3., tape.gradient(e, c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOpWithNoTrainableOutputs(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -847,7 +847,7 @@ def testBackwardOverForward(self, forward_prop_first): self.assertTrue(record.should_record_backprop((acc.jvp(d),))) self.assertAllClose(-.1 * math_ops.cos(1.), tape.gradient(acc.jvp(d), c)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testRecordingWithJVPIndices(self): c = constant_op.constant(1.) with forwardprop.ForwardAccumulator(c, 10.) as acc: @@ -861,7 +861,7 @@ def testRecordingWithJVPIndices(self): None, (((0, 1),),)) self.assertAllClose(3., acc.jvp(d)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testSpecialForwardFunctionUsed(self): c = constant_op.constant(1.) d = constant_op.constant(2.) @@ -875,7 +875,7 @@ def testSpecialForwardFunctionUsed(self): lambda x: [x]) self.assertAllClose(-20., acc.jvp(e)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testVariableWatched(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): # TODO(b/264947738) @@ -1015,25 +1015,25 @@ def _fprop_cond(k, y): class ControlFlowTests(test.TestCase): - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOfFunctionWhile(self): y = constant_op.constant(1.) with forwardprop.ForwardAccumulator(y, 1.) as acc: self.assertAllClose(10., acc.jvp(_has_loop(constant_op.constant(5), y))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testOfFunctionCond(self): y = constant_op.constant(1.) with forwardprop.ForwardAccumulator(y, 1.) as acc: self.assertAllClose(3., acc.jvp(_has_cond(constant_op.constant(5), y))) self.assertAllClose(0., acc.jvp(_has_cond(constant_op.constant(0), y))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testInFunctionWhile(self): self.assertAllClose( 10., _fprop_while(constant_op.constant(5), constant_op.constant(1.))) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testInFunctionCond(self): self.assertAllClose( 3., _fprop_cond(constant_op.constant(5), constant_op.constant(1.))) diff --git a/tensorflow/python/eager/memory_tests/memory_test.py b/tensorflow/python/eager/memory_tests/memory_test.py index ee5104ef27b343..3503058b0012cd 100644 --- a/tensorflow/python/eager/memory_tests/memory_test.py +++ b/tensorflow/python/eager/memory_tests/memory_test.py @@ -61,7 +61,7 @@ def graph(x): memory_test_util.assert_no_leak( f, num_iters=1000, increase_threshold_absolute_mb=30) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedFunctionsDeleted(self): @def_function.function diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py index 64aab16798ebf4..663562a347b59f 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py @@ -3833,7 +3833,7 @@ def testMethodReferenceCycles(self): # function itself is not involved in a reference cycle. self.assertIs(None, weak_fn()) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testErrorMessageWhenGraphTensorIsPassedToEager(self): @polymorphic_function.function diff --git a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py index 96ea55beeb8077..42d8091ed960d5 100644 --- a/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py +++ b/tensorflow/python/eager/polymorphic_function/tracing_compilation_test.py @@ -385,7 +385,7 @@ def sum_gather(): expected = self.evaluate(sum_gather()) self.assertAllEqual(expected, self.evaluate(defined())) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testCallOptionsMemory(self): @compiled_fn def model(x): diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index ea5e6006b9fa24..532d7f1555521f 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -86,7 +86,7 @@ def testNumpyValue(self): t = _create_tensor(values) self.assertAllEqual(values, t) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNumpyDtypeSurvivesThroughTensorConversion(self): scalar_creators = [np.int32, np.int64, np.float32, np.float64] conversion_functions = [ops.convert_to_tensor, constant_op.constant] @@ -359,7 +359,7 @@ def testConvertToTensorAllowsOverflow(self): _ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8) @test_util.run_in_graph_and_eager_modes - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testConvertToTensorNumpyZeroDim(self): for np_type, dtype in [(np.int32, dtypes.int32), (np.half, dtypes.half), (np.float32, dtypes.float32)]: @@ -370,7 +370,7 @@ def testConvertToTensorNumpyZeroDim(self): self.assertAllEqual(x, [65, 16]) @test_util.run_in_graph_and_eager_modes - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testConvertToTensorNumpyScalar(self): x = ops.convert_to_tensor([ np.array(321, dtype=np.int64).item(), @@ -422,19 +422,19 @@ def testMemoryviewIsReadonly(self): t = constant_op.constant([0.0]) self.assertTrue(memoryview(t).readonly) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewScalar(self): t = constant_op.constant(42.0) self.assertAllEqual( np.array(memoryview(t)), np.array(42.0, dtype=np.float32)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewEmpty(self): t = constant_op.constant([], dtype=np.float32) self.assertAllEqual(np.array(memoryview(t)), np.array([])) @test_util.run_gpu_only - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMemoryviewCopyToCPU(self): with ops.device("/device:GPU:0"): t = constant_op.constant([0.0]) @@ -620,7 +620,7 @@ def testSliceDimOutOfRange(self): "but tensor at index 2 has rank 0"): pywrap_tfe.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testTensorDir(self): t = array_ops.ones(1) t.test_attr = "Test" @@ -639,7 +639,7 @@ def testNonRectangularPackAsConstant(self): with self.assertRaisesRegex(ValueError, "non-rectangular Python sequence"): constant_op.constant(l) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFloatAndIntAreConvertibleToComplex(self): a = [[1., 1], [1j, 2j]] np_value = np.array(a, dtype=np.complex128) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 7c86ed1709e152..a4085a0f2b7571 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -672,26 +672,27 @@ def wrapper(*args, **kwargs): return wrapper -def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): +def assert_no_new_pyobjects_executing_eagerly( + warmup_iters: int = 2, +) -> Callable[[Callable[..., Any]], Callable[..., None]]: """Decorator for asserting that no new Python objects persist after a test. - Runs the test multiple times executing eagerly, first as a warmup and then to - let objects accumulate. The warmup helps ignore caches which do not grow as - the test is run repeatedly. + Returns a decorator that runs the test multiple times executing eagerly, + first as a warmup and then to let objects accumulate. The warmup helps ignore + caches which do not grow as the test is run repeatedly. Useful for checking that there are no missing Py_DECREFs in the C exercised by a bit of Python. Args: - func: The function to test. warmup_iters: The numer of warmup iterations, excluded from measuring. Returns: - The wrapped function performing the test. + A decorator function which can be applied to the test function. """ - def wrap_f(f): - def decorator(self, *args, **kwargs): + def wrap_f(f: Callable[..., Any]) -> Callable[..., None]: + def decorator(self: "TensorFlowTestCase", *args, **kwargs) -> None: """Warms up, gets object counts, runs the test, checks for new objects.""" with context.eager_mode(): gc.disable() @@ -783,12 +784,9 @@ def decorator(self, *args, **kwargs): "The following objects were newly created: %s" % str(obj_count_by_type)) gc.enable() - return decorator + return tf_decorator.make_decorator(f, decorator) - if func is None: - return wrap_f - else: - return wrap_f(func) + return wrap_f def assert_no_new_tensors(f): diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 69680857a3b037..1407aa328b7056 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -1196,11 +1196,11 @@ def __init__(self, *args, **kwargs): self.accumulation = [] @unittest.expectedFailure - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_has_leak(self): self.accumulation.append([1.]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_has_no_leak(self): self.not_accumulating = [1.] diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 0a3e6a2eb29b0c..9da1992d921b27 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -688,7 +688,7 @@ def testInt64GPU(self): s = array_ops.strided_slice(x, begin, end, strides) self.assertAllEqual([3.], self.evaluate(s)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() @test_util.assert_no_garbage_created def testTensorSliceEagerMemory(self): with context.eager_mode(): @@ -697,7 +697,7 @@ def testTensorSliceEagerMemory(self): # Tests that slicing an EagerTensor doesn't leak memory inputs[0] # pylint: disable=pointless-statement - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() @test_util.assert_no_garbage_created def testVariableSliceEagerMemory(self): if sys.version_info.major == 3 and sys.version_info.minor in (11, 12): diff --git a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py index 5fb4fb659d8f19..55cb3e049c0b32 100644 --- a/tensorflow/python/kernel_tests/array_ops/constant_op_test.py +++ b/tensorflow/python/kernel_tests/array_ops/constant_op_test.py @@ -208,7 +208,7 @@ def testExplicitShapeNumPy(self): shape=[2, 3, 5]) self.assertEqual(c.get_shape(), [2, 3, 5]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerMemory(self): """Tests PyObject refs are managed correctly when executing eagerly.""" constant_op.constant([[1.]]) diff --git a/tensorflow/python/kernel_tests/nn_ops/losses_test.py b/tensorflow/python/kernel_tests/nn_ops/losses_test.py index 7da91f686a849c..b339738b485800 100644 --- a/tensorflow/python/kernel_tests/nn_ops/losses_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/losses_test.py @@ -101,7 +101,7 @@ def testLossWithSampleSpecificWeightsAllZero(self): with self.cached_session(): self.assertAlmostEqual(0.0, self.evaluate(loss), 3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerNoMemoryLeaked(self): # This is a somewhat convoluted way of testing that nothing gets added to # a global collection. @@ -244,7 +244,7 @@ def testAllCorrectInt32Labels(self): self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerNoMemoryLeaked(self): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) diff --git a/tensorflow/python/kernel_tests/nn_ops/rnn_test.py b/tensorflow/python/kernel_tests/nn_ops/rnn_test.py index e517f4ecc8864c..f13a2521d44516 100644 --- a/tensorflow/python/kernel_tests/nn_ops/rnn_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/rnn_test.py @@ -240,7 +240,7 @@ def testUnbalancedOutputIsAccepted(self): self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) self.assertAllEqual(4, state) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testEagerMemory(self): with context.eager_mode(): cell = TensorArrayStateRNNCell() diff --git a/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py index ae41fc42fb0260..dfae887a60341c 100644 --- a/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops/summary_ops_test.py @@ -996,7 +996,7 @@ def testNoMemoryLeak_graphMode(self): with context.graph_mode(), ops.Graph().as_default(): summary_ops.create_file_writer_v2(logdir) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNoMemoryLeak_eagerMode(self): logdir = self.get_temp_dir() with summary_ops.create_file_writer_v2(logdir).as_default(): diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 8f498936328e86..a08d41f5fcf499 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -2970,7 +2970,7 @@ def increment_v(x): # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3 # iterations took hundreds of seconds). It would be really nice to check # allocations at a lower level. - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def test_functions_cleaned(self, use_cpp_bindings): # TODO(b/264869753) Fix SingleCycleTest if use_cpp_bindings: diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 26341624c06619..0378076cba247b 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -154,24 +154,24 @@ class UnsortedSampleAttr(object): field1 = attr.ib() field2 = attr.ib() - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassCustomProtocol(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertIsInstance(mt, CustomNestProtocol) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassIsNested(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertTrue(nest.is_nested(mt)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlatten(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) leaves = nest.flatten(mt) self.assertLen(leaves, 1) self.assertAllEqual(leaves[0], [1]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenUpToCompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -200,7 +200,7 @@ def testDataclassFlattenUpToCompatible(self): ) self.assertAllEqual(flat_path_nested_list, [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenUpToIncompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -239,7 +239,7 @@ def testDataclassFlattenUpToIncompatible(self): shallow_tree=nested_list, input_tree=mt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePathsUpToCompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -271,7 +271,7 @@ def testDataclassFlattenWithTuplePathsUpToCompatible(self): ) self.assertAllEqual(flat_path_nested_list, [[(0, 0), 2]]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePathsUpToIncompatible(self): simple_list = [2] mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -311,7 +311,7 @@ def testDataclassFlattenWithTuplePathsUpToIncompatible(self): shallow_tree=nested_list2, input_tree=nmt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenAndPack(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) leaves = nest.flatten(mt) @@ -319,7 +319,7 @@ def testDataclassFlattenAndPack(self): self.assertIsInstance(reconstructed_mt, MaskedTensor) self.assertEqual(reconstructed_mt, mt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructure(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_doubled = nest.map_structure(lambda x: x * 2, mt) @@ -327,7 +327,7 @@ def testDataclassMapStructure(self): self.assertEqual(mt_doubled.mask, True) self.assertAllEqual(mt_doubled.value, [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithPaths(self): mt = MaskedTensor(mask=False, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -360,7 +360,7 @@ def path_sum(path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], "0/0") self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithTuplePaths(self): mt = MaskedTensor(mask=False, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -395,7 +395,7 @@ def tuple_path_sum(tuple_path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], (0, 0)) self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureUpTo(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -431,7 +431,7 @@ def sum_tensors(*tensors): self.assertEqual(nmt_combined_with_path.value.mask, True) self.assertAllEqual(nmt_combined_with_path.value.value, [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassMapStructureWithTuplePathsUoTo(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt2 = MaskedTensor(mask=True, value=constant_op.constant([2])) @@ -470,7 +470,7 @@ def tuple_path_sum(tuple_path, *tensors): self.assertAllEqual(nmt_combined_with_path.value.value[0], (0, 0)) self.assertAllEqual(nmt_combined_with_path.value.value[1], [9]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassIsNested(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) self.assertTrue(nest.is_nested(mt)) @@ -480,7 +480,7 @@ def testNestedDataclassIsNested(self): ) self.assertTrue(nest.is_nested(nmt)) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassAssertShallowStructure(self): # These assertions are expected to pass: two dataclasses with the same # component size are considered to have the same shallow structure. @@ -535,7 +535,7 @@ def testDataclassAssertShallowStructure(self): shallow_tree=nmt, input_tree=mt, check_types=False ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassGetTraverseShallowStructure(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -568,7 +568,7 @@ def testDataclassGetTraverseShallowStructure(self): self.assertEqual(traverse_result3, False) nest.assert_shallow_structure(traverse_result3, nmt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassFlatten(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -577,7 +577,7 @@ def testNestedDataclassFlatten(self): self.assertLen(leaves, 1) self.assertAllEqual(leaves[0], [1]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassFlattenAndPack(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -587,7 +587,7 @@ def testNestedDataclassFlattenAndPack(self): self.assertIsInstance(reconstructed_mt, NestedMaskedTensor) self.assertEqual(reconstructed_mt, nmt) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testNestedDataclassMapStructure(self): nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( mask=True, inner_value=constant_op.constant([1]) @@ -602,7 +602,7 @@ def testNestedDataclassMapStructure(self): self.assertEqual(mt_doubled.value.mask, expected.value.mask) self.assertAllEqual(mt_doubled.value.value, expected.value.value) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassYieldFlatPaths(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_flat_paths = list(nest.yield_flat_paths(mt)) @@ -626,7 +626,7 @@ def testDataclassYieldFlatPaths(self): ], ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithStringPaths(self): sep = "/" mt = MaskedTensor(mask=True, value=constant_op.constant([1])) @@ -650,7 +650,7 @@ def testDataclassFlattenWithStringPaths(self): self.assertEqual(dict_mt_nmt_flat_paths[1][0], "nmt/0/0") self.assertAllEqual(dict_mt_nmt_flat_paths[1][1], [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassFlattenWithTuplePaths(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) mt_flat_paths = nest.flatten_with_tuple_paths(mt) @@ -671,7 +671,7 @@ def testDataclassFlattenWithTuplePaths(self): self.assertEqual(dict_mt_nmt_flat_paths[1][0], ("nmt", 0, 0)) self.assertAllEqual(dict_mt_nmt_flat_paths[1][1], [2]) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testDataclassListToTuple(self): mt = MaskedTensor(mask=True, value=constant_op.constant([1])) nmt = NestedMaskedTensor.nested_masked_tensor_with_opposite_masks( @@ -690,7 +690,7 @@ def testDataclassListToTuple(self): ) nest.assert_same_structure(results, expected) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrsFlattenAndPack(self): if attr is None: self.skipTest("attr module is unavailable.") @@ -715,7 +715,7 @@ def testAttrsFlattenAndPack(self): {"values": [(1, 2), [3, 4], 5]}, {"values": [PointXY(1, 2), 3, 4]}, ) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAttrsMapStructure(self, values): if attr is None: self.skipTest("attr module is unavailable.") @@ -724,7 +724,7 @@ def testAttrsMapStructure(self, values): new_structure = nest.map_structure(lambda x: x, structure) self.assertEqual(structure, new_structure) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] @@ -761,7 +761,7 @@ def testFlattenAndPack(self): @parameterized.parameters({"mapping_type": collections.OrderedDict}, {"mapping_type": _CustomMapping}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenDictOrder(self, mapping_type): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) @@ -787,7 +787,7 @@ def testPackDictOrder(self, mapping_type): custom_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPackMappingViews(self): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) @@ -806,7 +806,7 @@ def testFlattenAndPackMappingViews(self): Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testFlattenAndPack_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. mess = [ @@ -889,7 +889,7 @@ def testPackSequenceAs_CompositeTensor(self): ValueError, "Structure had 2 atoms, but flat_sequence had 1 items."): nest.pack_sequence_as(val, [val], expand_composites=True) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testIsNested(self): self.assertFalse(nest.is_nested("1234")) self.assertTrue(nest.is_nested([1, 3, [4, 5]])) @@ -942,7 +942,7 @@ def testFlattenDictItems(self, mapping_type): class SameNamedType1(SameNameab): pass - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) @@ -1053,7 +1053,7 @@ def testHeterogeneousComparison(self): nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) @@ -1129,7 +1129,7 @@ def testMapStructure(self): ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name - @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.assert_no_new_pyobjects_executing_eagerly() def testMapStructureWithStrings(self): inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) inp_b = NestTest.ABTuple(a=2, b=(1, 3)) From f93a789fc06399ffa20a279201e7c12ba77a7215 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 17:49:26 -0800 Subject: [PATCH 206/391] Integrate LLVM at llvm/llvm-project@46396108deb2 Updates LLVM usage to match [46396108deb2](https://github.com/llvm/llvm-project/commit/46396108deb2) PiperOrigin-RevId: 583219072 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index f7c7984832623d..ac8776f3fc9e7c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "865f54e501739f382d33866baebfd0f9aaad01bb" - LLVM_SHA256 = "16dc3aa4f7688f11e20d1f506419e99217018aa8b9ae02453d63b95b76541a2a" + LLVM_COMMIT = "46396108deb24564159c441c6e6ebfac26714d7b" + LLVM_SHA256 = "be8a1460e9e8d1eb96eae8065e5d32376f6ed721872033974c7069a35096f9b3" tf_http_archive( name = name, From 456a890cae4824c1da4954be4a460b8b012cb758 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Thu, 16 Nov 2023 17:58:51 -0800 Subject: [PATCH 207/391] Support stablehlo.dot_general -> tfl.fully_connected legalization of quantized simple dot_general op. Follw up cl's will cover the following: - bias and activation fusion - stablehlo.dot_genenral -> tfl.batch_matmul case - stablehlo.convolution -> tfl.conv_2d case PiperOrigin-RevId: 583221525 --- .../uniform-quantized-stablehlo-to-tfl.mlir | 102 ++-- ...uniform_quantized_stablehlo_to_tfl_pass.cc | 460 ++++++++++++++---- .../stablehlo/uniform_quantized_types.cc | 14 + .../stablehlo/uniform_quantized_types.h | 5 + .../stablehlo/uniform_quantized_types_test.cc | 37 ++ 5 files changed, 470 insertions(+), 148 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 706106dad3c24b..7272b5f17301cb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -104,8 +104,8 @@ func.func @uniform_dequantize_op_return_f64(%arg: tensor<2x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer +func.func @convolution_upstream_full_integer(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %1 : tensor<1x3x3x2x!quant.uniform> @@ -123,8 +123,8 @@ func.func @convolution_op(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_non_const_filter +func.func @convolution_upstream_full_integer_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> return %0 : tensor<1x3x3x2x!quant.uniform> } @@ -139,8 +139,8 @@ func.func @convolution_op_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform< // Test that if the window padding contains values of 0, tfl.pad op is not // created and the `padding` attribute is set as "VALID". -// CHECK-LABEL: convolution_op_valid_padding -func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_valid_padding +func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 0], [0, 0]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> return %1 : tensor<1x1x1x2x!quant.uniform> @@ -157,8 +157,8 @@ func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_valid_padding +func.func @convolution_upstream_full_integer_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The `window` attribute is empty. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> @@ -175,8 +175,8 @@ func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { +// CHECK-LABEL: convolution_upstream_full_integer_strides +func.func @convolution_upstream_full_integer_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> // The stride value is explicitly set to [1, 2]. %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 2], pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> @@ -195,8 +195,8 @@ func.func @convolution_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_asym_input +func.func @dot_general_upstream_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -216,8 +216,8 @@ func.func @dot_general_full_integer_asym_input(%arg0: tensor<1x2x3x4x!quant.unif // Test full integer quantized dot_general with symmetric quantized input. -// CHECK-LABEL: dot_general_full_integer_sym_input -func.func @dot_general_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_sym_input +func.func @dot_general_upstream_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -242,8 +242,8 @@ func.func @dot_general_full_integer_sym_input(%arg0: tensor<1x2x3x4x!quant.unifo // are quantized upstream. Other cases should be handled by regular quantized // stablehlo.dot_general case. -// CHECK-LABEL: dot_general_op_i32_output -func.func @dot_general_op_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_i32_output +func.func @dot_general_upstream_full_integer_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -263,8 +263,8 @@ func.func @dot_general_op_i32_output(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_activation_rhs +func.func @dot_general_upstream_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant.uniform>, %arg1: tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -282,8 +282,8 @@ func.func @dot_general_full_integer_activation_rhs(%arg0: tensor<1x2x3x4x!quant. // Test full integer quantized dot_general with adj_x -// CHECK-LABEL: dot_general_full_integer_adj_x -func.func @dot_general_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_adj_x +func.func @dot_general_upstream_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -306,8 +306,8 @@ func.func @dot_general_full_integer_adj_x(%arg0: tensor<1x2x4x3x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_adj_y +func.func @dot_general_upstream_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x5x4xi8>} : () -> tensor<1x2x5x4x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -330,8 +330,8 @@ func.func @dot_general_full_integer_adj_y(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_too_many_batches +func.func @dot_general_upstream_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x!quant.uniform>) -> tensor<1x1x1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x1x1x2x4x5xi8>} : () -> tensor<1x1x1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -352,8 +352,8 @@ func.func @dot_general_full_integer_too_many_batches(%arg0: tensor<1x1x1x2x3x4x! // Test full integer quantized dot_general with too many contracting dimension -// CHECK-LABEL: dot_general_full_integer_too_many_contractions -func.func @dot_general_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_too_many_contractions +func.func @dot_general_upstream_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x4x5xi8>} : () -> tensor<1x2x4x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -374,8 +374,8 @@ func.func @dot_general_full_integer_too_many_contractions(%arg0: tensor<1x2x3x4x // Test full integer quantized dot_general with unsupported contracting dim -// CHECK-LABEL: dot_general_full_integer_wrong_contracting -func.func @dot_general_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_wrong_contracting +func.func @dot_general_upstream_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x4x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) { dot_dimension_numbers = #stablehlo.dot< @@ -397,8 +397,8 @@ func.func @dot_general_full_integer_wrong_contracting(%arg0: tensor<1x2x3x4x!qua // Test full integer quantized dot_general with float operands -// CHECK-LABEL: dot_general_full_integer_float_operands -func.func @dot_general_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { +// CHECK-LABEL: dot_general_upstream_full_integer_float_operands +func.func @dot_general_upstream_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], @@ -418,8 +418,8 @@ func.func @dot_general_full_integer_float_operands(%arg0: tensor<1x2x3x4xf32>, % // Test full integer quantized dot_general with asymmetric weight (rhs). -// CHECK-LABEL: dot_general_full_integer_asym_weight -func.func @dot_general_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_asym_weight +func.func @dot_general_upstream_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> %1 = "stablehlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<1x2x3x4x!quant.uniform>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5x!quant.uniform> return %1 : tensor<1x2x3x5x!quant.uniform> @@ -433,8 +433,8 @@ func.func @dot_general_full_integer_asym_weight(%arg0: tensor<1x2x3x4x!quant.uni // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized, it is converted to `tfl.fully_connected` op. -// CHECK-LABEL: dot_general_per_axis_quantized_filter -func.func @dot_general_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.uniform>) -> tensor<1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> return %1 : tensor<1x2x!quant.uniform> @@ -452,8 +452,8 @@ func.func @dot_general_per_axis_quantized_filter(%arg0: tensor<1x3x!quant.unifor // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dimension, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_with_batch_dim -func.func @dot_general_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x1x3x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x1x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x2x!quant.uniform> return %1 : tensor<1x1x2x!quant.uniform> @@ -468,8 +468,8 @@ func.func @dot_general_per_axis_quantized_filter_with_batch_dim(%arg0: tensor<1x // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has a batch dim > 1, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_multibatch -func.func @dot_general_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<3x3x2xi8>} : () -> tensor<3x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x1x3x!quant.uniform>, tensor<3x3x2x!quant.uniform>) -> tensor<3x1x2x!quant.uniform> return %1 : tensor<3x1x2x!quant.uniform> @@ -484,8 +484,8 @@ func.func @dot_general_per_axis_quantized_filter_multibatch(%arg0: tensor<3x1x3x // Test that when the weight tensor for `stablehlo.dot_general` is per-axis // quantized but has more than one contracting dimension, it is not converted. -// CHECK-LABEL: dot_general_per_axis_quantized_filter_with_multiple_contracting_dims -func.func @dot_general_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { +// CHECK-LABEL: dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims +func.func @dot_general_upstream_full_integer_per_axis_quantized_filter_with_multiple_contracting_dims(%arg0: tensor<1x2x3x!quant.uniform>) -> tensor<1x1x!quant.uniform> { %0 = stablehlo.constant() {value = dense<1> : tensor<1x3x2xi8>} : () -> tensor<1x3x2x!quant.uniform> %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1, 2] x [2, 1] : (tensor<1x2x3x!quant.uniform>, tensor<1x3x2x!quant.uniform>) -> tensor<1x1x!quant.uniform> return %1 : tensor<1x1x!quant.uniform> @@ -494,3 +494,25 @@ func.func @dot_general_per_axis_quantized_filter_with_multiple_contracting_dims( // CHECK: stablehlo.dot_general // CHECK-NOT: tfl.fully_connected // CHECK-NOT: tfl.batch_matmul + +// ----- + +// Test that a simple per-tensor quantized stablehlo.dot_general is properly +// fused with a subsequent requantize (qi32->qi8) op then legalized. +// Supports the following format: (lhs: qi8, rhs: qi8) -> result: qi32 + +// CHECK-LABEL: dot_general_full_integer +// CHECK-SAME: (%[[ARG_1:.*]]: tensor<1x1024x!quant.uniform + func.func @dot_general_full_integer(%arg0: tensor<1x1024x!quant.uniform> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) { + %0 = stablehlo.constant() {value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<1x1024x!quant.uniform>, tensor<1024x3x!quant.uniform:f32, 2.000000e+0:0>>) -> tensor<1x3x!quant.uniform> + %2 = stablehlo.uniform_quantize %1 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + +// CHECK-NOT: stablehlo.dot_general +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x1024x!quant.uniform>, value = dense<1> : tensor<3x1024xi8>} : () -> tensor<3x1024x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform:f32, 2.000000e+00>>, value = dense<0> : tensor<3xi32>} : () -> tensor<3x!quant.uniform:f32, 2.000000e+00>> +// CHECK: "tfl.fully_connected"(%[[ARG_1]], %[[QCONST_0]], %[[QCONST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1024x!quant.uniform>, tensor<3x1024x!quant.uniform>, tensor<3x!quant.uniform:f32, 2.000000e+00>>) -> tensor<1x3x!quant.uniform> +// CHECK-NOT: tfl.batch_matmul diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index cba64e220f82f0..3ba5ad97ad579e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -47,8 +47,13 @@ namespace mlir { namespace odml { namespace { +// TODO: b/311029361: Add e2e test for verifying this legalization once +// StableHLO Quantizer API migration is complete. + +using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI8F32UniformQuantizedType; +using ::mlir::quant::IsSupportedByTfliteQuantizeOrDequantizeOps; using ::mlir::quant::QuantizedType; using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; @@ -63,21 +68,6 @@ class UniformQuantizedStablehloToTflPass void runOnOperation() override; }; -// Determines whether the storage type of a quantized type is supported by -// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. -bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { - if ((storage_type.isSigned() && - !(storage_type.getWidth() == 8 || storage_type.getWidth() == 16)) || - (!storage_type.isSigned() && storage_type.getWidth() != 8)) { - LLVM_DEBUG(llvm::dbgs() - << "Uniform quantize / dequantize op only supports ui8, i8 or " - "i16 for the storage type of uniform quantized type. Got: " - << storage_type << ".\n"); - return false; - } - return true; -} - // Bias scales for matmul-like ops should be input scale * filter scale. Here it // is assumed that the input is per-tensor quantized and filter is per-channel // quantized. @@ -91,6 +81,153 @@ SmallVector GetBiasScales(const double input_scale, return bias_scales; } +// Returns a bias scale for matmul-like ops. Here it is assumed that both input +// and filter are per-tensor quantized. +double GetBiasScale(const double input_scale, const double filter_scale) { + return filter_scale * input_scale; +} + +// Creates a new `tfl.qconst` op for the quantized filter. Transposes the +// filter value from [i, o] -> [o, i]. This is because we assume `[i, o]` +// format for `stablehlo.dot_general` (i.e. contracting dimension == 1) +// whereas `tfl.fully_connected` accepts an OI format. +TFL::QConstOp CreateTflConstOpForFilter( + stablehlo::ConstantOp filter_constant_op, PatternRewriter& rewriter, + bool is_per_axis) { + const auto filter_values = filter_constant_op.getValue() + .cast() + .getValues(); + + ArrayRef filter_shape = + filter_constant_op.getType().cast().getShape(); + + // Reverse the shapes. This makes sense, assuming that the filter tensor has a + // rank of 2 (no batch dimension). + SmallVector new_filter_shape(filter_shape.rbegin(), + filter_shape.rend()); + + // Construct the value array of transposed filter. Assumes 2D matrix. + SmallVector new_filter_values(filter_values.size(), /*Value=*/0); + for (int i = 0; i < filter_shape[0]; ++i) { + for (int j = 0; j < filter_shape[1]; ++j) { + const int old_idx = i * filter_shape[1] + j; + const int new_idx = j * filter_shape[0] + i; + new_filter_values[new_idx] = filter_values[old_idx]; + } + } + + auto new_filter_value_attr_type = RankedTensorType::getChecked( + filter_constant_op.getLoc(), new_filter_shape, + /*elementType=*/rewriter.getI8Type()); + + Type new_filter_quantized_type; + + if (is_per_axis) { + auto filter_quantized_type = filter_constant_op.getResult() + .getType() + .cast() + .getElementType() + .cast(); + + new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( + filter_constant_op.getLoc(), /*flags=*/true, + /*storageType=*/filter_quantized_type.getStorageType(), + /*expressedType=*/filter_quantized_type.getExpressedType(), + /*scales=*/filter_quantized_type.getScales(), + /*zeroPoints=*/filter_quantized_type.getZeroPoints(), + /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } else { + auto filter_quantized_type = filter_constant_op.getResult() + .getType() + .cast() + .getElementType() + .cast(); + new_filter_quantized_type = UniformQuantizedType::getChecked( + filter_constant_op.getLoc(), /*flags=*/true, + /*storageType=*/filter_quantized_type.getStorageType(), + /*expressedType=*/filter_quantized_type.getExpressedType(), + /*scale=*/filter_quantized_type.getScale(), + /*zeroPoint=*/filter_quantized_type.getZeroPoint(), + /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } + + // Required because the quantized dimension is changed from 3 -> 0. + auto new_filter_result_type = RankedTensorType::getChecked( + filter_constant_op.getLoc(), /*shape=*/new_filter_shape, + /*type=*/new_filter_quantized_type); + + auto new_filter_constant_value_attr = + DenseIntElementsAttr::get(new_filter_value_attr_type, new_filter_values); + return rewriter.create( + filter_constant_op.getLoc(), + /*output=*/TypeAttr::get(new_filter_result_type), + /*value=*/new_filter_constant_value_attr); +} + +// Creates a new `tfl.qconst` op for the bias. The bias values are 0s, because +// this bias a dummy bias (note that bias fusion is not considered for this +// transformation). The quantization scale for the bias is input scale * +// filter scale. `filter_const_op` is used to retrieve the filter scales and +// the size of the bias constant. +// TODO - b/309896242: Support bias fusion legalization. +TFL::QConstOp CreateTflConstOpForDummyBias(const Location loc, + const double input_scale, + TFL::QConstOp filter_const_op, + PatternRewriter& rewriter, + bool is_per_axis) { + const ArrayRef filter_shape = + filter_const_op.getResult().getType().getShape(); + + Type bias_quantized_type; + if (is_per_axis) { + const auto filter_quantized_element_type = + filter_const_op.getResult() + .getType() + .getElementType() + .cast(); + + // The storage type is i32 for bias, which is the precision used for + // accumulation. + bias_quantized_type = UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), + /*expressedType=*/rewriter.getF32Type(), /*scales=*/ + GetBiasScales(input_scale, filter_quantized_element_type.getScales()), + /*zeroPoints=*/filter_quantized_element_type.getZeroPoints(), + /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } else { + const auto filter_quantized_element_type = + filter_const_op.getResult() + .getType() + .getElementType() + .cast(); + + // The storage type is i32 for bias, which is the precision used for + // accumulation. + bias_quantized_type = UniformQuantizedType::getChecked( + loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), + /*expressedType=*/rewriter.getF32Type(), /*scale=*/ + GetBiasScale(input_scale, filter_quantized_element_type.getScale()), + /*zeroPoint=*/filter_quantized_element_type.getZeroPoint(), + /*storageTypeMin=*/llvm::minIntN(8), + /*storageTypeMax=*/llvm::maxIntN(8)); + } + + SmallVector bias_shape = {filter_shape[0]}; + auto bias_type = + RankedTensorType::getChecked(loc, bias_shape, bias_quantized_type); + + auto bias_value_type = RankedTensorType::getChecked( + loc, std::move(bias_shape), rewriter.getI32Type()); + auto bias_value = DenseIntElementsAttr::get( + bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + + return rewriter.create( + loc, /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); +} + // stablehlo.uniform_quantize -> tfl.quantize class RewriteUniformQuantizeOp : public OpRewritePattern { @@ -103,10 +240,11 @@ class RewriteUniformQuantizeOp LogicalResult match(stablehlo::UniformQuantizeOp op) const override { const Type input_element_type = op.getOperand().getType().cast().getElementType(); - if (!input_element_type.isa()) { - LLVM_DEBUG(llvm::dbgs() - << "Uniform quantize op's input should be a float type. Got: " - << input_element_type << ".\n"); + if (!(input_element_type.isa() || + IsI32F32UniformQuantizedType(input_element_type))) { + LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " + "float type or int32. Got: " + << input_element_type << ".\n"); return failure(); } @@ -775,7 +913,7 @@ class RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp // * Does not consider bias add fusion. // // TODO: b/294983811 - Merge this pattern into -// `RewriteUpstreamQuantizedDotGeneralOpToBatchMatmulOp`. +// `RewriteFullIntegerQuantizedDotGeneralOp`. // TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands // is not specified in the StableHLO dialect. Update the spec to allow this. class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp @@ -826,15 +964,17 @@ class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp cast(op.getOperand(1).getDefiningOp()); TFL::QConstOp new_filter_constant_op = - CreateTflConstOpForFilter(filter_constant_op, rewriter); + CreateTflConstOpForFilter(filter_constant_op, rewriter, + /*is_per_axis=*/true); const Value input_value = op.getOperand(0); const double input_scale = input_value.getType() .cast() .getElementType() .cast() .getScale(); - TFL::QConstOp bias_constant_op = CreateTflConstOpForBias( - op.getLoc(), input_scale, new_filter_constant_op, rewriter); + TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + /*is_per_axis=*/true); const Value result_value = op.getResult(); // Set to `nullptr` because this attribute only matters when the input is @@ -921,106 +1061,208 @@ class RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp return success(); } +}; - // Creates a new `tfl.qconst` op for the quantized filter. Transposes the - // filter value from [i, o] -> [o, i]. This is because we assume `[i, o]` - // format for `stablehlo.dot_general` (i.e. contracting dimension == 1) - // whereas `tfl.fully_connected` accepts an OI format. - TFL::QConstOp CreateTflConstOpForFilter( - stablehlo::ConstantOp filter_constant_op, - PatternRewriter& rewriter) const { - const auto filter_values = filter_constant_op.getValue() - .cast() - .getValues(); +// Rewrites `stablehlo.dot_general` to `tfl.fully_connected` or +// `tfl.batch_matmul` when it accepts uniform quantized tensors. +// +// Conditions for `tfl.fully_connected` conversion: +// * Input and output tensors are per-tensor uniform quantized (i8->f32) +// tensors. +// * The filter tensor is constant a per-tensor uniform quantized (i8->f32) +// tensor. The quantization dimension should be 1 (the non-contracting +// dimension). +// * The input tensor's rank is either 2 or 3. The last dimension of the input +// tensor should be the contracting dimension, i.e. [..., c_x, r_x]. +// * The filter tensor's rank is 2. The contracting dimension should be the +// first dimension (dim 0), i.e. [c_y, r_y] where c_y == r_x. +// * Does not consider activation fusion. +// * Does not consider bias add fusion. +// TODO: b/580909703 - Include conversion conditions for `tfl.batch_matmul` op. +// +// TODO: b/295264927 - `stablehlo.dot_general` with per-axis quantized operands +// is not specified in the StableHLO dialect. Update the spec to allow this. +class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - ArrayRef filter_shape = - filter_constant_op.getType().cast().getShape(); + public: + LogicalResult match(stablehlo::DotGeneralOp op) const override { + const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = + op.getDotDimensionNumbers(); + if (const int num_rhs_contracting_dims = + dot_dimension_nums.getRhsContractingDimensions().size(); + num_rhs_contracting_dims != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Expected number of contracting dimensions to be 1. Got: " + << num_rhs_contracting_dims << ".\n"); + return failure(); + } - // Reverse the shapes. This makes sense because it assumes that the filter - // tensor has rank of 2 (no batch dimension). - SmallVector new_filter_shape(filter_shape.rbegin(), - filter_shape.rend()); + if (failed(MatchInput(op.getOperand(0)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input for quantized dot_general op.\n"); + return failure(); + } - // Construct the value array of transposed filter. Assumes 2D matrix. - SmallVector new_filter_values(filter_values.size(), /*Value=*/0); - for (int i = 0; i < filter_shape[0]; ++i) { - for (int j = 0; j < filter_shape[1]; ++j) { - const int old_idx = i * filter_shape[1] + j; - const int new_idx = j * filter_shape[0] + i; - new_filter_values[new_idx] = filter_values[old_idx]; - } + if (failed(MatchFilter(op.getOperand(1)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match filter for quantized dot_general op.\n"); + return failure(); } - auto new_filter_value_attr_type = RankedTensorType::getChecked( - filter_constant_op.getLoc(), new_filter_shape, - /*elementType=*/rewriter.getI8Type()); + if (failed(MatchOutput(op.getResult()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized dot_general op.\n"); + return failure(); + } - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() - .cast(); + if (failed(MatchUsers(op.getResult()))) { + LLVM_DEBUG(llvm::dbgs() << "Failed to match subsequent requantize for " + "quantized dot_general op.\n"); + return failure(); + } - auto new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( - filter_constant_op.getLoc(), /*flags=*/true, - /*storageType=*/filter_quantized_type.getStorageType(), - /*expressedType=*/filter_quantized_type.getExpressedType(), - /*scales=*/filter_quantized_type.getScales(), - /*zeroPoints=*/filter_quantized_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + return success(); + } - // Required because the quantized dimension is changed from 3 -> 0. - auto new_filter_result_type = RankedTensorType::getChecked( - filter_constant_op.getLoc(), /*shape=*/new_filter_shape, - /*type=*/new_filter_quantized_type); + void rewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + // Create the new filter constant - transpose filter value + // from [i, o] -> [o, i]. This is because we assume `[i, o]` format for + // `stablehlo.dot_general` (i.e. contracting dimension == 1) whereas + // `tfl.fully_connected` accepts an OI format. + auto filter_constant_op = + cast(op.getOperand(1).getDefiningOp()); - auto new_filter_constant_value_attr = DenseIntElementsAttr::get( - new_filter_value_attr_type, new_filter_values); - return rewriter.create( - filter_constant_op.getLoc(), - /*output=*/TypeAttr::get(new_filter_result_type), - /*value=*/new_filter_constant_value_attr); + TFL::QConstOp new_filter_constant_op = CreateTflConstOpForFilter( + filter_constant_op, rewriter, /*is_per_axis=*/false); + const Value input_value = op.getOperand(0); + const double input_scale = input_value.getType() + .cast() + .getElementType() + .cast() + .getScale(); + TFL::QConstOp bias_constant_op = CreateTflConstOpForDummyBias( + op.getLoc(), input_scale, new_filter_constant_op, rewriter, + /*is_per_axis=*/false); + + auto output_op = op.getResult().getDefiningOp(); + Operation* requantize_op = *output_op->getResult(0).getUsers().begin(); + Operation* dequantize_op = *requantize_op->getResult(0).getUsers().begin(); + + // Set to `nullptr` because this attribute only matters when the input is + // dynamic-range quantized. + const BoolAttr asymmetric_quantize_inputs = nullptr; + auto tfl_fully_connected_op = rewriter.create( + op.getLoc(), + /*output=*/ + requantize_op->getResult(0).getType(), // result_value.getType(), + /*input=*/input_value, /*filter=*/new_filter_constant_op.getResult(), + /*bias=*/bias_constant_op.getResult(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + asymmetric_quantize_inputs); + + auto tfl_dequantize_op = rewriter.create( + op.getLoc(), dequantize_op->getResult(0).getType(), + tfl_fully_connected_op->getResult(0)); + + rewriter.replaceAllUsesWith(dequantize_op->getResult(0), + tfl_dequantize_op->getResult(0)); + + rewriter.replaceAllUsesWith(op.getResult(), + tfl_fully_connected_op.getResult(0)); + + rewriter.eraseOp(op); } - // Creates a new `tfl.qconst` op for the bias. The bias values are 0s, because - // this bias a dummy bias (note that bias fusion is not considered for this - // transformation). The quantization scale for the bias is input scale * - // filter scale. `filter_const_op` is used to retrieve the filter scales and - // the size of the bias constant. - TFL::QConstOp CreateTflConstOpForBias(const Location loc, - const double input_scale, - TFL::QConstOp filter_const_op, - PatternRewriter& rewriter) const { - const ArrayRef filter_shape = - filter_const_op.getResult().getType().getShape(); - const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() - .cast(); + private: + static LogicalResult MatchInput(Value input) { + auto input_type = input.getType().cast(); + if (!input_type.hasRank() || + !(input_type.getRank() == 2 || input_type.getRank() == 3)) { + LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " + << input_type << ".\n"); + return failure(); + } - // The storage type is i32 for bias, which is the precision used for - // accumulation. - auto bias_quantized_type = UniformQuantizedPerAxisType::getChecked( - loc, /*flags=*/true, /*storageType=*/rewriter.getI32Type(), - /*expressedType=*/rewriter.getF32Type(), /*scales=*/ - GetBiasScales(input_scale, filter_quantized_element_type.getScales()), - /*zeroPoints=*/filter_quantized_element_type.getZeroPoints(), - /*quantizedDimension=*/0, /*storageTypeMin=*/llvm::minIntN(8), - /*storageTypeMax=*/llvm::maxIntN(8)); + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); + return failure(); + } - SmallVector bias_shape = {filter_shape[0]}; - auto bias_type = - RankedTensorType::getChecked(loc, bias_shape, bias_quantized_type); + return success(); + } - auto bias_value_type = RankedTensorType::getChecked( - loc, std::move(bias_shape), rewriter.getI32Type()); - auto bias_value = DenseIntElementsAttr::get( - bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + static LogicalResult MatchFilter(Value filter) { + auto filter_type = filter.getType().cast(); + if (!filter_type.hasRank() || filter_type.getRank() != 2) { + LLVM_DEBUG(llvm::dbgs() + << "Filter tensor expected to have a tensor rank of 2. Got: " + << filter_type << ".\n"); + return failure(); + } + + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedType(filter_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized (i8->f32) type. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (Operation* filter_op = filter.getDefiningOp(); + filter_op == nullptr || !isa(filter_op)) { + LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + return failure(); + } + + return success(); + } + + static LogicalResult MatchOutput(Value output) { + const Type output_element_type = + output.getType().cast().getElementType(); + if (!IsI32F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized (i32->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + return success(); + } - return rewriter.create( - loc, /*output=*/TypeAttr::get(bias_type), /*value=*/bias_value); + static LogicalResult MatchUsers(Value output) { + auto output_op = output.getDefiningOp(); + + if (!output_op->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() << "Expected output to be used only once.\n"); + return failure(); + } + // TODO: b/309896242 - Add support for fused op case. + if (Operation* requantize_op = dyn_cast_or_null( + *output_op->getResult(0).getUsers().begin())) { + const Type requantize_element_type = requantize_op->getResult(0) + .getType() + .cast() + .getElementType(); + if (!IsI8F32UniformQuantizedType(requantize_element_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected a quantize (i8->f32) type. Got: " + << requantize_element_type << ".\n"); + return failure(); + } + if (!isa( + *requantize_op->getResult(0).getUsers().begin())) { + LLVM_DEBUG(llvm::dbgs() << "Expected a dequantize type.\n"); + return failure(); + } + } + return success(); } }; @@ -1032,7 +1274,9 @@ void UniformQuantizedStablehloToTflPass::runOnOperation() { patterns.add(&ctx); + RewriteUpstreamQuantizedDotGeneralOpToTflFullyConnectedOp, + RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp>( + &ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc index 5c1d362110799b..f8064220786442 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc @@ -151,5 +151,19 @@ bool IsI32F32UniformQuantizedType(const Type type) { return true; } +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { + if (storage_type.getWidth() == 8 || + (storage_type.isSigned() && storage_type.getWidth() == 16)) { + return true; + } + LLVM_DEBUG(llvm::dbgs() + << "Uniform quantize / dequantize op only supports ui8, i8 or " + "i16 for the storage type of uniform quantized type. Got: " + << storage_type << ".\n"); + return false; +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h index d938c3a235343a..c422439e8472dc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -73,6 +74,10 @@ bool IsI8F32UniformQuantizedPerAxisType(Type type); // 32-bit integer and expressed type is f32. bool IsI32F32UniformQuantizedType(Type type); +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc index ec91997fb9dc14..43b78f505564fb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -349,6 +349,43 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, IsExpressedTypeF32) { EXPECT_TRUE(IsExpressedTypeF32(qi32_per_axis_type)); } +class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public ::testing::Test { + protected: + IsSupportedByTfliteQuantizeOrDequantizeOpsTest() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; + OpBuilder builder_{&ctx_}; +}; + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI8) { + auto qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/true), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi8_type.getStorageType()))); +} + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsI16) { + auto qi16_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(16, /*isSigned=*/true), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi16_type.getStorageType()))); +} + +TEST_F(IsSupportedByTfliteQuantizeOrDequantizeOpsTest, IsUI8) { + auto qi8_type = quant::UniformQuantizedType::get( + /*flags=*/0, builder_.getIntegerType(8, /*isSigned=*/false), + builder_.getF32Type(), /*scale=*/1.0, + /*zeroPoint=*/0, /*storageTypeMin=*/0, /*storageTypeMax=*/255); + EXPECT_TRUE(IsSupportedByTfliteQuantizeOrDequantizeOps( + dyn_cast_or_null(qi8_type.getStorageType()))); +} + } // namespace } // namespace quant } // namespace mlir From 742d37c562f328f6340c92be4075dd5de2f4b235 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Thu, 16 Nov 2023 18:34:54 -0800 Subject: [PATCH 208/391] [xla:gpu] Use xla_gpu_enable_command_buffer flag to control command buffer scheduling pass #6528 PiperOrigin-RevId: 583229032 --- .../service/gpu/command_buffer_scheduling.cc | 35 +++++++++++++------ .../service/gpu/command_buffer_scheduling.h | 4 ++- .../gpu/command_buffer_scheduling_test.cc | 5 ++- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc index 7cf4d4038bd9fa..8c6036c70feae7 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/command_buffer_scheduling.h" #include +#include #include #include @@ -44,11 +45,6 @@ namespace { // category. // 2. Intermediates: Instructions that produce intermediate values that are // used by commands. -bool IsCommand(const HloInstruction* inst) { - // TODO(anlunx): Add support for conditionals and while loops. - return inst->opcode() == HloOpcode::kFusion; -} - bool IsIntermediate(const HloInstruction* inst) { switch (inst->opcode()) { case HloOpcode::kConstant: @@ -79,7 +75,8 @@ constexpr int kMinNumCommands = 2; // subsequences that will be extracted as command buffers. std::vector CommandBufferScheduling::CollectCommandBufferSequences( - const HloInstructionSequence inst_sequence) { + const HloInstructionSequence inst_sequence, + std::function is_command) { struct Accumulator { std::vector sequences; HloInstructionSequence current_seq; @@ -96,10 +93,10 @@ CommandBufferScheduling::CollectCommandBufferSequences( return acc; }; - auto process_instruction = [&start_new_sequence]( + auto process_instruction = [&start_new_sequence, &is_command]( Accumulator* acc, HloInstruction* inst) -> Accumulator* { - if (IsCommand(inst)) { + if (is_command(inst)) { acc->current_seq.push_back(inst); acc->num_commands_in_current_seq += 1; return acc; @@ -239,8 +236,26 @@ StatusOr CommandBufferScheduling::Run( } HloComputation* entry = module->entry_computation(); MoveParametersToFront(entry); - std::vector sequences = - CollectCommandBufferSequences(module->schedule().sequence(entry)); + + absl::flat_hash_set command_types; + for (auto cmd_type_num : + module->config().debug_options().xla_gpu_enable_command_buffer()) { + DebugOptions::CommandBufferCmdType cmd_type = + static_cast(cmd_type_num); + command_types.insert(cmd_type); + } + + std::function is_command = + [&command_types = + std::as_const(command_types)](const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kFusion) { + if (command_types.contains(DebugOptions::FUSION)) return true; + } + return false; + }; + + std::vector sequences = CollectCommandBufferSequences( + module->schedule().sequence(entry), is_command); for (const HloInstructionSequence& seq : sequences) { TF_ASSIGN_OR_RETURN(BuildCommandBufferResult result, diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h index 601e72c9f984bf..ad0844a207c3c7 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ #include +#include #include #include @@ -79,7 +80,8 @@ class CommandBufferScheduling : public HloModulePass { const absl::flat_hash_set& execution_threads) override; static std::vector CollectCommandBufferSequences( - HloInstructionSequence inst_sequence); + HloInstructionSequence inst_sequence, + std::function is_command); static void MoveParametersToFront(HloComputation* computation); struct BuildCommandBufferResult { diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc index 2a556b6f0ef2bf..aa63b7e40d2c25 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc @@ -222,7 +222,10 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { EXPECT_EQ(seq.size(), 10); std::vector command_buffer_sequences = - CommandBufferScheduling::CollectCommandBufferSequences(seq); + CommandBufferScheduling::CollectCommandBufferSequences( + seq, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kFusion; + }); EXPECT_EQ(command_buffer_sequences.size(), 2); std::vector seq_0 = From 094532fc413cc8128eb29faa392c10fa163a9a29 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 16 Nov 2023 19:28:26 -0800 Subject: [PATCH 209/391] Catch all python errors in custom call sharding and convert to FATAL error rather than leaving the error uncaught. PiperOrigin-RevId: 583237551 --- .../xla/xla/python/custom_call_sharding.cc | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/python/custom_call_sharding.cc b/third_party/xla/xla/python/custom_call_sharding.cc index 318164a87375a6..9b5e5e1dfef810 100644 --- a/third_party/xla/xla/python/custom_call_sharding.cc +++ b/third_party/xla/xla/python/custom_call_sharding.cc @@ -186,27 +186,37 @@ class PyCustomCallPartitioner : public CustomCallPartitioner { const HloInstruction* instruction, const HloInstruction* user, const HloSharding& sharding) const override { py::gil_scoped_acquire gil; - // TODO(parkers): expand this API to handle the `user` sharding. - // The user is used when the custom call returns a Tuple and - // the user is a get-tuple-element. In this case we must update only - // part of the sharding spec. - auto result = py::cast(prop_user_sharding_( - sharding, instruction->shape(), - py::bytes(instruction->raw_backend_config_string()))); - return result; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = py::cast(prop_user_sharding_( + sharding, instruction->shape(), + py::bytes(instruction->raw_backend_config_string()))); + return result; + } catch (const pybind11::error_already_set& e) { + LOG(FATAL) << absl::StrFormat("custom_partitioner: %s", e.what()); + } } std::optional InferShardingFromOperands( const HloInstruction* instruction) const override { + std::optional result; std::vector arg_shapes = GetArgShapes(instruction); auto arg_shardings = GetArgShardings(instruction); py::gil_scoped_acquire gil; - auto py_result = infer_sharding_from_operands_( - arg_shapes, arg_shardings, instruction->shape(), - py::bytes(instruction->raw_backend_config_string())); - if (py_result.is_none()) { - return std::nullopt; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, instruction->shape(), + py::bytes(instruction->raw_backend_config_string())); + if (py_result.is_none()) { + return std::nullopt; + } + return py::cast(py_result); + } catch (const pybind11::error_already_set& e) { + LOG(FATAL) << absl::StrFormat("custom_partitioner: %s", e.what()); } - return py::cast(py_result); + return result; } bool IsCustomCallShardable(const HloInstruction* instruction) const override { return true; From 4d494341920578a081ba4a0d6f6c5757d1cac82f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 19:52:23 -0800 Subject: [PATCH 210/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/2e40c24ac028946d3c22289c1e8d7c8dc31204f2. PiperOrigin-RevId: 583241393 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 8fecb534923045..19b22547b3342a 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "a2f7558daa1c026305ba9148bce0d34ed0a83a6e" - TFRT_SHA256 = "6c4d2c1e9835ec186dcd813cfec4d68537d7dbefab42bc0fabf2c0813c0b64e0" + TFRT_COMMIT = "2e40c24ac028946d3c22289c1e8d7c8dc31204f2" + TFRT_SHA256 = "3afbe32abb11e76509f5b0313c0a83ff3be7ea5c39760e74982be7e6d94f9fd1" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 8fecb534923045..19b22547b3342a 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "a2f7558daa1c026305ba9148bce0d34ed0a83a6e" - TFRT_SHA256 = "6c4d2c1e9835ec186dcd813cfec4d68537d7dbefab42bc0fabf2c0813c0b64e0" + TFRT_COMMIT = "2e40c24ac028946d3c22289c1e8d7c8dc31204f2" + TFRT_SHA256 = "3afbe32abb11e76509f5b0313c0a83ff3be7ea5c39760e74982be7e6d94f9fd1" tf_http_archive( name = "tf_runtime", From 557e90f50e96917b0287373c22e869c215ac4d3f Mon Sep 17 00:00:00 2001 From: Yaning Liang Date: Thu, 16 Nov 2023 21:00:36 -0800 Subject: [PATCH 211/391] [xla:gpu] Use xla_gpu_enable_command_buffer flag to control command buffer scheduling pass #6528 PiperOrigin-RevId: 583254995 --- .../service/gpu/command_buffer_scheduling.cc | 35 ++++++------------- .../service/gpu/command_buffer_scheduling.h | 4 +-- .../gpu/command_buffer_scheduling_test.cc | 5 +-- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc index 8c6036c70feae7..7cf4d4038bd9fa 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/command_buffer_scheduling.h" #include -#include #include #include @@ -45,6 +44,11 @@ namespace { // category. // 2. Intermediates: Instructions that produce intermediate values that are // used by commands. +bool IsCommand(const HloInstruction* inst) { + // TODO(anlunx): Add support for conditionals and while loops. + return inst->opcode() == HloOpcode::kFusion; +} + bool IsIntermediate(const HloInstruction* inst) { switch (inst->opcode()) { case HloOpcode::kConstant: @@ -75,8 +79,7 @@ constexpr int kMinNumCommands = 2; // subsequences that will be extracted as command buffers. std::vector CommandBufferScheduling::CollectCommandBufferSequences( - const HloInstructionSequence inst_sequence, - std::function is_command) { + const HloInstructionSequence inst_sequence) { struct Accumulator { std::vector sequences; HloInstructionSequence current_seq; @@ -93,10 +96,10 @@ CommandBufferScheduling::CollectCommandBufferSequences( return acc; }; - auto process_instruction = [&start_new_sequence, &is_command]( + auto process_instruction = [&start_new_sequence]( Accumulator* acc, HloInstruction* inst) -> Accumulator* { - if (is_command(inst)) { + if (IsCommand(inst)) { acc->current_seq.push_back(inst); acc->num_commands_in_current_seq += 1; return acc; @@ -236,26 +239,8 @@ StatusOr CommandBufferScheduling::Run( } HloComputation* entry = module->entry_computation(); MoveParametersToFront(entry); - - absl::flat_hash_set command_types; - for (auto cmd_type_num : - module->config().debug_options().xla_gpu_enable_command_buffer()) { - DebugOptions::CommandBufferCmdType cmd_type = - static_cast(cmd_type_num); - command_types.insert(cmd_type); - } - - std::function is_command = - [&command_types = - std::as_const(command_types)](const HloInstruction* inst) { - if (inst->opcode() == HloOpcode::kFusion) { - if (command_types.contains(DebugOptions::FUSION)) return true; - } - return false; - }; - - std::vector sequences = CollectCommandBufferSequences( - module->schedule().sequence(entry), is_command); + std::vector sequences = + CollectCommandBufferSequences(module->schedule().sequence(entry)); for (const HloInstructionSequence& seq : sequences) { TF_ASSIGN_OR_RETURN(BuildCommandBufferResult result, diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h index ad0844a207c3c7..601e72c9f984bf 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h @@ -16,7 +16,6 @@ limitations under the License. #define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ #include -#include #include #include @@ -80,8 +79,7 @@ class CommandBufferScheduling : public HloModulePass { const absl::flat_hash_set& execution_threads) override; static std::vector CollectCommandBufferSequences( - HloInstructionSequence inst_sequence, - std::function is_command); + HloInstructionSequence inst_sequence); static void MoveParametersToFront(HloComputation* computation); struct BuildCommandBufferResult { diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc index aa63b7e40d2c25..2a556b6f0ef2bf 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc @@ -222,10 +222,7 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { EXPECT_EQ(seq.size(), 10); std::vector command_buffer_sequences = - CommandBufferScheduling::CollectCommandBufferSequences( - seq, [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kFusion; - }); + CommandBufferScheduling::CollectCommandBufferSequences(seq); EXPECT_EQ(command_buffer_sequences.size(), 2); std::vector seq_0 = From cd8c33bc133db2e45a8d1d5fd0b772aff1bfdfb1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 16 Nov 2023 23:40:20 -0800 Subject: [PATCH 212/391] Integrate LLVM at llvm/llvm-project@00da5eb86ed0 Updates LLVM usage to match [00da5eb86ed0](https://github.com/llvm/llvm-project/commit/00da5eb86ed0) PiperOrigin-RevId: 583281718 --- .../bridge/convert-mhlo-quant-to-int.mlir | 13 ++- third_party/llvm/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 51 +++++++++ .../xla/third_party/stablehlo/temporary.patch | 51 +++++++++ .../chlo/sparse_chlo_legalize_to_linalg.mlir | 25 +++-- .../Dialect/mhlo/hlo-legalize-to-linalg.mlir | 5 +- .../mhlo/hlo-legalize-to-stablehlo.mlir | 13 ++- .../Dialect/mhlo/mhlo_ops_prettyprint.mlir | 14 ++- .../Dialect/mhlo/sparse_gendot_lower.mlir | 50 +++++---- .../tests/Dialect/mhlo/sparse_lower.mlir | 106 +++++++++--------- .../tests/Dialect/mhlo/sparse_rewriting.mlir | 63 ++++++----- .../tests/Dialect/mhlo/sparse_transpose.mlir | 12 +- .../mhlo/stablehlo-legalize-to-hlo.mlir | 13 ++- .../translate/hlo_to_mhlo/tests/import.hlotxt | 7 +- 14 files changed, 284 insertions(+), 143 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index d943e9c1b04fdb..3782030c3a314a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -95,12 +95,15 @@ func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> // CHECK-LABEL: func @uniform_quantize_and_dequantize_sparse_tensor_encoding -func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor (d0 : compressed) }>>) -> () { - // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor (d0 : compressed) }>>) -> tensor (d0 : compressed) }>> - %0 = mhlo.uniform_quantize %arg0 : (tensor (d0 : compressed) }>>) -> tensor, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>> - // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor (d0 : compressed) }>>, tensor) -> tensor (d0 : compressed) }>> - %1 = mhlo.uniform_dequantize %0 : (tensor, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor (d0 : compressed) }>> +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor, #SV> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor, tensor) -> tensor + %1 = mhlo.uniform_dequantize %0 : (tensor, #SV>) -> tensor return } diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index ac8776f3fc9e7c..1694b4045c6cc5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "46396108deb24564159c441c6e6ebfac26714d7b" - LLVM_SHA256 = "be8a1460e9e8d1eb96eae8065e5d32376f6ed721872033974c7069a35096f9b3" + LLVM_COMMIT = "00da5eb86ed0b86002b0947643f7da72faa4fd42" + LLVM_SHA256 = "fb2c08c558cb28d16be3d21ecbb600c4a481a5796c985d5b9e677d757b6021c1" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index ba7bcb6b5de29c..c90f940606d5f1 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -3959,6 +3959,57 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> +diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir +--- stablehlo/stablehlo/tests/print_stablehlo.mlir ++++ stablehlo/stablehlo/tests/print_stablehlo.mlir +@@ -1,5 +1,5 @@ +-// RUN: stablehlo-opt %s | FileCheck %s +-// RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s ++// RUN: stablehlo-opt %s --split-input-file | FileCheck %s ++// RUN: stablehlo-opt %s --split-input-file | stablehlo-opt --split-input-file | FileCheck %s + + // CHECK-LABEL: func @zero_input + func.func @zero_input() -> !stablehlo.token { +@@ -291,6 +291,8 @@ + "stablehlo.return"() : () -> () + } + ++// ----- ++ + #CSR = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed) + }> +@@ -299,14 +301,16 @@ + map = (d0, d1) -> (d0 : compressed, d1 : compressed) + }> + ++// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> ++// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> + // CHECK-LABEL: func @encodings + func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, + %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { +- // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) -> tensor<10x20xf32> +- // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> +- // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xf32> +- // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +- // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xcomplex> ++ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> ++ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> ++ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, + tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> + %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, +@@ -316,6 +320,8 @@ + %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> + func.return %0 : tensor<10x20xf32> + } ++ ++// ----- + + func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: tensor<2x2xi8>, %arg3: tensor<2x3xi8>) -> tensor<2x2x3xi32> { + // CHECK: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index f132c3c6a10ce1..ceb3e74472fc74 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -3959,6 +3959,57 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> +diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir +--- stablehlo/stablehlo/tests/print_stablehlo.mlir ++++ stablehlo/stablehlo/tests/print_stablehlo.mlir +@@ -1,5 +1,5 @@ +-// RUN: stablehlo-opt %s | FileCheck %s +-// RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s ++// RUN: stablehlo-opt %s --split-input-file | FileCheck %s ++// RUN: stablehlo-opt %s --split-input-file | stablehlo-opt --split-input-file | FileCheck %s + + // CHECK-LABEL: func @zero_input + func.func @zero_input() -> !stablehlo.token { +@@ -291,6 +291,8 @@ + "stablehlo.return"() : () -> () + } + ++// ----- ++ + #CSR = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed) + }> +@@ -299,14 +301,16 @@ + map = (d0, d1) -> (d0 : compressed, d1 : compressed) + }> + ++// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> ++// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> + // CHECK-LABEL: func @encodings + func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, + %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { +- // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) -> tensor<10x20xf32> +- // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> +- // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xf32> +- // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +- // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xcomplex> ++ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> ++ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> ++ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, + tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> + %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, +@@ -316,6 +320,8 @@ + %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> + func.return %0 : tensor<10x20xf32> + } ++ ++// ----- + + func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: tensor<2x2xi8>, %arg3: tensor<2x3xi8>) -> tensor<2x2x3xi32> { + // CHECK: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h --- stablehlo/stablehlo/transforms/Passes.h +++ stablehlo/stablehlo/transforms/Passes.h diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir index 118e043d484243..1b7602b71dd3a5 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/sparse_chlo_legalize_to_linalg.mlir @@ -4,6 +4,8 @@ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + // CHECK-LABEL: @asinh_scalar( // CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { // CHECK: %[[RESULT:.*]] = chlo.asinh %[[ARG]] : tensor -> tensor @@ -14,13 +16,13 @@ func.func @asinh_scalar(%arg : tensor) -> tensor { } // CHECK-LABEL: @asinh_tensor( -// CHECK-SAME: %[[ARG:.*]]: tensor<10x20xf32, #{{.*}}>) -> -// CHECK-SAME: tensor<10x20xf32, #{{.*}}> { +// CHECK-SAME: %[[ARG:.*]]: tensor<10x20xf32, #[[$CSR]]>) -> +// CHECK-SAME: tensor<10x20xf32, #[[$CSR]]> { // CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : -// CHECK-SAME: tensor<10x20xf32, #{{.*}}> +// CHECK-SAME: tensor<10x20xf32, #[[$CSR]]> // CHECK: %[[VAL:.*]] = linalg.generic // CHECK-SAME: ins(%[[ARG]] : tensor<10x20xf32, -// CHECK-SAME: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) outs(%[[OUT]] +// CHECK-SAME: #sparse>) outs(%[[OUT]] // CHECK: sparse_tensor.unary %{{.*}} : f32 to f32 // CHECK: present = { // CHECK: tensor.from_elements @@ -38,13 +40,12 @@ func.func @asinh_tensor(%arg : tensor<10x20xf32, #CSR>) func.return %result : tensor<10x20xf32, #CSR> } - // CHECK-LABEL: func.func @tan_tensor( -// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<10x20xf32, -// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<10x20xf32, #[[$CSR]] +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]] // CHECK: %[[TMP_1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, -// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, +// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, #[[$CSR]] +// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, #[[$CSR]] // CHECK: ^bb0(%[[TMP_arg1:.*]]: f32, %[[TMP_arg2:.*]]: f32): // CHECK: %[[TMP_2:.*]] = sparse_tensor.unary %[[TMP_arg1]] : f32 to f32 // CHECK: present = { @@ -68,10 +69,10 @@ func.func @tan_tensor(%arg : tensor<10x20xf32, #CSR>) // CHECK-LABEL: func.func @sinh_tensor( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor<10x20xf32, -// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]] // CHECK: %[[TMP_1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, -// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, +// CHECK-SAME: ins(%[[TMP_arg0]] : tensor<10x20xf32, #[[$CSR]] +// CHECK-SAME: outs(%[[TMP_0]] : tensor<10x20xf32, #[[$CSR]] // CHECK: ^bb0(%[[TMP_arg1:.*]]: f32, %[[TMP_arg2:.*]]: f32): // CHECK: %[[TMP_2:.*]] = sparse_tensor.unary %[[TMP_arg1]] : f32 to f32 // CHECK: present = { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index a5deb447d2187a..7726ec47587c3a 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -6130,7 +6130,8 @@ func.func @clamp_complex(%min: tensor<8xcomplex>, } // ----- - +// CHECK: #[[$ST_3D:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }> +// CHECK: #[[$ST_4D:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed) }> // CHECK-LABEL: func @reshape_sparse_encoding // CHECK-PRIMITIVE-LABEL: func @reshape_sparse_encoding @@ -6146,4 +6147,4 @@ func.func @reshape_sparse_encoding(%arg0: tensor<1x49x16xf32, #ST_3D>) -> tensor %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32, #ST_3D>) -> tensor<1x784x1x1xf32, #ST_4D> func.return %0 : tensor<1x784x1x1xf32, #ST_4D> } -// CHECK: tensor.reshape %{{.*}} : (tensor<1x49x16xf32, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>, tensor<4xi64>) -> tensor<1x784x1x1xf32, #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed) }>> +// CHECK: tensor.reshape %{{.*}} : (tensor<1x49x16xf32, #[[$ST_3D]]>, tensor<4xi64>) -> tensor<1x784x1x1xf32, #[[$ST_4D]]> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 6de9c1a64b37b7..25093f046797d1 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1834,13 +1834,20 @@ func.func @type_quantization(%arg0: tensor>, %ar func.return %0 : tensor } +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> // CHECK-LABEL: "type_sparsity" -func.func @type_sparsity(%arg0: tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> { - // CHECK: "stablehlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> - %0 = "mhlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> +func.func @type_sparsity(%arg0: tensor<16xf32, #SV>) -> tensor<16xf32> { + // CHECK: "stablehlo.abs"(%arg0) : (tensor<16xf32, #[[$SV]]>) -> tensor<16xf32> + %0 = "mhlo.abs"(%arg0) : (tensor<16xf32, #SV>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } +// ----- + // AsyncBundle aka !mhlo.async_bundle is unsupported at the moment (see negative test below). func.func @type_token_callee(%arg0: !mhlo.token) -> !mhlo.token { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir index b7d36ebba62c81..7aad7a84cc5155 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir @@ -1,5 +1,5 @@ // RUN: mlir-hlo-opt -split-input-file %s | FileCheck %s -// RUN: mlir-hlo-opt -split-input-file %s | mlir-hlo-opt | FileCheck %s +// RUN: mlir-hlo-opt -split-input-file %s | mlir-hlo-opt -split-input-file | FileCheck %s // ----- @@ -238,14 +238,16 @@ func.func @extensions(%arg0 : tensor (d0 : compressed, d1 : compressed) }> +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> // CHECK-LABEL: func @encodings func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { - // CHECK: %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>, tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>) -> tensor<10x20xf32> - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>> - // CHECK-NEXT: %2 = mhlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>) -> tensor<10x20xf32> - // CHECK-NEXT: %3 = mhlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>> - // CHECK-NEXT: %4 = mhlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xcomplex> + // CHECK: %0 = mhlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> + // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]] + // CHECK-NEXT: %2 = mhlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> + // CHECK-NEXT: %3 = mhlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> + // CHECK-NEXT: %4 = mhlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> %0 = "mhlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> %1 = "mhlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir index 4e3a8dd7569710..ec98a7cc1a4c68 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_gendot_lower.mlir @@ -3,16 +3,20 @@ // RUN: --mhlo-test-lower-general-dot --canonicalize | FileCheck %s #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> -#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> +#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> #COO = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique), d2 : singleton) }> +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> +// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> +// CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique), d2 : singleton) }> + // // Vector-vector gendot. // // CHECK-LABEL: func.func @sparse_vecvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<10xf64, #sparse_tensor.encoding<{{{.*}}}>>, tensor<10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf64, #[[$SV]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<10xf64, #[[$SV]]>) -> tensor { +// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<10xf64, #[[$SV]]>, tensor<10xf64, #[[$SV]]>) -> tensor // CHECK: return %[[DOT]] : tensor // CHECK: } // @@ -32,20 +36,20 @@ func.func @sparse_vecvec(%arg0: tensor<10xf64, #SV>, // Matrix-vector gendot. // // CHECK-LABEL: func.func @sparse_matvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<5xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3xf64> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<3x5xf64, #sparse_tensor.encoding<{{{.*}}}>>, tensor<5xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3xf64> +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5xf64, #[[$DCSR]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<5xf64, #[[$SV]]>) -> tensor<3xf64> { +// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<3x5xf64, #[[$DCSR]]>, tensor<5xf64, #[[$SV]]>) -> tensor<3xf64> // CHECK: return %[[DOT]] : tensor<3xf64> // CHECK: } // -func.func @sparse_matvec(%arg0: tensor<3x5xf64, #CSR>, +func.func @sparse_matvec(%arg0: tensor<3x5xf64, #DCSR>, %arg1: tensor<5xf64, #SV>) -> tensor<3xf64> { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - : (tensor<3x5xf64, #CSR>, + : (tensor<3x5xf64, #DCSR>, tensor<5xf64, #SV>) -> tensor<3xf64> return %0 : tensor<3xf64> } @@ -54,20 +58,20 @@ func.func @sparse_matvec(%arg0: tensor<3x5xf64, #CSR>, // Matrix-matrix gendot, one sparse operand. // // CHECK-LABEL: func.func @sparse_matmat_1s( -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #[[$DCSR]]>, // CHECK-SAME: %[[ARG1:.*]]: tensor<32x64xf64>) -> tensor<16x64xf64> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #sparse_tensor.encoding<{{{.*}}}>>, tensor<32x64xf64>) -> tensor<16x64xf64> +// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #[[$DCSR]]>, tensor<32x64xf64>) -> tensor<16x64xf64> // CHECK: return %[[DOT]] : tensor<16x64xf64> // CHECK: } // -func.func @sparse_matmat_1s(%arg0: tensor<16x32xf64, #CSR>, +func.func @sparse_matmat_1s(%arg0: tensor<16x32xf64, #DCSR>, %arg1: tensor<32x64xf64>) -> tensor<16x64xf64> { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - : (tensor<16x32xf64, #CSR>, + : (tensor<16x32xf64, #DCSR>, tensor<32x64xf64>) -> tensor<16x64xf64> return %0 : tensor<16x64xf64> } @@ -76,22 +80,22 @@ func.func @sparse_matmat_1s(%arg0: tensor<16x32xf64, #CSR>, // Matrix-matrix gendot, everything sparse. // // CHECK-LABEL: func.func @sparse_matmat_as( -// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<16x64xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>, tensor<32x64xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) -> tensor<16x64xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> -// CHECK: return %[[DOT]] : tensor<16x64xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32xf64, #[[$DCSR]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<32x64xf64, #[[$DCSR]]>) -> tensor<16x64xf64, #[[$DCSR]]> { +// CHECK: %[[DOT:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) {precision_config = [#mhlo, #mhlo]} : (tensor<16x32xf64, #[[$DCSR]]>, tensor<32x64xf64, #[[$DCSR]]>) -> tensor<16x64xf64, #[[$DCSR]]> +// CHECK: return %[[DOT]] : tensor<16x64xf64, #[[$DCSR]]> // CHECK: } // -func.func @sparse_matmat_as(%arg0: tensor<16x32xf64, #CSR>, - %arg1: tensor<32x64xf64, #CSR>) -> tensor<16x64xf64, #CSR> { +func.func @sparse_matmat_as(%arg0: tensor<16x32xf64, #DCSR>, + %arg1: tensor<32x64xf64, #DCSR>) -> tensor<16x64xf64, #DCSR> { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} - : (tensor<16x32xf64, #CSR>, - tensor<32x64xf64, #CSR>) -> tensor<16x64xf64, #CSR> - return %0 : tensor<16x64xf64, #CSR> + : (tensor<16x32xf64, #DCSR>, + tensor<32x64xf64, #DCSR>) -> tensor<16x64xf64, #DCSR> + return %0 : tensor<16x64xf64, #DCSR> } // @@ -101,7 +105,7 @@ func.func @sparse_matmat_as(%arg0: tensor<16x32xf64, #CSR>, // // CHECK-LABEL: func.func @sparse_tensor( // CHECK-SAME: %[[ARG0:.*]]: tensor<197x12x64xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<197x768xf32> { +// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #[[$COO]]>) -> tensor<197x768xf32> { // CHECK: %[[R:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) // CHECK: return %[[R]] : tensor<197x768xf32> func.func @sparse_tensor(%arg0: tensor<197x12x64xf32>, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir index 960248ebe8267d..eb5e7d98cdf6ea 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_lower.mlir @@ -22,15 +22,21 @@ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }> +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> +// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }> +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + + // CHECK-LABEL: func @sparse_abs_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #{{.*}}> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}} ins(%[[ARG0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) outs(%[[OUT]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32, #[[$DCSR]]> { +// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$DCSR]]> +// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}} ins(%[[ARG0]] : tensor<10x20xf32, #[[$CSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$DCSR]]>) // CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32): // CHECK: %[[ABS:.*]] = math.absf %[[A]] : f32 // CHECK: linalg.yield %[[ABS]] : f32 -// CHECK: } -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #{{.*}}> +// CHECK: } -> tensor<10x20xf32, #[[$DCSR]]> +// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$DCSR]]> // CHECK: } func.func @sparse_abs_eltwise(%arg0: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> { @@ -40,15 +46,15 @@ func.func @sparse_abs_eltwise(%arg0: tensor<10x20xf32, #CSR>) } // CHECK-LABEL: func @sparse_add_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #{{.*}}> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) outs(%[[OUT]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) { +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32, #[[$CSR]]> { +// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]]> +// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$CSR]]>) { // CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32): // CHECK: %[[ADD:.*]] = arith.addf %[[A]], %[[B]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 -// CHECK: } -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #{{.*}}> +// CHECK: } -> tensor<10x20xf32, #[[$CSR]]> +// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$CSR]]> // CHECK: } func.func @sparse_add_eltwise(%arg0: tensor<10x20xf32, #CSR>, %arg1: tensor<10x20xf32, #DCSR>) @@ -60,15 +66,15 @@ func.func @sparse_add_eltwise(%arg0: tensor<10x20xf32, #CSR>, } // CHECK-LABEL: func @sparse_mul_eltwise( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #{{.*}}>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #{{.*}}>) -> tensor<10x20xf32, #{{.*}}> { -// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #{{.*}}> -// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) outs(%[[OUT]] : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) { +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20xf32, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32, #[[$CSR]]> { +// CHECK: %[[OUT:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #[[$CSR]]> +// CHECK: %[[VAL:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) outs(%[[OUT]] : tensor<10x20xf32, #[[$CSR]]>) { // CHECK: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32): // CHECK: %[[ADD:.*]] = arith.mulf %[[A]], %[[B]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 -// CHECK: } -> tensor<10x20xf32, #{{.*}}> -// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #{{.*}}> +// CHECK: } -> tensor<10x20xf32, #[[$CSR]]> +// CHECK: return %[[VAL:.*]] : tensor<10x20xf32, #[[$CSR]]> // CHECK: } func.func @sparse_mul_eltwise(%arg0: tensor<10x20xf32, #CSR>, %arg1: tensor<10x20xf32, #DCSR>) @@ -80,20 +86,20 @@ func.func @sparse_mul_eltwise(%arg0: tensor<10x20xf32, #CSR>, } // CHECK-LABEL: func @sparse_math( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xf64, #{{.*}}>) -> tensor<10x20x30xf64, #{{.*}}> { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x20x30xf64, #[[$ST]]>) -> tensor<10x20x30xf64, #[[$ST]]> { +// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.absf // CHECK: } -// CHECK: %[[T1:.*]] = linalg.generic {{{.*}}} ins(%[[T0]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T1:.*]] = linalg.generic {{{.*}}} ins(%[[T0]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.expm1 // CHECK: } -// CHECK: %[[T2:.*]] = linalg.generic {{{.*}}} ins(%[[T1]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T2:.*]] = linalg.generic {{{.*}}} ins(%[[T1]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.log1p // CHECK: } -// CHECK: %[[T3:.*]] = linalg.generic {{{.*}}} ins(%[[T2]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T3:.*]] = linalg.generic {{{.*}}} ins(%[[T2]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: arith.negf // CHECK: } -// CHECK: %[[T4:.*]] = linalg.generic {{{.*}}} ins(%[[T3]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T4:.*]] = linalg.generic {{{.*}}} ins(%[[T3]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: sparse_tensor.unary %{{.*}} : f64 to f64 // CHECK: present = { // CHECK: math.copysign @@ -102,22 +108,22 @@ func.func @sparse_mul_eltwise(%arg0: tensor<10x20xf32, #CSR>, // CHECK: absent = { // CHECK: } // CHECK: } -// CHECK: %[[T5:.*]] = linalg.generic {{{.*}}} ins(%[[T4]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T5:.*]] = linalg.generic {{{.*}}} ins(%[[T4]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.sin // CHECK: } -// CHECK: %[[T6:.*]] = linalg.generic {{{.*}}} ins(%[[T5]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T6:.*]] = linalg.generic {{{.*}}} ins(%[[T5]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.sqrt // CHECK: } -// CHECK: %[[T7:.*]] = linalg.generic {{{.*}}} ins(%[[T6]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T7:.*]] = linalg.generic {{{.*}}} ins(%[[T6]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.tanh // CHECK: } -// CHECK: %[[T8:.*]] = linalg.generic {{{.*}}} ins(%[[T7]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T8:.*]] = linalg.generic {{{.*}}} ins(%[[T7]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.ceil // CHECK: } -// CHECK: %[[T9:.*]] = linalg.generic {{{.*}}} ins(%[[T8]] : tensor<10x20x30xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>>) outs +// CHECK: %[[T9:.*]] = linalg.generic {{{.*}}} ins(%[[T8]] : tensor<10x20x30xf64, #[[$ST]]>) outs // CHECK: math.floor // CHECK: } -// CHECK: return %[[T9]] : tensor<10x20x30xf64, #{{.*}}> +// CHECK: return %[[T9]] : tensor<10x20x30xf64, #[[$ST]]> // CHECK: } func.func @sparse_math(%arg0: tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> { %0 = mhlo.abs %arg0 : (tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, #ST> @@ -134,8 +140,8 @@ func.func @sparse_math(%arg0: tensor<10x20x30xf64, #ST>) -> tensor<10x20x30xf64, } // CHECK-LABEL: func @sparse_sign( -// CHECK-SAME: %[[A:.*]]: tensor<100xi32, #{{.*}}>) -> tensor<100xi32> { -// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi32, #{{.*}}>) +// CHECK-SAME: %[[A:.*]]: tensor<100xi32, #[[$SV]]>) -> tensor<100xi32> { +// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi32, #[[$SV]]>) // CHECK: %[[U:.*]] = sparse_tensor.unary %{{.*}} : i32 to i32 // CHECK: present = { // CHECK: arith.cmpi eq @@ -153,8 +159,8 @@ func.func @sparse_sign(%arg0: tensor<100xi32, #SV>) -> tensor<100xi32> { } // CHECK-LABEL: func @sparse_int_abs( -// CHECK-SAME: %[[A:.*]]: tensor<100xi64, #{{.*}}>) -> tensor<100xi64> { -// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi64, #{{.*}}>) +// CHECK-SAME: %[[A:.*]]: tensor<100xi64, #[[$SV]]>) -> tensor<100xi64> { +// CHECK: %[[T:.*]] = linalg.generic {{{.*}}} ins(%[[A]] : tensor<100xi64, #[[$SV]]>) // CHECK: %[[U:.*]] = sparse_tensor.unary %{{.*}} : i64 to i64 // CHECK: present = { // CHECK: arith.cmpi sge @@ -174,8 +180,8 @@ func.func @sparse_int_abs(%arg0: tensor<100xi64, #SV>) -> tensor<100xi64> { } // CHECK-LABEL: func @sparse_reduce( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi64, #{{.*}}>) -> tensor { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10xi64, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi64, #[[$SV]]>) -> tensor { +// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]] : tensor<10xi64, #[[$SV]]>) // CHECK: arith.addi // CHECK: } // CHECK: return %[[T0]] : tensor @@ -191,9 +197,9 @@ func.func @sparse_reduce(%arg0: tensor<10xi64, #SV>) -> tensor { } // CHECK-LABEL: func @sparse_dot( -// CHECK-SAME: %[[ARG0:.*]]: tensor, -// CHECK-SAME: %[[ARG1:.*]]: tensor) -> tensor { -// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor (d0 : compressed) }>>, tensor (d0 : compressed) }>>) +// CHECK-SAME: %[[ARG0:.*]]: tensor, +// CHECK-SAME: %[[ARG1:.*]]: tensor) -> tensor { +// CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK: arith.mulf // CHECK: arith.addf // CHECK: } @@ -211,12 +217,12 @@ func.func @sparse_dot(%arg0: tensor, } // CHECK-LABEL: func @sparse_transpose( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #{{.*}}>) -> tensor<200x100xf64, #{{.*}}> { -// CHECK: %[[T0:.*]] = bufferization.alloc_tensor() : tensor<200x100xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> -// CHECK: %[[T1:.*]] = linalg.generic {{.*}} ins(%[[ARG0]] : tensor<100x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) outs(%[[T0]] : tensor<200x100xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) { +// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$DCSR]]> { +// CHECK: %[[T0:.*]] = bufferization.alloc_tensor() : tensor<200x100xf64, #[[$DCSR]]> +// CHECK: %[[T1:.*]] = linalg.generic {{.*}} ins(%[[ARG0]] : tensor<100x200xf64, #[[$CSR]]>) outs(%[[T0]] : tensor<200x100xf64, #[[$DCSR]]>) { // CHECK: linalg.yield // CHECK: } -// CHECK: return %[[T1]] : tensor<200x100xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> +// CHECK: return %[[T1]] : tensor<200x100xf64, #[[$DCSR]]> // CHECK: } func.func @sparse_transpose(%arg0: tensor<100x200xf64, #CSR>) -> tensor<200x100xf64, #DCSR> { @@ -226,20 +232,20 @@ func.func @sparse_transpose(%arg0: tensor<100x200xf64, #CSR>) } // CHECK-LABEL: func @sparse_expand( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64, #{{.*}}>) -> tensor<10x10xf64, #{{.*}}> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64, #[[$SV]]>) -> tensor<10x10xf64, #[[$CSR]]> { // CHECK: %[[CST:.*]] = arith.constant dense<10> : tensor<2xi64> -// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<100xf64, #{{.*}}>, tensor<2xi64>) -> tensor<10x10xf64, #{{.*}}> -// CHECK: return %[[OUT]] : tensor<10x10xf64, #{{.*}}> +// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<100xf64, #[[$SV]]>, tensor<2xi64>) -> tensor<10x10xf64, #[[$CSR]]> +// CHECK: return %[[OUT]] : tensor<10x10xf64, #[[$CSR]]> func.func @sparse_expand(%arg0: tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> { %0 = "mhlo.reshape"(%arg0) : (tensor<100xf64, #SV>) -> tensor<10x10xf64, #CSR> return %0 : tensor<10x10xf64, #CSR> } // CHECK-LABEL: func @sparse_collapse( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #{{.*}}>) -> tensor<100xf64, #{{.*}}> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<100xf64, #[[$SV]]> { // CHECK: %[[CST:.*]] = arith.constant dense<100> : tensor<1xi64> -// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<10x10xf64, #{{.*}}>, tensor<1xi64>) -> tensor<100xf64, #{{.*}}> -// CHECK: return %[[OUT]] : tensor<100xf64, #{{.*}}> +// CHECK: %[[OUT:.*]] = tensor.reshape %[[ARG0]](%[[CST]]) : (tensor<10x10xf64, #[[$CSR]]>, tensor<1xi64>) -> tensor<100xf64, #[[$SV]]> +// CHECK: return %[[OUT]] : tensor<100xf64, #[[$SV]]> func.func @sparse_collapse(%arg0: tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> { %0 = "mhlo.reshape"(%arg0) : (tensor<10x10xf64, #CSR>) -> tensor<100xf64, #SV> return %0 : tensor<100xf64, #SV> @@ -247,12 +253,12 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #CSR>) -> tensor<100xf64, #S // CHECK-LABEL: func @sparse_tensor_dot( // CHECK-SAME: %[[ARG0:.*]]: tensor<197x12x64xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #{{.*}}>) -> tensor<197x768xf32, #{{.*}}> { +// CHECK-SAME: %[[ARG1:.*]]: tensor<12x64x768xf32, #[[$ST]]>) -> tensor<197x768xf32, #[[$CSR]]> { // CHECK: %[[T0:.*]] = linalg.generic {{{.*}}} ins(%[[ARG0]], %[[ARG1]] : // CHECK: arith.mulf // CHECK: arith.addf // CHECK: } -// CHECK: return %[[T0]] : tensor<197x768xf32, #{{.*}}> +// CHECK: return %[[T0]] : tensor<197x768xf32, #[[$CSR]]> // CHECK: } func.func @sparse_tensor_dot(%arg0: tensor<197x12x64xf32>, %arg1: tensor<12x64x768xf32, #ST>) -> tensor<197x768xf32, #CSR> { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir index d88b9b990a880d..c0386c1de2ac94 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_rewriting.mlir @@ -14,10 +14,15 @@ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> + + // CHECK-LABEL: func @rewrite_unary( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL:.*]] = mhlo.abs %[[ARG0]] : (tensor<100xf64>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64>) -> tensor<100xf64, #[[$SV]]> { +// CHECK: %[[VAL:.*]] = mhlo.abs %[[ARG0]] : (tensor<100xf64>) -> tensor<100xf64, #[[$SV]]> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #[[$SV]]> func.func @rewrite_unary(%arg0: tensor<100xf64>) -> tensor<100xf64, #SV> { %0 = mhlo.abs %arg0 : tensor<100xf64> %1 = sparse_tensor.convert %0 : tensor<100xf64> to tensor<100xf64, #SV> @@ -26,9 +31,9 @@ func.func @rewrite_unary(%arg0: tensor<100xf64>) -> tensor<100xf64, #SV> { // CHECK-LABEL: func @rewrite_binary( // CHECK-SAME: %[[ARG0:.*]]: tensor<100xf64>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<100xf64>, tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[ARG1:.*]]: tensor<100xf64, #[[$SV]]>) -> tensor<100xf64, #[[$SV]]> { +// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<100xf64>, tensor<100xf64, #[[$SV]]> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<100xf64, #[[$SV]]> func.func @rewrite_binary(%arg0: tensor<100xf64>, %arg1: tensor<100xf64, #SV>) -> tensor<100xf64, #SV> { %0 = mhlo.multiply %arg0, %arg1 : (tensor<100xf64>, tensor<100xf64, #SV>) -> tensor<100xf64> @@ -37,10 +42,10 @@ func.func @rewrite_binary(%arg0: tensor<100xf64>, } // CHECK-LABEL: func @rewrite_binary_override( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$DCSR]]> { +// CHECK: %[[VAL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : (tensor<10x10xf64, #[[$CSR]]>, tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$DCSR]]> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #[[$DCSR]]> func.func @rewrite_binary_override(%arg0: tensor<10x10xf64, #CSR>, %arg1: tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #DCSR> { %0 = mhlo.multiply %arg0, %arg1 : (tensor<10x10xf64, #CSR>, tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #CSR> @@ -49,9 +54,9 @@ func.func @rewrite_binary_override(%arg0: tensor<10x10xf64, #CSR>, } // CHECK-LABEL: func @rewrite_convert( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL:.*]] = sparse_tensor.convert %[[ARG0]] : tensor<10x10xf64> to tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64>) -> tensor<10x10xf64, #[[$CSR]]> { +// CHECK: %[[VAL:.*]] = sparse_tensor.convert %[[ARG0]] : tensor<10x10xf64> to tensor<10x10xf64, #[[$CSR]]> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<10x10xf64, #[[$CSR]]> func.func @rewrite_convert(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64, #CSR> { %0 = sparse_tensor.convert %arg0 : tensor<10x10xf64> to tensor<10x10xf64, #DCSR> %1 = sparse_tensor.convert %0 : tensor<10x10xf64, #DCSR> to tensor<10x10xf64, #CSR> @@ -60,8 +65,8 @@ func.func @rewrite_convert(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64, #CSR> } // CHECK-LABEL: func @rewrite_convert_nop( -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK-NEXT: return %[[ARG0]] : tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf64, #[[$CSR]]>) -> tensor<10x10xf64, #[[$CSR]]> +// CHECK-NEXT: return %[[ARG0]] : tensor<10x10xf64, #[[$CSR]]> func.func @rewrite_convert_nop(%arg0: tensor<10x10xf64, #CSR>) -> tensor<10x10xf64, #CSR> { %0 = sparse_tensor.convert %arg0 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #DCSR> %1 = sparse_tensor.convert %0 : tensor<10x10xf64, #DCSR> to tensor<10x10xf64, #CSR> @@ -70,9 +75,9 @@ func.func @rewrite_convert_nop(%arg0: tensor<10x10xf64, #CSR>) -> tensor<10x10xf } // CHECK-LABEL: func @rewrite_transpose( -// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<200x100xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL:.*]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK-NEXT: return %[[VAL:.*]] : tensor<200x100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$CSR]]> { +// CHECK: %[[VAL:.*]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #[[$CSR]]> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<200x100xf64, #[[$CSR]]> func.func @rewrite_transpose(%arg0: tensor<100x200xf64, #CSR>) -> tensor<200x100xf64, #CSR> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #CSR>) -> tensor<200x100xf64> %1 = sparse_tensor.convert %0 : tensor<200x100xf64> to tensor<200x100xf64, #CSR> @@ -80,10 +85,10 @@ func.func @rewrite_transpose(%arg0: tensor<100x200xf64, #CSR>) -> tensor<200x100 } // CHECK-LABEL: func.func @rewrite_dot( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf64, #[[$CSR]]> { // CHECK: %[[VAL:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[VAL]] : tensor<5x5xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +// CHECK: return %[[VAL]] : tensor<5x5xf64, #[[$CSR]]> func.func @rewrite_dot(%arg0: tensor<5x5xf64, #CSR>, %arg1: tensor<5x5xf64, #CSR>) -> tensor<5x5xf64, #CSR> { %0 = "mhlo.dot"(%arg0, %arg1) @@ -96,10 +101,10 @@ func.func @rewrite_dot(%arg0: tensor<5x5xf64, #CSR>, } // CHECK-LABEL: func.func @rewrite_general_dot( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*1]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf64, #[[$CSR]]> { // CHECK: %[[VAL:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[ARG1]]) -// CHECK: return %[[VAL]] : tensor<5x5xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> +// CHECK: return %[[VAL]] : tensor<5x5xf64, #[[$CSR]]> func.func @rewrite_general_dot(%arg0: tensor<5x5xf64, #CSR>, %arg1: tensor<5x5xf64, #CSR>) -> tensor<5x5xf64, #CSR> { %0 = "mhlo.dot_general"(%arg0, %arg1) @@ -114,19 +119,19 @@ func.func @rewrite_general_dot(%arg0: tensor<5x5xf64, #CSR>, } // CHECK-LABEL: func.func @rewrite_elt_convert( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<5x5xf32, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-SAME: %[[ARG0:.*0]]: tensor<5x5xf64, #[[$CSR]]>) -> tensor<5x5xf32, #[[$CSR]]> { // CHECK: %[[VAL:.*]] = sparse_tensor.convert %[[ARG0]] -// CHECK: return %[[VAL]] : tensor<5x5xf32, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[VAL]] : tensor<5x5xf32, #[[$CSR]]> func.func @rewrite_elt_convert(%arg0: tensor<5x5xf64, #CSR>) -> tensor<5x5xf32, #CSR> { %0 = "mhlo.convert"(%arg0) : (tensor<5x5xf64, #CSR>) -> tensor<5x5xf32, #CSR> return %0 : tensor<5x5xf32, #CSR> } // CHECK-LABEL: func.func @concatenate_sparse( -// CHECK-SAME: %[[ARG0:.*0]]: tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[ARG1:.*1]]: tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<200x100xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-SAME: %[[ARG0:.*0]]: tensor<100x100xf64, #[[$CSR]]>, +// CHECK-SAME: %[[ARG1:.*1]]: tensor<100x100xf64, #[[$CSR]]>) -> tensor<200x100xf64, #[[$CSR]]> { // CHECK: %[[VAL:.*]] = sparse_tensor.concatenate %[[ARG0]], %[[ARG1]] {dimension = 0 -// CHECK: return %[[VAL]] : tensor<200x100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[VAL]] : tensor<200x100xf64, #[[$CSR]]> func.func @concatenate_sparse(%arg0: tensor<100x100xf64, #CSR>, %arg1: tensor<100x100xf64, #CSR>) -> tensor<200x100xf64, #CSR> { %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<100x100xf64, #CSR>, tensor<100x100xf64, #CSR>) -> tensor<200x100xf64, #CSR> return %0 : tensor<200x100xf64, #CSR> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir index 4d10863b90baee..ae8dc3213c7f3b 100755 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/sparse_transpose.mlir @@ -23,8 +23,8 @@ func.func @transpose1(%arg0: tensor<100x100xf64>) } // CHECK-LABEL: func @transpose2( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -// CHECK: return %[[A]] : tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse{{[0-9]*}}>) +// CHECK: return %[[A]] : tensor<100x100xf64, #sparse{{[0-9]*}}> func.func @transpose2(%arg0: tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> { %0 = "mhlo.transpose"(%arg0) @@ -34,8 +34,8 @@ func.func @transpose2(%arg0: tensor<100x100xf64, #DCSR>) } // CHECK-LABEL: func @transpose3( -// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64, #sparse_tensor.encoding<{{.*}}}>>) -> tensor<100x100xf64> +// CHECK-SAME: %[[A:.*]]: tensor<100x100xf64, #sparse{{[0-9]*}}>) +// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64, #sparse{{[0-9]*}}>) -> tensor<100x100xf64> // CHECK: return %[[R]] : tensor<100x100xf64> func.func @transpose3(%arg0: tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64> { @@ -47,8 +47,8 @@ func.func @transpose3(%arg0: tensor<100x100xf64, #DCSR>) // CHECK-LABEL: func @transpose4( // CHECK-SAME: %[[A:.*]]: tensor<100x100xf64>) -// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64>) -> tensor<100x100xf64, #sparse_tensor.encoding<{{.*}}}>> -// CHECK: return %[[R]] : tensor<100x100xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[R:.*]] = mhlo.reshape %[[A]] : (tensor<100x100xf64>) -> tensor<100x100xf64, #sparse{{[0-9]*}}> +// CHECK: return %[[R]] : tensor<100x100xf64, #sparse{{[0-9]*}}> func.func @transpose4(%arg0: tensor<100x100xf64>) -> tensor<100x100xf64, #DCSR> { %0 = "mhlo.transpose"(%arg0) diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index af8419890635ab..e9ce4e3ce68b2e 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1838,13 +1838,20 @@ func.func @type_quantization(%arg0: tensor>, %ar func.return %0 : tensor } +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// CHECK: #[[$SV:.*]] = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> // CHECK-LABEL: "type_sparsity" -func.func @type_sparsity(%arg0: tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> { - // CHECK: "mhlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> - %0 = "stablehlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) -> tensor<16xf32> +func.func @type_sparsity(%arg0: tensor<16xf32, #SV>) -> tensor<16xf32> { + // CHECK: "mhlo.abs"(%arg0) : (tensor<16xf32, #[[$SV]]>) -> tensor<16xf32> + %0 = "stablehlo.abs"(%arg0) : (tensor<16xf32, #SV>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } +// ----- + func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { // CHECK: function_type = (!mhlo.token) -> !mhlo.token, sym_name = "type_token_callee" // CHECK: "func.return"(%arg0) : (!mhlo.token) -> () diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 8344c204ab4f89..396fdbc4c78f71 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -1,6 +1,9 @@ // RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s // RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION +// CHECK: #[[$DC:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }> +// CHECK: #[[$CSS:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, nonordered), d2 : singleton(nonordered)), posWidth = 32, crdWidth = 32 }> + // NO_DEAD_FUNCTION-NOT: @test // CHECK: module @foobar @@ -1757,13 +1760,13 @@ add { } // CHECK-LABEL : func private @sparse -// CHECK: tensor<10x10xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }>>) -> tensor<10x10xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 32, crdWidth = 32 }>> +// CHECK: tensor<10x10xf32, #[[$DC]]>) -> tensor<10x10xf32, #[[$DC]]> %sparse { ROOT root = f32[10,10]{1,0:D(D,C)} parameter(0) } // CHECK-LABEL : func private @sparse_nu_no -// CHECK: tensor<3x4x5xf32, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, nonordered), d2 : singleton(nonordered)), posWidth = 32, crdWidth = 32 }>>) -> tensor<3x4x5xf32, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, nonordered), d2 : singleton(nonordered)), posWidth = 32, crdWidth = 32 }>> +// CHECK: tensor<3x4x5xf32, #[[$CSS]]>) -> tensor<3x4x5xf32, #[[$CSS]]> %sparse_nu_no { ROOT root = f32[3,4,5]{2,1,0:D(C+,S+~,S~)} parameter(0) } From b695b20d2bf4c5705d1d849f2c26f876e0b7d1c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 00:12:09 -0800 Subject: [PATCH 213/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/a3959294a297645c34a6adbdd639d7df5c84d691. PiperOrigin-RevId: 583287741 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 19b22547b3342a..4958ad4de72977 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "2e40c24ac028946d3c22289c1e8d7c8dc31204f2" - TFRT_SHA256 = "3afbe32abb11e76509f5b0313c0a83ff3be7ea5c39760e74982be7e6d94f9fd1" + TFRT_COMMIT = "a3959294a297645c34a6adbdd639d7df5c84d691" + TFRT_SHA256 = "3765b313e3da83774a50f01ba52ac155229359d0bfb170beb9714197ca64ec4a" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 19b22547b3342a..4958ad4de72977 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "2e40c24ac028946d3c22289c1e8d7c8dc31204f2" - TFRT_SHA256 = "3afbe32abb11e76509f5b0313c0a83ff3be7ea5c39760e74982be7e6d94f9fd1" + TFRT_COMMIT = "a3959294a297645c34a6adbdd639d7df5c84d691" + TFRT_SHA256 = "3765b313e3da83774a50f01ba52ac155229359d0bfb170beb9714197ca64ec4a" tf_http_archive( name = "tf_runtime", From 49192e78b6a2d64665a1a93256da9697a2a93633 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 01:02:01 -0800 Subject: [PATCH 214/391] compat: Update forward compatibility horizon to 2023-11-17 PiperOrigin-RevId: 583298085 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 2098432509a560..d2c52044b7901b 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 16) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 1f27207f1a8918696bd3751d22da03bbadafca1f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 01:02:10 -0800 Subject: [PATCH 215/391] Update GraphDef version to 1683. PiperOrigin-RevId: 583298148 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index b00916cf91259d..9c5b17b70cee85 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1682 // Updated: 2023/11/16 +#define TF_GRAPH_DEF_VERSION 1683 // Updated: 2023/11/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 87681cbb46a32243ec7feca93adf63ce20344d42 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Fri, 17 Nov 2023 01:45:54 -0800 Subject: [PATCH 216/391] PR #7025: [NVIDIA XLA] Add 2 new attributes to hlo instructions to specify the operation queue to run current instruction. Imported from GitHub PR https://github.com/openxla/xla/pull/7025 In CUDA term, operation queue id will be the cuda stream ID. This is to support upcoming cases such as collective matmul where we may need to run compute kernels on multiple streams. operation_queue_id means the queue(stream) to run this instruction. wait_on_operation_queues means the queues(streams) to wait for input data, this signals the backend to insert EventWait primitives. Copybara import of the project: -- 9059ecfe88ca18c4c23abd5af6bf5ceaf650249e by TJ : Add 2 new attributes to hlo instructions to specify the operation queue to run current instruction. In CUDA term, this will be the cuda stream ID. Merging this change closes #7025 PiperOrigin-RevId: 583307710 --- third_party/xla/xla/hlo/ir/hlo_instruction.cc | 17 +++++ third_party/xla/xla/hlo/ir/hlo_instruction.h | 26 +++++++ third_party/xla/xla/service/hlo.proto | 14 +++- .../xla/xla/service/hlo_instruction_test.cc | 75 +++++++++++++++++++ third_party/xla/xla/service/hlo_parser.cc | 15 ++++ 5 files changed, 146 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index a8b1a377d834bd..bc375c570b0909 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -3599,6 +3599,23 @@ void HloInstruction::PrintExtraAttributes( AppendCat(printer, "statistics=", StatisticsVizToString(statistics_viz_)); }); } + + if (operation_queue_id_) { + printer.Next([this](Printer* printer) { + AppendCat(printer, "operation_queue_id=", *operation_queue_id_); + }); + } + + if (wait_on_operation_queues_.size() > 0) { + printer.Next([this, &options](Printer* printer) { + printer->Append("wait_on_operation_queues={"); + AppendJoin(printer, wait_on_operation_queues_, ", ", + [&](Printer* printer, int64_t queue_id) { + printer->Append(queue_id); + }); + printer->Append("}"); + }); + } } std::vector HloInstruction::ExtraAttributesToString( diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 37101f793551f5..47ed8792b8ad21 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1909,6 +1909,26 @@ class HloInstruction { bool is_default_config() const { return is_default_config_; } void set_default_config() { is_default_config_ = true; } + void set_operation_queue_id(int64_t operation_queue_id) { + operation_queue_id_ = operation_queue_id; + } + + const std::optional operation_queue_id() const { + return operation_queue_id_; + } + + void set_wait_on_operation_queues(std::vector& operation_queue_ids) { + wait_on_operation_queues_ = operation_queue_ids; + } + + const std::vector wait_on_operation_queues() const { + return wait_on_operation_queues_; + } + + void add_wait_on_operation_queues(int64_t operation_queue_id) { + wait_on_operation_queues_.push_back(operation_queue_id); + } + // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -2520,6 +2540,12 @@ class HloInstruction { // Intrusive flag used by HloComputation, whether this instruction has // been marked as dead. bool marked_as_dead_; + + // ID of the operation queue to run this instruction. + std::optional operation_queue_id_; + + // IDs of operation queues to await before running this instruction. + std::vector wait_on_operation_queues_; }; // Explicit instantiations in hlo_instruction.cc. diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 22d46f62ff1a76..547f6bcd171758 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -111,7 +111,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 83 +// Next ID: 85 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -367,6 +367,18 @@ message HloInstructionProto { // Represents the information for tracking propagation of values within HLO // graph. xla.StatisticsViz statistics_viz = 82; + + // Specifies which operation queue the current instruction will run on. + // A backend may have multiple operation queues to run instructions + // concurrently, use this to signal the backend which queue to dispatch to. + // The backend should keep a mapping of + // operation_queue_id->actual_hardware_queue_id if runtime will create + // different IDs. + int64 operation_queue_id = 83; + + // Specifies which operation queues to await for data when running with + // multiple operation queues. + repeated int64 wait_on_operation_queues = 84; } // Serialization of HloComputation. diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index e5ab4db8acba4c..22f07bcfc880d2 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -2541,5 +2541,80 @@ TEST_F(HloInstructionTest, PrintCycle) { ASSERT_IS_OK(send_done->DropAllControlDeps()); } +TEST_F(HloInstructionTest, SetOperationQueueId) { + std::unique_ptr main_computation; + HloComputation::Builder main_builder("Entry"); + const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); + HloInstruction* param0 = main_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* param1 = main_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + + HloInstruction* add = + main_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + add->set_operation_queue_id(3); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(main_builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(module->entry_computation()->root_instruction()->ToString(options), + "%add = f32[] add(f32[] %p0, f32[] %p1), operation_queue_id=3"); +} + +TEST_F(HloInstructionTest, SetWaitOnOperationQueues) { + std::unique_ptr main_computation; + HloComputation::Builder main_builder("Entry"); + const Shape scalar_shape = ShapeUtil::MakeScalarShape(F32); + HloInstruction* param0 = main_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* param1 = main_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + + HloInstruction* add = + main_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + std::vector wait_on_queues = {0, 2}; + add->set_wait_on_operation_queues(wait_on_queues); + add->add_wait_on_operation_queues(5); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(main_builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(module->entry_computation()->root_instruction()->ToString(options), + "%add = f32[] add(f32[] %p0, f32[] %p1), " + "wait_on_operation_queues={0, 2, 5}"); +} + +TEST_F(HloInstructionTest, ParseOperationQueueId) { + constexpr char kHloString[] = R"( + ENTRY main { + c0 = f32[] constant(0) + ROOT add0 = f32[] add(c0, c0), operation_queue_id=2 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + EXPECT_EQ( + module->entry_computation()->root_instruction()->operation_queue_id(), 2); +} + +TEST_F(HloInstructionTest, ParseWaitOnOperationQueues) { + constexpr char kHloString[] = R"( + ENTRY main { + c0 = f32[] constant(0) + ROOT add0 = f32[] add(c0, c0), wait_on_operation_queues={0,2} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + std::vector expected_wait_on_queue_ids = {0, 2}; + for (int64_t i = 0; i < expected_wait_on_queue_ids.size(); i++) { + EXPECT_EQ(expected_wait_on_queue_ids[i], + module->entry_computation() + ->root_instruction() + ->wait_on_operation_queues()[i]); + } +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 223b246b9bbe9b..c7a6433952d0ea 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -1322,6 +1322,14 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kStringOrJsonDict, &backend_config}; + optional operation_queue_id; + attrs["operation_queue_id"] = {/*required=*/false, AttrTy::kInt64, + &operation_queue_id}; + + optional> wait_on_operation_queues; + attrs["wait_on_operation_queues"] = { + /*required=*/false, AttrTy::kBracedInt64List, &wait_on_operation_queues}; + std::optional maybe_shape; if (parse_shape) { maybe_shape = shape; @@ -1393,6 +1401,13 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (statistics_viz) { instruction->set_statistics_viz(*statistics_viz); } + if (operation_queue_id) { + instruction->set_operation_queue_id(*operation_queue_id); + } + if (wait_on_operation_queues) { + instruction->set_wait_on_operation_queues(*wait_on_operation_queues); + } + return AddInstruction(name, instruction, name_loc); } From ac012e26d4331919335d4bceb8abe22b68ed5434 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 17 Nov 2023 02:01:33 -0800 Subject: [PATCH 217/391] [xla:python] Don't assume GPU clients are PJRT-compatible. Instead explicitly check for compatibility. PiperOrigin-RevId: 583311176 --- third_party/xla/xla/python/py_client.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 752edd05cedb3f..3f03c99d601e6e 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -372,15 +372,19 @@ StatusOr> PyClient::Compile( std::vector host_callbacks) { // Pass allocated device memory size to compile options for pjrt compatible // backends. - if ((ifrt_client_->platform_id() == xla::CudaId() || - ifrt_client_->platform_id() == xla::RocmId()) && - !pjrt_client()->devices().empty()) { - auto maybe_stats = pjrt_client()->devices()[0]->GetAllocatorStats(); - if (maybe_stats.ok() && maybe_stats->bytes_limit) { - options.executable_build_options.set_device_memory_size( - *maybe_stats->bytes_limit); + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_compatible_client != nullptr) { + auto devices = pjrt_compatible_client->pjrt_client()->devices(); + if (!devices.empty()) { + auto stats = devices[0]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } } } + std::unique_ptr ifrt_loaded_executable; std::optional fingerprint; auto ifrt_compile_options = From 3f0e719bdb8d581af71f35c9c84e972e62d73b6d Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 17 Nov 2023 03:00:38 -0800 Subject: [PATCH 218/391] Remove is_scheduled from .hlo files. hlo-opt runs the HLO passes, and if we have is_scheduled=true, the scheduling pass will be skipped which currently makes the test fail. PiperOrigin-RevId: 583324486 --- third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo | 2 +- third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo index 6abb61fcc51d9b..b99bba81b2a394 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -1,6 +1,6 @@ // RUN: hlo-opt %s --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s -HloModule m, is_scheduled=true +HloModule m add { a = f16[] parameter(0) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo index 9ed67ee237db0e..b9ba42564d8227 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo @@ -1,6 +1,6 @@ // RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s --dump-input-filter=all -HloModule m, is_scheduled=true +HloModule m add { a = f16[] parameter(0) From 2d643457aebf8e4b155b3e749b8b9d06d423684d Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 17 Nov 2023 04:17:56 -0800 Subject: [PATCH 219/391] Priority fusion: cache HloFusionAnalyses. PiperOrigin-RevId: 583340116 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/hlo_fusion_analysis.cc | 123 +++++++++--------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 34 ++--- .../xla/service/gpu/kernel_mapping_scheme.h | 24 ++-- third_party/xla/xla/service/gpu/model/BUILD | 29 +++++ .../gpu/model/fusion_analysis_cache.cc | 93 +++++++++++++ .../service/gpu/model/fusion_analysis_cache.h | 69 ++++++++++ .../gpu/model/fusion_analysis_cache_test.cc | 115 ++++++++++++++++ .../gpu/model/gpu_performance_model.cc | 36 ++++- .../service/gpu/model/gpu_performance_model.h | 12 +- .../gpu/model/gpu_performance_model_test.cc | 2 +- .../xla/xla/service/gpu/priority_fusion.cc | 21 ++- .../xla/xla/service/gpu/priority_fusion.h | 9 +- .../xla/service/gpu/priority_fusion_test.cc | 8 +- 14 files changed, 466 insertions(+), 110 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h create mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 58ae615215ee3b..8445dec5c78840 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2071,6 +2071,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", + "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index bb2fe734a63055..c064c8c9c1770c 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -58,6 +58,35 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; +std::optional ComputeTransposeTilingScheme( + const std::optional& tiled_transpose) { + if (!tiled_transpose) { + return std::nullopt; + } + + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the input shape. + Vector3 dims = tiled_transpose->dimensions; + Vector3 order = tiled_transpose->permutation; + + Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; + Vector3 tile_sizes{1, 1, 1}; + tile_sizes[order[2]] = WarpSize() / kNumRows; + Vector3 num_threads{1, 1, WarpSize()}; + num_threads[order[2]] = kNumRows; + + return TilingScheme( + /*permuted_dims*/ permuted_dims, + /*tile_sizes=*/tile_sizes, + /*num_threads=*/num_threads, + /*indexing_order=*/kLinearIndexingX, + /*vector_size=*/1, + /*scaling_factor=*/1, + /*tiling_dimensions=*/{order[2], 2}); +} + // Returns true if `instr` is a non-strided slice. bool IsSliceWithUnitStrides(const HloInstruction* instr) { auto slice = DynCast(instr); @@ -257,6 +286,28 @@ std::optional FindConsistentTransposeHero( } // namespace +HloFusionAnalysis::HloFusionAnalysis( + FusionBackendConfig fusion_backend_config, + std::vector fusion_roots, + FusionBoundaryFn fusion_boundary_fn, + std::vector fusion_arguments, + std::vector fusion_heroes, + const se::DeviceDescription* device_info, + std::optional tiled_transpose, bool has_4_bit_input, + bool has_4_bit_output) + : fusion_backend_config_(std::move(fusion_backend_config)), + fusion_roots_(std::move(fusion_roots)), + fusion_boundary_fn_(std::move(fusion_boundary_fn)), + fusion_arguments_(std::move(fusion_arguments)), + fusion_heroes_(std::move(fusion_heroes)), + device_info_(device_info), + tiled_transpose_(tiled_transpose), + has_4_bit_input_(has_4_bit_input), + has_4_bit_output_(has_4_bit_output), + reduction_codegen_info_(ComputeReductionCodegenInfo(FindHeroReduction())), + transpose_tiling_scheme_(ComputeTransposeTilingScheme(tiled_transpose_)), + loop_fusion_config_(ComputeLoopFusionConfig()) {} + // static StatusOr HloFusionAnalysis::Create( FusionBackendConfig backend_config, @@ -353,7 +404,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kLoop; } -StatusOr HloFusionAnalysis::GetLaunchDimensions() { +StatusOr HloFusionAnalysis::GetLaunchDimensions() const { auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { @@ -403,7 +454,9 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions() { } const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { - CHECK(GetEmitterFusionKind() == EmitterFusionKind::kReduction); + if (GetEmitterFusionKind() != EmitterFusionKind::kReduction) { + return nullptr; + } auto roots = fusion_roots(); CHECK(!roots.empty()); // We always use the first reduce root that triggers unnested reduction @@ -418,57 +471,8 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { - if (reduction_codegen_info_.has_value()) { - return &reduction_codegen_info_.value(); - } - - const HloInstruction* hero_reduction = FindHeroReduction(); - - auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); - reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); - return &reduction_codegen_info_.value(); -} - -const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { - if (transpose_tiling_scheme_.has_value()) { - return &transpose_tiling_scheme_.value(); - } - - if (!tiled_transpose_) { - return nullptr; - } - - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the input shape. - Vector3 dims = tiled_transpose_->dimensions; - Vector3 order = tiled_transpose_->permutation; - - Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; - Vector3 tile_sizes{1, 1, 1}; - tile_sizes[order[2]] = WarpSize() / kNumRows; - Vector3 num_threads{1, 1, WarpSize()}; - num_threads[order[2]] = kNumRows; - - TilingScheme tiling_scheme( - /*permuted_dims*/ permuted_dims, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1, - /*scaling_factor=*/1, - /*tiling_dimensions=*/{order[2], 2}); - transpose_tiling_scheme_.emplace(std::move(tiling_scheme)); - return &transpose_tiling_scheme_.value(); -} - -const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { - if (loop_fusion_config_.has_value()) { - return &loop_fusion_config_.value(); - } - +std::optional +HloFusionAnalysis::ComputeLoopFusionConfig() const { int unroll_factor = 1; // Unrolling is good to read large inputs with small elements // due to vector loads, but increases the register pressure when one @@ -501,8 +505,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) { // Only the unroll factor is used for scatter. - loop_fusion_config_.emplace(LaunchDimensionsConfig{unroll_factor}); - return &loop_fusion_config_.value(); + return LaunchDimensionsConfig{unroll_factor}; } bool row_vectorized; @@ -537,8 +540,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { launch_config.row_vectorized = false; launch_config.few_waves = false; } - loop_fusion_config_.emplace(std::move(launch_config)); - return &loop_fusion_config_.value(); + return launch_config; } const Shape& HloFusionAnalysis::GetElementShape() const { @@ -809,8 +811,13 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( +std::optional +HloFusionAnalysis::ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const { + if (!hero_reduction) { + return std::nullopt; + } + Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index c07819db2d3a15..1bec5ca650be47 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -68,19 +68,27 @@ class HloFusionAnalysis { // Determines the launch dimensions for the fusion. The fusion kind must not // be `kTriton`. - StatusOr GetLaunchDimensions(); + StatusOr GetLaunchDimensions() const; // Calculates the reduction information. Returns `nullptr` if the fusion is // not a reduction. - const ReductionCodegenInfo* GetReductionCodegenInfo(); + const ReductionCodegenInfo* GetReductionCodegenInfo() const { + return reduction_codegen_info_.has_value() ? &*reduction_codegen_info_ + : nullptr; + } // Calculates the transpose tiling information. Returns `nullptr` if the // fusion is not a transpose. - const TilingScheme* GetTransposeTilingScheme(); + const TilingScheme* GetTransposeTilingScheme() const { + return transpose_tiling_scheme_.has_value() ? &*transpose_tiling_scheme_ + : nullptr; + } // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a // loop. - const LaunchDimensionsConfig* GetLoopFusionConfig(); + const LaunchDimensionsConfig* GetLoopFusionConfig() const { + return loop_fusion_config_.has_value() ? &*loop_fusion_config_ : nullptr; + } // Returns the hero reduction of the computation. const HloInstruction* FindHeroReduction() const; @@ -93,16 +101,7 @@ class HloFusionAnalysis { std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, - bool has_4_bit_input, bool has_4_bit_output) - : fusion_backend_config_(std::move(fusion_backend_config)), - fusion_roots_(std::move(fusion_roots)), - fusion_boundary_fn_(std::move(fusion_boundary_fn)), - fusion_arguments_(std::move(fusion_arguments)), - fusion_heroes_(std::move(fusion_heroes)), - device_info_(device_info), - tiled_transpose_(tiled_transpose), - has_4_bit_input_(has_4_bit_input), - has_4_bit_output_(has_4_bit_output) {} + bool has_4_bit_input, bool has_4_bit_output); const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; @@ -118,8 +117,9 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - ReductionCodegenInfo ComputeReductionCodegenInfo( + std::optional ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const; + std::optional ComputeLoopFusionConfig() const; bool HasConsistentTransposeHeros() const; FusionBackendConfig fusion_backend_config_; @@ -131,8 +131,8 @@ class HloFusionAnalysis { std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; - const bool has_4_bit_input_ = false; - const bool has_4_bit_output_ = false; + bool has_4_bit_input_ = false; + bool has_4_bit_output_ = false; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index 4a6f0f7ae3c6fa..f7b51c42c6beaf 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -146,34 +146,34 @@ class TilingScheme { private: // The number of elements in each dimension. - const Vector3 dims_in_elems_; + Vector3 dims_in_elems_; // The number of elements for each dimension of a tile. - const Vector3 tile_sizes_; + Vector3 tile_sizes_; // The dimensions which are used for the shared memory tile. - const Vector2 tiling_dimensions_; + Vector2 tiling_dimensions_; // Number of threads implicitly assigned to each dimension. - const Vector3 num_threads_; + Vector3 num_threads_; - const IndexingOrder indexing_order_; + IndexingOrder indexing_order_; // Vector size for dimension X. - const int vector_size_; + int vector_size_; // Scaling apply to transform physical threadIdx into logical. - const int64_t thread_id_virtual_scaling_ = 1; + int64_t thread_id_virtual_scaling_ = 1; }; class ReductionCodegenInfo { public: using IndexGroups = std::vector>; - explicit ReductionCodegenInfo(TilingScheme mapping_scheme, - int num_partial_results, bool is_row_reduction, - bool is_race_free, IndexGroups index_groups, - const HloInstruction* first_reduce) + ReductionCodegenInfo(TilingScheme mapping_scheme, int num_partial_results, + bool is_row_reduction, bool is_race_free, + IndexGroups index_groups, + const HloInstruction* first_reduce) : tiling_scheme_(mapping_scheme), num_partial_results_(num_partial_results), is_row_reduction_(is_row_reduction), @@ -198,7 +198,7 @@ class ReductionCodegenInfo { private: friend class ReductionCodegenState; - const TilingScheme tiling_scheme_; + TilingScheme tiling_scheme_; int num_partial_results_; bool is_row_reduction_; bool is_race_free_; diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 1b542e7dca447a..034f2a2f20d2ff 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -66,6 +66,34 @@ xla_test( ], ) +cc_library( + name = "fusion_analysis_cache", + srcs = ["fusion_analysis_cache.cc"], + hdrs = ["fusion_analysis_cache.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/synchronization", + ], +) + +xla_cc_test( + name = "fusion_analysis_cache_test", + srcs = ["fusion_analysis_cache_test.cc"], + deps = [ + ":fusion_analysis_cache", + "//xla/service:hlo_parser", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "gpu_cost_model_stats_collection", srcs = ["gpu_cost_model_stats_collection.cc"], @@ -149,6 +177,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ + ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", "//xla:shape_util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc new file mode 100644 index 00000000000000..00a294413506ac --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/fusion_analysis_cache.h" + +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla::gpu { + +const std::optional& HloFusionAnalysisCache::Get( + const HloInstruction& instruction) { + { + absl::ReaderMutexLock lock(&mutex_); + auto it = analyses_.find(instruction.unique_id()); + if (it != analyses_.end()) { + return it->second; + } + } + + std::optional analysis = + AnalyzeFusion(instruction, device_info_); + absl::MutexLock lock(&mutex_); + + // If some other thread created an entry for this key concurrently, return + // that instead (the other thread is likely using the instance). + auto it = analyses_.find(instruction.unique_id()); + if (it != analyses_.end()) { + return it->second; + } + + return analyses_[instruction.unique_id()] = std::move(analysis); +} + +const std::optional& HloFusionAnalysisCache::Get( + const HloInstruction& producer, const HloInstruction& consumer) { + std::pair key{producer.unique_id(), consumer.unique_id()}; + { + absl::ReaderMutexLock lock(&mutex_); + auto it = producer_consumer_analyses_.find(key); + if (it != producer_consumer_analyses_.end()) { + return it->second; + } + } + + std::optional analysis = + AnalyzeProducerConsumerFusion(producer, consumer, device_info_); + absl::MutexLock lock(&mutex_); + + // If some other thread created an entry for this key concurrently, return + // that instead (the other thread is likely using the instance). + auto it = producer_consumer_analyses_.find(key); + if (it != producer_consumer_analyses_.end()) { + return it->second; + } + + producers_for_consumers_[consumer.unique_id()].push_back( + producer.unique_id()); + consumers_for_producers_[producer.unique_id()].push_back( + consumer.unique_id()); + return producer_consumer_analyses_[key] = std::move(analysis); +} + +void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { + absl::MutexLock lock(&mutex_); + analyses_.erase(instruction.unique_id()); + + if (auto consumers = + consumers_for_producers_.extract(instruction.unique_id())) { + for (const auto consumer : consumers.mapped()) { + producer_consumer_analyses_.erase({instruction.unique_id(), consumer}); + } + } + if (auto producers = + producers_for_consumers_.extract(instruction.unique_id())) { + for (const auto producer : producers.mapped()) { + producer_consumer_analyses_.erase({producer, instruction.unique_id()}); + } + } +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h new file mode 100644 index 00000000000000..b13c0a102f3704 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h @@ -0,0 +1,69 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ +#define XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ + +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` +// and `Invalidate` the same key. Analyses are cached based on unique_ids, no +// checking or tracking of changes is done. +class HloFusionAnalysisCache { + public: + explicit HloFusionAnalysisCache( + const stream_executor::DeviceDescription& device_info) + : device_info_(device_info) {} + + // Returns the analysis for the given instruction, creating it if it doesn't + // exist yet. Do not call concurrently with `Invalidate` for the same key. + const std::optional& Get( + const HloInstruction& instruction); + + // Returns the analysis for the given producer/consumer pair. + const std::optional& Get(const HloInstruction& producer, + const HloInstruction& consumer); + + // Removes the cache entry for the given instruction, if it exists. Also + // removes all producer-consumer fusions that involve this instruction. + void Invalidate(const HloInstruction& instruction); + + private: + const stream_executor::DeviceDescription& device_info_; + + absl::Mutex mutex_; + +// All `int` keys and values here are unique instruction IDs. + absl::node_hash_map> analyses_; + absl::node_hash_map, std::optional> + producer_consumer_analyses_; + + // For each instruction `producer`, contains the `consumer`s for which we have + // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. + absl::flat_hash_map> consumers_for_producers_; + // For each instruction `consumer`, contains the `producer`s for which we have + // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. + absl::flat_hash_map> producers_for_consumers_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc new file mode 100644 index 00000000000000..edacd6a7c8666b --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/fusion_analysis_cache.h" + +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/hlo_parser.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla::gpu { +namespace { + +class FusionAnalysisCacheTest : public HloTestBase { + public: + stream_executor::DeviceDescription device_{ + TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + HloFusionAnalysisCache cache_{device_}; +}; + +TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { + absl::string_view hlo_string = R"( + HloModule m + + f { + c0 = f32[] constant(0) + b0 = f32[1000] broadcast(c0) + ROOT n0 = f32[1000] negate(b0) + } + + ENTRY e { + ROOT r.1 = f32[1000] fusion(), kind=kLoop, calls=f + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto* computation = module->GetComputationWithName("f"); + auto* broadcast = computation->GetInstructionWithName("b0"); + auto* negate = computation->GetInstructionWithName("n0"); + auto* fusion = module->entry_computation()->root_instruction(); + + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(negate)); + + computation->set_root_instruction(broadcast); + + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(negate)) + << "Analysis should be cached."; + + cache_.Invalidate(*fusion); + EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), + ::testing::ElementsAre(broadcast)) + << "Analysis should have been recomputed"; +} + +TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { + absl::string_view hlo_string = R"( + HloModule m + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + f { + c0 = f32[] constant(0) + b0 = f32[1000] broadcast(c0) + ROOT r0 = f32[] reduce(b0, c0), dimensions={0}, to_apply=add + } + + ENTRY e { + f0 = f32[] fusion(), kind=kInput, calls=f + ROOT n0 = f32[] negate(f0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto* fusion = module->entry_computation()->GetInstructionWithName("f0"); + auto* neg = module->entry_computation()->GetInstructionWithName("n0"); + + auto* computation = module->GetComputationWithName("f"); + auto* constant = computation->GetInstructionWithName("c0"); + + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kReduction); + + computation->set_root_instruction(constant); + + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kReduction) + << "Analysis should be cached."; + + cache_.Invalidate(*fusion); + EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), + HloFusionAnalysis::EmitterFusionKind::kLoop) + << "Analysis should have been recomputed"; +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 64e9132d5ce507..c9acce7f65c3ca 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -224,7 +224,7 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, // that the IR emitter will use. LaunchDimensions EstimateFusionLaunchDimensions( int64_t estimated_num_threads, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const se::DeviceDescription& device_info) { if (fusion_analysis) { // TODO(jreiffers): This is the wrong place for this DUS analysis. @@ -269,7 +269,15 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; - auto fusion_analysis = AnalyzeFusion(*instr, *cost_analysis->device_info_); + // Use the analysis cache if present. + // TODO(jreiffers): Remove this once all callers use a cache. + std::optional local_analysis = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeFusion(*instr, *cost_analysis->device_info_); + const auto& fusion_analysis = config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*instr) + : local_analysis; LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(instr->shape()), fusion_analysis, *device_info); @@ -341,7 +349,7 @@ float GetCommonUtilization( const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); @@ -430,7 +438,16 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); + // Use the analysis cache if present. + // TODO(jreiffers): Remove this once all callers use a cache. + std::optional local_analysis = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeFusion(*fused_consumer, *device_info); + const auto& analysis_unfused = + config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*fused_consumer) + : local_analysis; LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(fused_consumer->shape()), @@ -479,8 +496,15 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( // // TODO(shyshkov): Add calculations for consumer epilogue in the formula to // make it complete. - auto analysis_fused = - AnalyzeProducerConsumerFusion(*producer, *fused_consumer, *device_info); + std::optional local_analysis_fused = + config.fusion_analysis_cache + ? std::nullopt + : AnalyzeProducerConsumerFusion(*producer, *fused_consumer, + *device_info); + const auto& analysis_fused = + config.fusion_analysis_cache + ? config.fusion_analysis_cache->Get(*producer, *fused_consumer) + : local_analysis_fused; LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_data.num_threads * utilization_by_this_consumer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index b7b28fff1eeda7..0fcc8cfcb2abf2 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" @@ -62,20 +63,25 @@ struct GpuPerformanceModelOptions { // re-reads can happen from cache. bool first_read_from_dram = false; + // If present, use this to retrieve fusion analyses. + HloFusionAnalysisCache* fusion_analysis_cache = nullptr; + static GpuPerformanceModelOptions Default() { return GpuPerformanceModelOptions(); } - static GpuPerformanceModelOptions PriorityFusion() { + static GpuPerformanceModelOptions PriorityFusion( + HloFusionAnalysisCache* fusion_analysis_cache) { GpuPerformanceModelOptions config; config.consider_coalescing = true; config.first_read_from_dram = true; + config.fusion_analysis_cache = fusion_analysis_cache; return config; } static GpuPerformanceModelOptions ForModule(const HloModule* module) { return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion() + ? PriorityFusion(nullptr) // Only cache within priority fusion. : Default(); } }; @@ -121,7 +127,7 @@ class GpuPerformanceModel { const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - std::optional& fusion_analysis, + const std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer = nullptr); }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 68bde4b9010382..d768bce08c55ef 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -360,7 +360,7 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(), + producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 8df808c1bfad82..1c2e7c93970e30 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/instruction_fusion.h" @@ -79,12 +80,14 @@ class GpuPriorityFusionQueue : public FusionQueue { const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool) + tsl::thread::ThreadPool* thread_pool, + HloFusionAnalysisCache& fusion_analysis_cache) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), can_fuse_(can_fuse), fusion_process_dump_(fusion_process_dump), - thread_pool_(thread_pool) { + thread_pool_(thread_pool), + fusion_analysis_cache_(fusion_analysis_cache) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -181,6 +184,9 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } + fusion_analysis_cache_.Invalidate(*fusion); + fusion_analysis_cache_.Invalidate(*original_producer); + // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. // Priority recomputation is called before DCE. Remove all references to @@ -258,6 +264,7 @@ class GpuPriorityFusionQueue : public FusionQueue { void RemoveInstruction(HloInstruction* instruction) override { to_update_priority_.erase(instruction); producer_user_count_.erase(instruction); + fusion_analysis_cache_.Invalidate(*instruction); auto reverse_it = reverse_map_.find(instruction); if (reverse_it == reverse_map_.end()) { @@ -289,7 +296,8 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, &cost_analysis_, - GpuPerformanceModelOptions::PriorityFusion(), producer->users()); + GpuPerformanceModelOptions::PriorityFusion(&fusion_analysis_cache_), + producer->users()); if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = @@ -365,6 +373,8 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::Mutex fusion_process_dump_mutex_; tsl::thread::ThreadPool* thread_pool_; + + HloFusionAnalysisCache& fusion_analysis_cache_; }; } // namespace @@ -502,8 +512,7 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( // matter but some passes downstream still query these instead of fusion // analysis. // TODO: Don't recompute this all the time. - auto analysis = - AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); + const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer); if (!analysis) return HloInstruction::FusionKind::kLoop; switch (analysis->GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kLoop: @@ -544,7 +553,7 @@ std::unique_ptr GpuPriorityFusion::GetFusionQueue( [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index); }, - fusion_process_dump_.get(), thread_pool_)); + fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index 1723766d7784c8..afc5e8f99003d4 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" @@ -42,12 +43,13 @@ namespace gpu { class GpuPriorityFusion : public InstructionFusion { public: GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& d, + const se::DeviceDescription& device, GpuHloCostAnalysis::Options cost_analysis_options) : InstructionFusion(GpuPriorityFusion::IsExpensive), thread_pool_(thread_pool), - device_info_(d), - cost_analysis_options_(std::move(cost_analysis_options)) {} + device_info_(device), + cost_analysis_options_(std::move(cost_analysis_options)), + fusion_analysis_cache_(device_info_) {} absl::string_view name() const override { return "priority-fusion"; } @@ -86,6 +88,7 @@ class GpuPriorityFusion : public InstructionFusion { absl::Mutex fusion_node_evaluations_mutex_; absl::flat_hash_map fusion_node_evaluations_; + HloFusionAnalysisCache fusion_analysis_cache_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 5574173a75ef11..310340af91d391 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -237,10 +237,10 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { } )"; - EXPECT_THAT( - RunAndGetFusionKinds(kHlo), - ::testing::ElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop, - HloFusionAnalysis::EmitterFusionKind::kReduction)); + EXPECT_THAT(RunAndGetFusionKinds(kHlo), + ::testing::UnorderedElementsAre( + HloFusionAnalysis::EmitterFusionKind::kLoop, + HloFusionAnalysis::EmitterFusionKind::kReduction)); RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY From 54cfd9082c78351f7615598e931b0fd84fcd2d67 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 04:40:43 -0800 Subject: [PATCH 220/391] Integrate LLVM at llvm/llvm-project@de176d8c5496 Updates LLVM usage to match [de176d8c5496](https://github.com/llvm/llvm-project/commit/de176d8c5496) PiperOrigin-RevId: 583344920 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 1694b4045c6cc5..1cd80656fe497f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "00da5eb86ed0b86002b0947643f7da72faa4fd42" - LLVM_SHA256 = "fb2c08c558cb28d16be3d21ecbb600c4a481a5796c985d5b9e677d757b6021c1" + LLVM_COMMIT = "de176d8c5496d6cf20e82aface98e102c593dbe2" + LLVM_SHA256 = "83239b51d91f9b07d110f66ddea740f028efb61b1bdcf0d0cd0f53ec859a000d" tf_http_archive( name = name, From d10d70bd76bbad6db89b4f8a53e4bbedcedb7e62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 05:02:26 -0800 Subject: [PATCH 221/391] Import openai/triton from GitHub. PiperOrigin-RevId: 583348989 --- third_party/triton/cl577369732.patch | 116 -- third_party/triton/cl580481414.patch | 1512 ----------------- third_party/triton/cl580550344.patch | 304 ---- third_party/triton/cl580852372.patch | 15 - third_party/triton/workspace.bzl | 7 +- .../xla/third_party/triton/cl577369732.patch | 116 -- .../xla/third_party/triton/cl580481414.patch | 1512 ----------------- .../xla/third_party/triton/cl580550344.patch | 304 ---- .../xla/third_party/triton/cl580852372.patch | 15 - .../xla/third_party/triton/workspace.bzl | 7 +- 10 files changed, 4 insertions(+), 3904 deletions(-) delete mode 100644 third_party/triton/cl577369732.patch delete mode 100644 third_party/triton/cl580481414.patch delete mode 100644 third_party/triton/cl580550344.patch delete mode 100644 third_party/triton/cl580852372.patch delete mode 100644 third_party/xla/third_party/triton/cl577369732.patch delete mode 100644 third_party/xla/third_party/triton/cl580481414.patch delete mode 100644 third_party/xla/third_party/triton/cl580550344.patch delete mode 100644 third_party/xla/third_party/triton/cl580852372.patch diff --git a/third_party/triton/cl577369732.patch b/third_party/triton/cl577369732.patch deleted file mode 100644 index e63b9f3804974b..00000000000000 --- a/third_party/triton/cl577369732.patch +++ /dev/null @@ -1,116 +0,0 @@ -==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -759,7 +759,7 @@ - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { -- OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &operand = *forOp.getTiedLoopInit(arg); - setValueMapping(arg, operand.get(), 0); - } - -==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -188,7 +188,7 @@ - auto getIncomingOp = [this](Value v) -> Value { - if (auto arg = v.dyn_cast()) - if (arg.getOwner()->getParentOp() == forOp.getOperation()) -- return forOp.getOpOperandForRegionIterArg(arg).get(); -+ return forOp.getTiedLoopInit(arg)->get(); - return Value(); - }; - -@@ -298,10 +298,10 @@ - Operation *firstDot = builder.clone(*dot, mapping); - if (Value a = operand2headPrefetch.lookup(dot.getA())) - firstDot->setOperand( -- 0, newForOp.getRegionIterArgForOpOperand(*a.use_begin())); -+ 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); - if (Value b = operand2headPrefetch.lookup(dot.getB())) - firstDot->setOperand( -- 1, newForOp.getRegionIterArgForOpOperand(*b.use_begin())); -+ 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); - - // remaining part - int64_t kOff = prefetchWidth; -==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -245,7 +245,7 @@ - for (OpOperand &use : value.getUses()) { - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { -- Value arg = forOp.getRegionIterArgForOpOperand(use); -+ Value arg = forOp.getTiedLoopRegionIterArg(&use); - Value result = forOp.getResultForOpOperand(use); - setEncoding({arg, result}, info, changed, user); - continue; -@@ -767,7 +767,7 @@ - SmallVector newOperands; - for (auto arg : forOp.getRegionIterArgs()) { - if (slice.count(arg)) { -- OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &initVal = *forOp.getTiedLoopInit(arg); - argMapping.push_back(std::make_pair( - forOp.getResultForOpOperand(initVal).getResultNumber(), - forOp.getInitArgs().size() + newOperands.size())); -==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -430,10 +430,10 @@ - Block *block = blockArg.getOwner(); - Operation *parentOp = block->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { -- OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); -+ OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); - Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( - blockArg.getArgNumber() - forOp.getNumInductionVars()); -- queue.push_back({initOperand.get(), encoding}); -+ queue.push_back({initOperand->get(), encoding}); - queue.push_back({yieldOperand, encoding}); - continue; - } -==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -88,9 +88,8 @@ - auto parentOp = blockArg.getOwner()->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- if (failed(getDependentPointers( -- forOp.getOpOperandForRegionIterArg(blockArg).get(), -- dependentSet, processedSet))) -+ if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), -+ dependentSet, processedSet))) - return failure(); - - unsigned operandIdx = -@@ -383,7 +382,7 @@ - if (failed(addControlOperandsForForOp(forOp))) - return failure(); - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get(); -+ Value operand = forOp.getTiedLoopInit(blockArg)->get(); - if (failed(tryInsertAndPropagate(operand))) - return failure(); - -==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ==== -# action=edit type=text ---- triton/test/lib/Analysis/TestAlias.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/test/lib/Analysis/TestAlias.cpp 2023-10-27 20:17:47.000000000 -0700 -@@ -87,7 +87,7 @@ - } - if (auto forOp = dyn_cast(op)) { - for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { -- auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); -+ auto operand = forOp.getTiedLoopInit(arg.value())->get(); - auto opNames = getAllocOpNames(operand); - auto argName = getValueOperandName(arg.value(), state); - print(argName, opNames, os); diff --git a/third_party/triton/cl580481414.patch b/third_party/triton/cl580481414.patch deleted file mode 100644 index 130e4860a4e9cd..00000000000000 --- a/third_party/triton/cl580481414.patch +++ /dev/null @@ -1,1512 +0,0 @@ -diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td ---- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td -+++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td -@@ -29,9 +29,8 @@ include "mlir/IR/EnumAttr.td" - include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType - --def I8Ptr_global : LLVM_IntPtrBase<8, 1>; --def I8Ptr_shared : LLVM_IntPtrBase<8, 3>; --def I64Ptr_shared : LLVM_IntPtrBase<64, 3>; -+def LLVM_PointerGlobal : LLVM_OpaquePointerInAddressSpace<1>; -+def LLVM_PointerShared : LLVM_OpaquePointerInAddressSpace<3>; - - class NVGPU_Op traits = []> : - LLVM_OpBase; -@@ -55,7 +54,7 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"w - } - - def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, I32Attr:$count); - let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)"; - } - -@@ -71,12 +70,12 @@ def MBarrier_ArriveTypeAttr : I32EnumAtt - } - - def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); - let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; - } - - def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$phase); - let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)"; - } - -@@ -116,13 +115,13 @@ def NVGPU_WGMMADescCreateOp : NVGPU_Op<" - } - - def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> { -- let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, -+ let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc, - I1:$pred, Variadic:$coords, Optional:$mcastMask); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - - def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> { -- let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic:$coords, I16Attr:$mcastMask); -+ let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic:$coords, I16Attr:$mcastMask); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - -@@ -217,12 +216,12 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"clus - } - - def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic:$coords); -+ let arguments = (ins LLVM_PointerGlobal:$tmaDesc, LLVM_PointerShared:$src, I1:$pred, Variadic:$coords); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - - def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I8Ptr_shared:$addr, Variadic:$datas); -+ let arguments = (ins LLVM_PointerShared:$addr, Variadic:$datas); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -@@ -41,13 +41,16 @@ struct AllocMBarrierOpConversion : publi - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); - auto resultTy = op.getType(); - auto resultTensorTy = resultTy.dyn_cast(); -- Type elemPtrTy; -+ Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Type llvmElemTy; - if (resultTensorTy) { -- auto llvmElemTy = -+ llvmElemTy = - getTypeConverter()->convertType(resultTensorTy.getElementType()); -- elemPtrTy = ptr_ty(llvmElemTy, 3); - } else { -- elemPtrTy = getTypeConverter()->convertType(resultTy); -+ auto resultPtrTy = resultTy.dyn_cast(); -+ assert(resultPtrTy && "Unknown type for AllocMBarrierOp"); -+ llvmElemTy = -+ getTypeConverter()->convertType(resultPtrTy.getPointeeType()); - } - smemBase = bitcast(smemBase, elemPtrTy); - auto threadId = getThreadId(rewriter, loc); -@@ -61,7 +64,7 @@ struct AllocMBarrierOpConversion : publi - for (int i = 0; i < numMBarriers; ++i) { - Value smem = smemBase; - if (i > 0) { -- smem = gep(elemPtrTy, smem, i32_val(i)); -+ smem = gep(elemPtrTy, llvmElemTy, smem, i32_val(i)); - } - rewriter.create(loc, smem, pred, - op.getCount()); -@@ -142,11 +145,11 @@ struct ExtractMBarrierOpConversion - op.getTensor().getType().cast().getElementType(); - auto tensorStruct = adaptor.getTensor(); - auto index = adaptor.getIndex(); -- auto ptrTy = -- LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - auto basePtr = - extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0)); -- Value result = gep(ptrTy, basePtr, index); -+ Value result = -+ gep(ptrTy, getTypeConverter()->convertType(elemTy), basePtr, index); - rewriter.replaceOp(op, result); - return success(); - } -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -@@ -310,10 +310,10 @@ private: - shapePerCTA); - Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, - paddedRepShape, outOrd); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - auto vecTy = vec_ty(llvmElemTy, vec); -- ptr = bitcast(ptr, ptr_ty(vecTy, 3)); -+ ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { -@@ -326,7 +326,7 @@ private: - } - store(valVec, ptr); - } else { -- Value valVec = load(ptr); -+ Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); - if (isInt1) -@@ -423,10 +423,10 @@ private: - for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - auto coord = coord2valT[elemId].first; - Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd); -- auto elemPtrTy = ptr_ty(elemTy, 3); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value ptr = gep(elemPtrTy, elemTy, smemBase, offset); - auto vecTy = vec_ty(elemTy, vec); -- ptr = bitcast(ptr, ptr_ty(vecTy, 3)); -+ ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { -@@ -435,7 +435,7 @@ private: - } - store(valVec, ptr); - } else { -- Value valVec = load(ptr); -+ Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(elemTy, valVec, i32_val(v)); - vals[elemId + v] = currVal; -@@ -462,7 +462,7 @@ private: - unsigned rank = srcShapePerCTA.size(); - - auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); -@@ -480,7 +480,7 @@ private: - - for (unsigned i = 0; i < inIndices.size(); ++i) { - Value offset = linearize(rewriter, loc, inIndices[i], smemShape); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - store(inVals[i], ptr); - } - } -@@ -513,8 +513,8 @@ private: - linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder); - Value localOffset = linearize(rewriter, loc, localCoord, smemShape); - -- Value ptr = gep(elemPtrTy, smemBase, localOffset); -- outVals.push_back(load_dsmem(ptr, remoteCTAId)); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); -+ outVals.push_back(load_dsmem(ptr, remoteCTAId, llvmElemTy)); - } - - Value result = -@@ -545,10 +545,8 @@ private: - - if (shouldUseDistSmem(srcLayout, dstLayout)) - return lowerDistToDistWithDistSmem(op, adaptor, rewriter); -- -- auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - auto shape = dstTy.getShape(); - unsigned rank = dstTy.getRank(); -@@ -747,7 +745,7 @@ private: - auto outOrd = dstSharedLayout.getOrder(); - Value smemBase = getSharedMemoryBase(loc, rewriter, dst); - auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); -- auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - - int32_t elemSize = elemTy.getIntOrFloatBitWidth(); -@@ -774,8 +772,7 @@ private: - unsigned leadingDimOffset = - numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]]; - -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - - uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; - -@@ -804,7 +801,8 @@ private: - loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset, - numElemsPerSwizzlingRow, true); - -- Value addr = gep(elemPtrTy, smemBase, offset); -+ Value addr = gep(elemPtrTy, getTypeConverter()->convertType(elemTy), -+ smemBase, offset); - - Value words[4]; - for (unsigned i = 0; i < 8; ++i) { -@@ -815,7 +813,7 @@ private: - } - - rewriter.create( -- loc, bitcast(addr, ptrI8SharedTy), -+ loc, bitcast(addr, ptrSharedTy), - ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), - bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); - } -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -@@ -133,10 +133,10 @@ Value loadAFMA(Value A, Value llA, Block - auto elemTy = typeConverter->convertType( - A.getType().cast().getElementType()); - -- Type ptrTy = ptr_ty(elemTy, 3); -+ Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) -- aPtrs[i] = gep(ptrTy, aSmem.base, aOff[i]); -+ aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); - - SmallVector vas; - -@@ -148,8 +148,8 @@ Value loadAFMA(Value A, Value llA, Block - for (unsigned mm = 0; mm < mSizePerThread; ++mm) { - Value offset = - add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); -- Value pa = gep(ptrTy, aPtrs[0], offset); -- Value va = load(pa); -+ Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); -+ Value va = load(elemTy, pa); - vas.emplace_back(va); - } - -@@ -200,10 +200,10 @@ Value loadBFMA(Value B, Value llB, Block - auto elemTy = typeConverter->convertType( - B.getType().cast().getElementType()); - -- Type ptrTy = ptr_ty(elemTy, 3); -+ Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) -- bPtrs[i] = gep(ptrTy, bSmem.base, bOff[i]); -+ bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); - - SmallVector vbs; - -@@ -215,8 +215,8 @@ Value loadBFMA(Value B, Value llB, Block - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value offset = - add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); -- Value pb = gep(ptrTy, bPtrs[0], offset); -- Value vb = load(pb); -+ Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); -+ Value vb = load(elemTy, pb); - vbs.emplace_back(vb); - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp -@@ -150,10 +150,10 @@ static Value loadA(Value tensor, const S - } - - Type elemX2Ty = vec_ty(f16_ty, 2); -- Type elemPtrTy = ptr_ty(f16_ty, 3); -+ Type elemTy = f16_ty; - if (tensorTy.getElementType().isBF16()) { - elemX2Ty = vec_ty(i16_ty, 2); -- elemPtrTy = ptr_ty(i16_ty, 3); -+ elemTy = i16_ty; - } - - // prepare arguments -@@ -161,22 +161,23 @@ static Value loadA(Value tensor, const S - - std::map, std::pair> has; - for (int i = 0; i < numPtrA; i++) -- ptrA[i] = gep(ptr_ty(f16_ty, 3), smemBase, offA[i]); -+ ptrA[i] = gep(ptr_ty(ctx, 3), f16_ty, smemBase, offA[i]); - - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - auto loadA = [&](int m, int k) { - int offidx = (isARow ? k / 4 : m) % numPtrA; -- Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]); -+ Value thePtrA = gep(ptr_ty(ctx, 3), elemTy, smemBase, offA[offidx]); - - int stepAM = isARow ? m : m / numPtrA * numPtrA; - int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; - Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), - mul(i32_val(stepAK), strideAK)); -- Value pa = gep(elemPtrTy, thePtrA, offset); -- Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); -- Value ha = load(bitcast(pa, aPtrTy)); -+ Value pa = gep(ptr_ty(ctx, 3), elemTy, thePtrA, offset); -+ Type vecTy = vec_ty(i32_ty, std::max(vecA / 2, 1)); -+ Type aPtrTy = ptr_ty(ctx, 3); -+ Value ha = load(vecTy, bitcast(pa, aPtrTy)); - // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty); -@@ -273,17 +274,17 @@ static Value loadB(Value tensor, const S - offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); - } - -- Type elemPtrTy = ptr_ty(f16_ty, 3); -+ Type elemTy = f16_ty; - Type elemX2Ty = vec_ty(f16_ty, 2); - if (tensorTy.getElementType().isBF16()) { -- elemPtrTy = ptr_ty(i16_ty, 3); -+ elemTy = i16_ty; - elemX2Ty = vec_ty(i16_ty, 2); - } - - SmallVector ptrB(numPtrB); - ValueTable hbs; - for (int i = 0; i < numPtrB; ++i) -- ptrB[i] = gep(ptr_ty(f16_ty, 3), smem, offB[i]); -+ ptrB[i] = gep(ptr_ty(ctx, 3), f16_ty, smem, offB[i]); - - auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; -@@ -297,10 +298,10 @@ static Value loadB(Value tensor, const S - int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); - Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), - mul(i32_val(stepBK), strideBK)); -- Value pb = gep(elemPtrTy, thePtrB, offset); -+ Value pb = gep(ptr_ty(ctx, 3), elemTy, thePtrB, offset); - -- Value hb = -- load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); -+ Type vecTy = vec_ty(i32_ty, std::max(vecB / 2, 1)); -+ Value hb = load(vecTy, bitcast(pb, ptr_ty(ctx, 3))); - // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty); -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp -@@ -280,9 +280,8 @@ SmallVector MMA16816SmemLoader::c - return offs; - } - --std::tuple --MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef ptrs, Type matTy, -- Type shemPtrTy) const { -+std::tuple MMA16816SmemLoader::loadX4( -+ int mat0, int mat1, ArrayRef ptrs, Type matTy, Type shemTy) const { - assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); - int matIdx[2] = {mat0, mat1}; - -@@ -321,7 +320,7 @@ MMA16816SmemLoader::loadX4(int mat0, int - Value stridedOffset = - mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape), - stridedSmemOffset); -- Value readPtr = gep(shemPtrTy, ptr, stridedOffset); -+ Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset); - - PTXBuilder builder; - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a -@@ -363,7 +362,7 @@ MMA16816SmemLoader::loadX4(int mat0, int - - for (int i = 0; i < 4; ++i) - for (int j = 0; j < vecWidth; ++j) { -- vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]); -+ vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); - } - // row + trans and col + no-trans are equivalent - bool isActualTrans = -@@ -381,8 +380,8 @@ MMA16816SmemLoader::loadX4(int mat0, int - int e = em % vecWidth; - int m = em / vecWidth; - int idx = m * 2 + r; -- Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3)); -- Value val = load(ptr); -+ Value ptr = bitcast(vptrs[idx][e], ptr_ty(ctx, 3)); -+ Value val = load(packedTy, ptr); - Value canonval = bitcast(val, vec_ty(canonInt, canonWidth)); - for (int w = 0; w < canonWidth; ++w) { - int ridx = idx + w * kWidth / vecWidth; -@@ -455,16 +454,16 @@ MMA16816SmemLoader::MMA16816SmemLoader( - warpMatOffset = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; - } - --Type getSharedMemPtrTy(Type argType) { -+Type getSharedMemTy(Type argType) { - MLIRContext *ctx = argType.getContext(); - if (argType.isF16()) -- return ptr_ty(type::f16Ty(ctx), 3); -+ return type::f16Ty(ctx); - else if (argType.isBF16()) -- return ptr_ty(type::i16Ty(ctx), 3); -+ return type::i16Ty(ctx); - else if (argType.isF32()) -- return ptr_ty(type::f32Ty(ctx), 3); -+ return type::f32Ty(ctx); - else if (argType.getIntOrFloatBitWidth() == 8) -- return ptr_ty(type::i8Ty(ctx), 3); -+ return type::i8Ty(ctx); - else - llvm::report_fatal_error("mma16816 data type not supported"); - } -@@ -531,15 +530,16 @@ std::function getLoadMat - const int numPtrs = loader.getNumPtrs(); - SmallVector ptrs(numPtrs); - Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); -- Type smemPtrTy = getSharedMemPtrTy(eltTy); -+ Type smemTy = getSharedMemTy(eltTy); - for (int i = 0; i < numPtrs; ++i) -- ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy); -+ ptrs[i] = -+ gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]); - // actually load from shared memory - auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(), - SmallVector(4, i32_ty)); - auto [ha0, ha1, ha2, ha3] = loader.loadX4( - (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, ptrs, -- matTy, getSharedMemPtrTy(eltTy)); -+ matTy, getSharedMemTy(eltTy)); - if (!isA) - std::swap(ha1, ha2); - // the following is incorrect -diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -@@ -561,8 +561,7 @@ struct StoreAsyncOpConversion - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(ctx, 3); - - auto threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(threadId, i32_val(0)); -@@ -599,9 +598,10 @@ struct StoreAsyncOpConversion - } - } - Value srcOffset = i32_val(b * boxStride); -- auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -- Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset); -- auto addr = bitcast(srcPtrBase, ptrI8SharedTy); -+ auto srcPtrTy = ptr_ty(ctx, 3); -+ Value srcPtrBase = gep(srcPtrTy, getTypeConverter()->convertType(elemTy), -+ smemObj.base, srcOffset); -+ auto addr = bitcast(srcPtrBase, ptrSharedTy); - rewriter.create(loc, tmaDesc, addr, pred, - coord); - } -@@ -749,7 +749,7 @@ struct StoreAsyncOpConversion - Value llDst = adaptor.getDst(); - Value llSrc = adaptor.getSrc(); - auto srcShape = srcTy.getShape(); -- auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); -+ auto dstElemPtrTy = ptr_ty(ctx, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, dstElemPtrTy); - -@@ -760,8 +760,7 @@ struct StoreAsyncOpConversion - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(ctx, 3); - - auto threadId = getThreadId(rewriter, loc); - Value pred = int_val(1, 1); -@@ -817,7 +816,9 @@ struct StoreAsyncOpConversion - i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset, - numElemsPerSwizzlingRow, true); - -- Value addr = gep(dstElemPtrTy, smemBase, offset); -+ Value addr = -+ gep(dstElemPtrTy, getTypeConverter()->convertType(dstElemTy), -+ smemBase, offset); - Value words[4]; - for (unsigned i = 0; i < 8; ++i) { - if (i % minVec == 0) -@@ -827,7 +828,7 @@ struct StoreAsyncOpConversion - } - - rewriter.create( -- loc, bitcast(addr, ptrI8SharedTy), -+ loc, bitcast(addr, ptrSharedTy), - ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), - bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); - } -@@ -860,9 +861,11 @@ struct StoreAsyncOpConversion - instrShape[1] * warpsPerCTA[1] / - numBox), - mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow))); -- auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); -- Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset); -- auto addr = bitcast(srcPtrBase, ptrI8SharedTy); -+ auto srcPtrTy = ptr_ty(ctx, 3); -+ Value srcPtrBase = -+ gep(srcPtrTy, getTypeConverter()->convertType(dstElemTy), smemBase, -+ srcOffset); -+ auto addr = bitcast(srcPtrBase, ptrSharedTy); - rewriter.create(loc, tmaDesc, addr, - pred, coord); - } -@@ -1022,7 +1025,7 @@ struct AtomicCASOpConversion - auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - createBarrier(rewriter, loc, numCTAs); - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); -+ atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); - // Only threads with mask = True store the result - PTXBuilder ptxBuilderStore; - auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); -@@ -1033,7 +1036,7 @@ struct AtomicCASOpConversion - auto ASMReturnTy = void_ty(ctx); - ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); -- Value ret = load(atomPtr); -+ Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); - rewriter.replaceOp(op, {ret}); - } -@@ -1194,7 +1197,7 @@ struct AtomicRMWOpConversion - return success(); - } - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); -+ atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); - // Only threads with rmwMask = True store the result - PTXBuilder ptxBuilderStore; - auto &storeShared = -@@ -1204,7 +1207,7 @@ struct AtomicRMWOpConversion - storeShared(ptrOpr, valOpr).predicate(rmwMask); - ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - createBarrier(rewriter, loc, numCTAs); -- Value ret = load(atomPtr); -+ Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); - rewriter.replaceOp(op, {ret}); - } -@@ -1273,8 +1276,8 @@ struct InsertSliceOpConversion - // object - auto offset = dot(rewriter, loc, offsets, smemObj.strides); - auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); -- auto elemPtrTy = ptr_ty(elemTy, 3); -- auto smemBase = gep(elemPtrTy, smemObj.base, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ auto smemBase = gep(elemPtrTy, elemTy, smemObj.base, offset); - - auto llSrc = adaptor.getSource(); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); -@@ -1357,8 +1360,8 @@ struct InsertSliceAsyncOpConversion - // Compute the offset based on the original dimensions of the shared - // memory object - auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); -- auto dstPtrTy = ptr_ty(resElemTy, 3); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); - - // %mask - SmallVector maskElems; -@@ -1638,7 +1641,7 @@ struct InsertSliceAsyncV2OpConversion - // currently only support rank == 2. - dstOffsetCommon = - add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0]))); -- auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -@@ -1646,8 +1649,7 @@ struct InsertSliceAsyncV2OpConversion - // cache-policy modes - Value l2Desc = int_val(64, 0x1000000000000000ll); - -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - - SmallVector coordCommon; - auto llCoord = getTypeConverter()->unpackLLElements( -@@ -1688,11 +1690,12 @@ struct InsertSliceAsyncV2OpConversion - for (size_t i = 0; i < numBoxes; ++i) { - Value dstOffset = - add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast)); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ Value dstPtrBase = gep(dstPtrTy, getTypeConverter()->convertType(elemTy), -+ smemObj.base, dstOffset); - SmallVector coord = coordCommon; - coord[0] = add(coordCommon[0], i32_val(i * boxDims[0])); - rewriter.create( -- loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc, -+ loc, bitcast(dstPtrBase, ptrSharedTy), adaptor.getMbar(), tmaDesc, - l2Desc, pred, coord, mcastMask); - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp -@@ -149,13 +149,11 @@ private: - // Assign base index to each operand in their order in indices - std::map indexToBase; - indexToBase[indices[0]] = -- bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), -- getElementPtrType(op, indices[0])); -+ getSharedMemoryBase(loc, rewriter, op.getOperation()); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { -- indexToBase[indices[i]] = -- bitcast(gep(getElementPtrType(op, indices[i - 1]), -- indexToBase[indices[i - 1]], i32_val(elems)), -- getElementPtrType(op, indices[i])); -+ indexToBase[indices[i]] = gep( -+ ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), -+ indexToBase[indices[i - 1]], i32_val(elems)); - } - // smemBases[k] is the base pointer for the k-th operand - SmallVector smemBases(op.getNumOperands()); -@@ -335,11 +333,10 @@ private: - rewriter.replaceOp(op, results); - } - -- // Return the type of the shared memory pointer for operand i. -- Type getElementPtrType(triton::ReduceOp op, int i) const { -+ // Return the pointee type of the shared memory pointer for operand i. -+ Type getElementType(triton::ReduceOp op, int i) const { - auto ty = op.getInputTypes()[i].getElementType(); -- auto llvmElemTy = getTypeConverter()->convertType(ty); -- return LLVM::LLVMPointerType::get(llvmElemTy, 3); -+ return getTypeConverter()->convertType(ty); - } - - SmallVector -@@ -408,8 +405,9 @@ private: - Value writeOffset = - linearize(rewriter, loc, writeIdx, smemShape, smemOrder); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); -+ auto elemTy = getElementType(op, i); -+ Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], writeOffset); - storeShared(rewriter, loc, writePtr, acc[i], laneZero); - } - } -@@ -442,17 +440,19 @@ private: - for (unsigned round = 0; round < elemsPerThread; ++round) { - SmallVector acc(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- Value readPtr = gep(elemPtrTy, smemBases[i], readOffset); -- acc[i] = loadShared(rewriter, loc, readPtr, threadIsNeeded); -+ auto elemTy = getElementType(op, i); -+ Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], readOffset); -+ acc[i] = loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); - } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); - // only the first thread in each sizeInterWarps is writing - Value writeOffset = readOffset; - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset); -+ auto elemTy = getElementType(op, i); -+ writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], writeOffset); - } - - Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); -@@ -483,6 +483,7 @@ private: - auto smemOrder = helper.getOrderWithAxisAtBeginning(); - SmallVector results(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -+ auto elemTy = getElementType(op, i); - if (auto resultTy = - op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 -@@ -497,16 +498,16 @@ private: - readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, smemOrder); -- Value readPtr = -- gep(getElementPtrType(op, i), smemBases[i], readOffset); -- resultVals[j] = load(readPtr); -+ Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], readOffset); -+ resultVals[j] = load(elemTy, readPtr); - } - - results[i] = getTypeConverter()->packLLElements(loc, resultVals, - rewriter, resultTy); - } else { - // 0d-tensor -> scalar -- results[i] = load(smemBases[i]); -+ results[i] = load(elemTy, smemBases[i]); - } - } - rewriter.replaceOp(op, results); -diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp -@@ -112,7 +112,8 @@ static void storeWarpAccumulator(SmallVe - Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); - Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); - index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); -- Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index); -+ Value writePtr = gep(baseSharedMemPtr.getType(), lastElement.getType(), -+ baseSharedMemPtr, index); - storeShared(rewriter, loc, writePtr, lastElement, mask); - chunkId++; - } -@@ -170,8 +171,9 @@ static void AddPartialReduce(SmallVector - for (unsigned i = 0; i < axisNumWarps; ++i) { - Value index = add(parallelLaneId, i32_val(numParallelLane * - (i + chunkId * axisNumWarps))); -- Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index); -- Value partialReduce = load(ptr); -+ Value ptr = gep(sharedMemoryPtr.getType(), srcValues[srcIndex].getType(), -+ sharedMemoryPtr, index); -+ Value partialReduce = load(srcValues[srcIndex].getType(), ptr); - if (!accumulator.acc) { - accumulator.acc = partialReduce; - accumulator.maskedAcc = partialReduce; -@@ -411,7 +413,7 @@ ScanOpConversion::emitFastScan(triton::S - if (axisNumWarps > 1) { - // Slow path for the case where there are multiple warps with unique data on - // the axis. -- Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); -+ Type elemPtrTys = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - Value baseSharedMemPtr = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); - // Store the partial reducing for each warp into shared memory. -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -@@ -305,8 +305,7 @@ struct PrintOpConversion - - auto *context = rewriter.getContext(); - -- SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), -- ptr_ty(IntegerType::get(context, 8))}; -+ SmallVector argsType{ptr_ty(context), ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); -@@ -359,9 +358,8 @@ struct PrintOpConversion - - static void llPrintf(Value msg, ValueRange args, - ConversionPatternRewriter &rewriter) { -- Type int8Ptr = ptr_ty(i8_ty); -- - auto *ctx = rewriter.getContext(); -+ Type ptr = ptr_ty(ctx); - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - auto funcOp = getVprintfDeclaration(rewriter); -@@ -370,7 +368,7 @@ struct PrintOpConversion - Value one = i32_val(1); - Value zero = i32_val(0); - -- Value bufferPtr = null(int8Ptr); -+ Value bufferPtr = null(ptr); - - SmallVector newArgs; - if (args.size() >= 1) { -@@ -385,16 +383,16 @@ struct PrintOpConversion - - Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes); - auto allocated = -- rewriter.create(loc, ptr_ty(structTy), one, -+ rewriter.create(loc, ptr_ty(ctx), structTy, one, - /*alignment=*/0); - - for (const auto &entry : llvm::enumerate(newArgs)) { - auto index = i32_val(entry.index()); -- auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated, -+ auto fieldPtr = gep(ptr_ty(ctx), argTypes[entry.index()], allocated, - ArrayRef{zero, index}); - store(entry.value(), fieldPtr); - } -- bufferPtr = bitcast(allocated, int8Ptr); -+ bufferPtr = bitcast(allocated, ptr); - } - - SmallVector operands{msg, bufferPtr}; -@@ -488,8 +486,7 @@ struct AssertOpConversion - // void __assert_fail(const char * assertion, const char * file, unsigned - // int line, const char * function); - auto *ctx = rewriter.getContext(); -- SmallVector argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty, -- ptr_ty(i8_ty), -+ SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), - rewriter.getIntegerType(sizeof(size_t) * 8)}; - auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - -@@ -623,11 +620,14 @@ struct AddPtrOpConversion - Location loc = op->getLoc(); - auto resultTy = op.getType(); - auto offsetTy = op.getOffset().getType(); -- auto ptrTy = op.getPtr().getType(); - auto resultTensorTy = resultTy.dyn_cast(); - if (resultTensorTy) { - unsigned elems = getTotalElemsPerThread(resultTy); - Type elemTy = -+ getTypeConverter()->convertType(resultTensorTy.getElementType() -+ .cast() -+ .getPointeeType()); -+ Type ptrTy = - getTypeConverter()->convertType(resultTensorTy.getElementType()); - auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(), - rewriter, ptrTy); -@@ -635,15 +635,18 @@ struct AddPtrOpConversion - loc, adaptor.getOffset(), rewriter, offsetTy); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { -- resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); -+ resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); - } - Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, view); - } else { - assert(resultTy.isa()); -- Type llResultTy = getTypeConverter()->convertType(resultTy); -- Value result = gep(llResultTy, adaptor.getPtr(), adaptor.getOffset()); -+ auto resultPtrTy = getTypeConverter()->convertType(resultTy); -+ auto resultElemTy = getTypeConverter()->convertType( -+ resultTy.cast().getPointeeType()); -+ Value result = -+ gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); - rewriter.replaceOp(op, result); - } - return success(); -@@ -661,9 +664,7 @@ struct AllocTensorOpConversion - Location loc = op->getLoc(); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); - auto resultTy = op.getType().dyn_cast(); -- auto llvmElemTy = -- getTypeConverter()->convertType(resultTy.getElementType()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - auto sharedLayout = resultTy.getEncoding().cast(); - auto order = sharedLayout.getOrder(); -@@ -679,6 +680,8 @@ struct AllocTensorOpConversion - newOrder = SmallVector(order.begin(), order.end()); - } - -+ auto llvmElemTy = -+ getTypeConverter()->convertType(resultTy.getElementType()); - auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); - auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, - newOrder, loc, rewriter); -@@ -737,9 +740,10 @@ struct ExtractSliceOpConversion - } - } - -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -- smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), -- llvmElemTy, strideVals, offsetVals); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ smemObj = -+ SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), -+ llvmElemTy, strideVals, offsetVals); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -@@ -261,8 +261,7 @@ public: - template - Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, - T value) const { -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - FunctionOpInterface funcOp; - if constexpr (std::is_pointer_v) - funcOp = value->template getParentOfType(); -@@ -275,7 +274,9 @@ public: - assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); - size_t offset = funcAllocation->getOffset(bufferId); - Value offVal = i32_val(offset); -- Value base = gep(ptrTy, smem, offVal); -+ Value base = -+ gep(ptrTy, this->getTypeConverter()->convertType(rewriter.getI8Type()), -+ smem, offVal); - return base; - } - -@@ -312,9 +313,10 @@ public: - // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y - // This means that we can use some immediate offsets for shared memory - // operations. -- auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); - auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ Value dstPtrBase = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ smemObj.base, dstOffset); - - auto srcEncoding = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); -@@ -423,7 +425,8 @@ public: - Value colOff = add(colOffSwizzled, colOffOrdered); - // compute non-immediate offset - offset = add(offset, add(rowOff, mul(colOff, strideCol))); -- Value currPtr = gep(dstPtrTy, dstPtrBase, offset); -+ Value currPtr = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ dstPtrBase, offset); - // compute immediate offset - Value immediateOff; - if (outOrder.size() == 2) { -@@ -434,7 +437,8 @@ public: - immediateOff = i32_val(immedateOffCol); - } - -- ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff); -+ ret[elemIdx] = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ currPtr, immediateOff); - } - return ret; - } -@@ -479,8 +483,8 @@ public: - SmallVector outVals(outElems); - for (unsigned i = 0; i < numVecs; ++i) { - Value smemAddr = sharedPtrs[i * minVec]; -- smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); -- Value valVec = load(smemAddr); -+ smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); -+ Value valVec = load(wordTy, smemAddr); - for (unsigned v = 0; v < minVec; ++v) { - Value currVal = extract_element(dstElemTy, valVec, i32_val(v)); - outVals[i * minVec + v] = currVal; -@@ -537,7 +541,7 @@ public: - word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); - if (i % minVec == minVec - 1) { - Value smemAddr = sharedPtrs[i / minVec * minVec]; -- smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); -+ smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); - store(word, smemAddr); - } - } -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp -@@ -161,8 +161,7 @@ struct FuncOpConversion : public FuncOpC - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - // 1. Modify the function type to add the new argument. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); -@@ -232,15 +231,14 @@ struct FuncOpConversion : public FuncOpC - allocation.mapFuncOp(funcOp, newFuncOp); - - // Append arguments to receive TMADesc in global memory in the runtime -- auto i8PtrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 1); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); - auto numArgs = newFuncOp.getBody().front().getNumArguments(); - auto funcTy = newFuncOp.getFunctionType().cast(); - SmallVector newInputsTy(funcTy.getParams().begin(), - funcTy.getParams().end()); - for (unsigned i = 0; i < numTMA; ++i) { -- newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc()); -- newInputsTy.push_back(i8PtrTy); -+ newFuncOp.getBody().front().addArgument(ptrTy, funcOp.getLoc()); -+ newInputsTy.push_back(ptrTy); - } - newFuncOp.setType( - LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy)); -@@ -296,9 +294,8 @@ private: - // of shared memory and append it to the operands of the callOp. - auto loc = callOp.getLoc(); - auto caller = callOp->getParentOfType(); -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), -- NVVM::kSharedMemorySpace); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), -+ NVVM::kSharedMemorySpace); - auto promotedOperands = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); -@@ -312,7 +309,9 @@ private: - } - // function has a shared mem buffer - auto offset = funcAllocation->getOffset(bufferId); -- auto offsetValue = gep(ptrTy, base, i32_val(offset)); -+ auto offsetValue = -+ gep(ptrTy, this->getTypeConverter()->convertType(rewriter.getI8Type()), -+ base, i32_val(offset)); - promotedOperands.push_back(offsetValue); - return promotedOperands; - } -@@ -612,9 +611,8 @@ private: - } else { - funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1); - } -- auto ptrTy = -- LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), -- NVVM::NVVMMemorySpace::kSharedMemorySpace); -+ auto ptrTy = LLVM::LLVMPointerType::get( -+ ctx, NVVM::NVVMMemorySpace::kSharedMemorySpace); - funcSmem = b.create(loc, ptrTy, funcSmem); - allocation.setFunctionSharedMemoryValue(funcOp, funcSmem); - }); -diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp -@@ -60,13 +60,11 @@ Type TritonGPUToLLVMTypeConverter::conve - for (size_t i = 0; i < 2 * shape.size(); ++i) - types.push_back(IntegerType::get(ctx, 64)); - -- types.push_back( -- LLVM::LLVMPointerType::get(eleType, type.getAddressSpace())); -+ types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); - - return LLVM::LLVMStructType::getLiteral(ctx, types); - } -- return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()), -- type.getAddressSpace()); -+ return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); - } - - Value TritonGPUToLLVMTypeConverter::packLLElements( -@@ -145,7 +143,7 @@ Type TritonGPUToLLVMTypeConverter::conve - if (auto shared_layout = layout.dyn_cast()) { - SmallVector types; - // base ptr -- auto ptrType = LLVM::LLVMPointerType::get(eltType, 3); -+ auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); - types.push_back(ptrType); - // shape dims - auto rank = type.getRank(); -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp ---- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp -@@ -46,12 +46,11 @@ Value createLLVMIntegerConstant(OpBuilde - // (2) Create LoadDSmemOp - // (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy - Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, -- Value ctaId) { -+ Value ctaId, Type elemTy) { - assert(addr.getType().isa() && - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value ret = - rewriter.create(loc, addr, ctaId, bitwidth); -@@ -63,12 +62,12 @@ Value createLoadDSmem(Location loc, Patt - // (2) Create LoadDSmemOp and extract results from retStruct - // (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy - SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, -- Value addr, Value ctaId, unsigned vec) { -+ Value addr, Value ctaId, unsigned vec, -+ Type elemTy) { - assert(addr.getType().isa() && - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value retStruct = rewriter.create( - loc, addr, ctaId, bitwidth, vec); -@@ -91,8 +90,7 @@ void createStoreDSmem(Location loc, Patt - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); -- unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); -+ unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = bitcast(value, dataTy); - rewriter.create(loc, addr, ctaId, data, pred); -@@ -115,8 +113,10 @@ void createStoreDSmem(Location loc, Patt - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); -- unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); -+ unsigned bitwidth = 0; -+ if (!values.empty()) { -+ bitwidth = values.back().getType().getIntOrFloatBitWidth(); -+ } - auto dataTy = rewriter.getIntegerType(bitwidth); - SmallVector data; - for (unsigned i = 0; i < values.size(); ++i) -@@ -253,11 +253,10 @@ Value storeShared(ConversionPatternRewri - } - - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, -- Value pred) { -+ Type elemTy, Value pred) { - MLIRContext *ctx = rewriter.getContext(); - auto ptrTy = ptr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); - - const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); -@@ -363,12 +362,11 @@ Value addStringToModule(Location loc, Co - } - - Value zero = i32_val(0); -- Type globalPtrType = -- LLVM::LLVMPointerType::get(globalType, global.getAddrSpace()); -+ Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); - Value globalPtr = rewriter.create( - UnknownLoc::get(ctx), globalPtrType, global.getSymName()); - Value stringStart = -- rewriter.create(UnknownLoc::get(ctx), ptr_ty(i8_ty), -+ rewriter.create(UnknownLoc::get(ctx), ptr_ty(ctx), i8_ty, - globalPtr, SmallVector({zero, zero})); - return stringStart; - } -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h ---- a/lib/Conversion/TritonGPUToLLVM/Utility.h -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h -@@ -209,9 +209,10 @@ Value createLLVMIntegerConstant(OpBuilde - /// (1) load_dsmem(addr, ctaId) - /// (2) load_dsmem(addr, ctaId, vec) - Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, -- Value ctaId); -+ Value ctaId, Type elemTy); - SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, -- Value addr, Value ctaId, unsigned vec); -+ Value addr, Value ctaId, unsigned vec, -+ Type elemTy); - - /// Usage of macro store_dsmem - /// (1) store_dsmem(addr, ctaId, value, pred) -@@ -257,17 +258,12 @@ struct SharedMemoryObject { - : base(base), - baseElemType(baseElemType), - strides(strides.begin(), strides.end()), -- offsets(offsets.begin(), offsets.end()) { -- assert(baseElemType == -- base.getType().cast().getElementType()); -- } -+ offsets(offsets.begin(), offsets.end()) {} - - SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) - : base(base), baseElemType(baseElemType) { -- assert(baseElemType == -- base.getType().cast().getElementType()); - strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - offsets.append(order.size(), i32_val(0)); - } -@@ -332,7 +328,7 @@ Value storeShared(ConversionPatternRewri - Value val, Value pred); - - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, -- Value pred); -+ Type elemTy, Value pred); - - Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -diff --git a/lib/Dialect/NVGPU/IR/Dialect.cpp b/lib/Dialect/NVGPU/IR/Dialect.cpp ---- a/lib/Dialect/NVGPU/IR/Dialect.cpp -+++ b/lib/Dialect/NVGPU/IR/Dialect.cpp -@@ -73,7 +73,8 @@ void StoreDSmemOp::build(OpBuilder &buil - unsigned StoreDSmemOp::getBitwidth() { - auto addrTy = getAddr().getType(); - assert(addrTy.isa() && "addr must be a pointer type"); -- auto elemTy = addrTy.cast().getElementType(); -+ if (getValues().empty()) return 0; -+ auto elemTy = getValues().back().getType(); - return elemTy.getIntOrFloatBitWidth(); - } - -diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir ---- a/test/Conversion/tritongpu_to_llvm.mlir -+++ b/test/Conversion/tritongpu_to_llvm.mlir -@@ -1,7 +1,7 @@ - // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s - - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { -- // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) -+ // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) - // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32] - tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { -@@ -560,9 +560,9 @@ module attributes {"triton_gpu.num-ctas" - %index = arith.constant 1 : i32 - - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf16, #A> - tt.return - } -@@ -752,38 +752,38 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -799,14 +799,14 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -822,20 +822,20 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -889,12 +889,12 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_mmav2_block - tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0> - tt.return - } -@@ -909,16 +909,16 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_mmav1_block - tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked> - tt.return - } -@@ -932,9 +932,9 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_blocked_shared - tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> - tt.return - } -@@ -947,7 +947,7 @@ module attributes {"triton_gpu.num-ctas" - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_blocked1d_to_slice0 - tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { -- // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> -+ // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<3> - %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - tt.return - } -@@ -960,7 +960,7 @@ module attributes {"triton_gpu.num-ctas" - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_blocked1d_to_slice1 - tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { -- // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr, 3> -+ // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<3> - %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - tt.return - } -diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir ---- a/test/Conversion/tritongpu_to_llvm_hopper.mlir -+++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir -@@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" - %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> - %c0 = arith.constant 0 : i32 - %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> -- // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 -+ // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32 - %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> - tt.return - } -@@ -73,7 +73,7 @@ module attributes {"triton_gpu.num-ctas" - %src = triton_gpu.alloc_tensor : tensor<64x64xf32, #shared> - %c0 = arith.constant 0 : i32 - %dst = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> -- // CHECK: nvgpu.tma_store_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr, i1, i32, i32 -+ // CHECK: nvgpu.tma_store_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr<1>, !llvm.ptr<3>, i1, i32, i32 - triton_nvidia_gpu.store_async %dst, %src {cache = 1 : i32} : !tt.ptr, 1>, tensor<64x64xf32, #shared> - tt.return - } -diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir ---- a/test/NVGPU/test_cga.mlir -+++ b/test/NVGPU/test_cga.mlir -@@ -14,11 +14,11 @@ module attributes {"triton_gpu.num-warps - nvgpu.cga_barrier_arrive - nvgpu.cga_barrier_wait - -- %ptr = llvm.mlir.zero : !llvm.ptr -+ %ptr = llvm.mlir.zero : !llvm.ptr<3> - - // CHECK: llvm.inline_asm - %v = nvgpu.cluster_id -- llvm.store %v, %ptr : !llvm.ptr -+ llvm.store %v, %ptr : i32, !llvm.ptr<3> - - tt.return - } -diff --git a/test/NVGPU/test_mbarrier.mlir b/test/NVGPU/test_mbarrier.mlir ---- a/test/NVGPU/test_mbarrier.mlir -+++ b/test/NVGPU/test_mbarrier.mlir -@@ -2,18 +2,18 @@ - #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @test_mbarrier() { -- %mbarrier = llvm.mlir.zero : !llvm.ptr -+ %mbarrier = llvm.mlir.zero : !llvm.ptr<3> - %pred = arith.constant 1 : i1 - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr -+ nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 1 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 1 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 0 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 0 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 2 : i32, txCount = 128 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 2 : i32, txCount = 128 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_wait %mbarrier, %pred : !llvm.ptr, i1 -+ nvgpu.mbarrier_wait %mbarrier, %pred : !llvm.ptr<3>, i1 - tt.return - } - } // end module -diff --git a/test/NVGPU/test_tma.mlir b/test/NVGPU/test_tma.mlir ---- a/test/NVGPU/test_tma.mlir -+++ b/test/NVGPU/test_tma.mlir -@@ -2,9 +2,9 @@ - #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) { -- %mbarrier = llvm.mlir.zero : !llvm.ptr -- %tmaDesc = llvm.mlir.zero : !llvm.ptr -- %dst = llvm.mlir.zero : !llvm.ptr -+ %mbarrier = llvm.mlir.zero : !llvm.ptr<3> -+ %tmaDesc = llvm.mlir.zero : !llvm.ptr<1> -+ %dst = llvm.mlir.zero : !llvm.ptr<3> - %l2desc = arith.constant 0 : i64 - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 -@@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps - - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i32, i32 - - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i16 -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i16 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i32, i32 - - tt.return - } diff --git a/third_party/triton/cl580550344.patch b/third_party/triton/cl580550344.patch deleted file mode 100644 index c1d598de65dc7e..00000000000000 --- a/third_party/triton/cl580550344.patch +++ /dev/null @@ -1,304 +0,0 @@ -diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -@@ -67,8 +67,10 @@ struct AllocMBarrierOpConversion : publi - op.getCount()); - } - if (resultTensorTy) { -- auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(), -- {0}, loc, rewriter); -+ auto llvmElemTy = -+ getTypeConverter()->convertType(resultTensorTy.getElementType()); -+ auto smemObj = SharedMemoryObject( -+ smemBase, llvmElemTy, resultTensorTy.getShape(), {0}, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - } else { -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -@@ -707,8 +707,9 @@ private: - auto dstLayout = dstTy.getEncoding(); - auto inOrd = getOrder(srcSharedLayout); - -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, adaptor.getSrc(), -+ getTypeConverter()->convertType(srcTy.getElementType()), rewriter); - auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); - - auto srcStrides = -@@ -843,8 +844,8 @@ private: - storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, - dst, smemBase, elemTy, loc, rewriter); - } -- auto smemObj = -- SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter); -+ auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd, -+ loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -@@ -1013,8 +1014,11 @@ private: - Value dst = op.getResult(); - bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor()); - -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto llvmElemTy = getTypeConverter()->convertType( -+ src.getType().cast().getElementType()); -+ -+ auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), -+ llvmElemTy, rewriter); - Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -@@ -101,7 +101,9 @@ Value loadAFMA(Value A, Value llA, Block - - bool isARow = aOrder[0] == 1; - -- auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter); -+ auto aSmem = getSharedMemoryObjectFromStruct( -+ loc, llA, typeConverter->convertType(aTensorTy.getElementType()), -+ rewriter); - Value strideAM = aSmem.strides[0]; - Value strideAK = aSmem.strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; -@@ -166,7 +168,9 @@ Value loadBFMA(Value B, Value llB, Block - - bool isBRow = bOrder[0] == 1; - -- auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter); -+ auto bSmem = getSharedMemoryObjectFromStruct( -+ loc, llB, typeConverter->convertType(bTensorTy.getElementType()), -+ rewriter); - Value strideBN = bSmem.strides[1]; - Value strideBK = bSmem.strides[0]; - Value strideB0 = isBRow ? strideBN : strideBK; -diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp -@@ -332,8 +332,15 @@ LogicalResult convertDot(TritonGPUToLLVM - Value baseA; - Value baseB; - if (aSharedLayout) -- baseA = getSharedMemoryObjectFromStruct(loc, loadedA, rewriter).base; -- baseB = getSharedMemoryObjectFromStruct(loc, loadedB, rewriter).base; -+ baseA = -+ getSharedMemoryObjectFromStruct( -+ loc, loadedA, -+ typeConverter->convertType(aTensorTy.getElementType()), rewriter) -+ .base; -+ baseB = getSharedMemoryObjectFromStruct( -+ loc, loadedB, -+ typeConverter->convertType(bTensorTy.getElementType()), rewriter) -+ .base; - if (aSharedLayout) { - auto aOrd = aSharedLayout.getOrder(); - transA = aOrd[0] == 0; -diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -@@ -551,7 +551,8 @@ struct StoreAsyncOpConversion - Value llDst = adaptor.getDst(); - Value llSrc = adaptor.getSrc(); - auto srcShape = srcTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter); -+ auto smemObj = -+ getSharedMemoryObjectFromStruct(loc, llSrc, elemTy, rewriter); - - SmallVector offsetVals; - for (auto i = 0; i < srcShape.size(); ++i) { -@@ -1250,7 +1251,8 @@ struct InsertSliceOpConversion - - // newBase = base + offset - // Triton support either static and dynamic offsets -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, llDst, dstTy.getElementType(), rewriter); - SmallVector offsets; - SmallVector srcStrides; - auto mixedOffsets = op.getMixedOffsets(); -@@ -1339,7 +1341,8 @@ struct InsertSliceAsyncOpConversion - // %dst - auto dstTy = dst.getType().cast(); - auto dstShape = dstTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = -+ getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); - auto axis = op->getAttrOfType("axis").getInt(); - SmallVector offsetVals; - SmallVector srcStrides; -@@ -1601,7 +1604,9 @@ struct InsertSliceAsyncV2OpConversion - Value dst = op.getDst(); - auto dstTy = dst.getType().cast(); - auto dstShape = dstTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, llDst, typeConverter->convertType(dstTy.getElementType()), -+ rewriter); - - // the offset of coord considering multicast slicing - SmallVector mcastOffsetVals; -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -@@ -680,8 +680,8 @@ struct AllocTensorOpConversion - } - - auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); -- auto smemObj = -- SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter); -+ auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, -+ newOrder, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -@@ -704,10 +704,12 @@ struct ExtractSliceOpConversion - assert(op.hasUnitStride() && - "Only unit stride supported by ExtractSliceOpConversion"); - -+ auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); -+ - // newBase = base + offset - // Triton supports either static and dynamic offsets -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), -+ llvmElemTy, rewriter); - SmallVector opOffsetVals; - SmallVector offsetVals; - auto mixedOffsets = op.getMixedOffsets(); -@@ -735,10 +737,9 @@ struct ExtractSliceOpConversion - } - } - -- auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); - auto elemPtrTy = ptr_ty(llvmElemTy, 3); - smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), -- strideVals, offsetVals); -+ llvmElemTy, strideVals, offsetVals); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -@@ -525,7 +525,7 @@ public: - - SmallVector srcStrides = {dstStrides[0], dstStrides[1]}; - SmallVector offsetVals = {i32_val(0), i32_val(0)}; -- SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals); -+ SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); - - DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, dstElemTy, -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp ---- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp -@@ -131,9 +131,9 @@ void createStoreDSmem(Location loc, Patt - createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); - } - --SharedMemoryObject --getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, -- ConversionPatternRewriter &rewriter) { -+SharedMemoryObject getSharedMemoryObjectFromStruct( -+ Location loc, Value llvmStruct, Type elemTy, -+ ConversionPatternRewriter &rewriter) { - ArrayRef types = - llvmStruct.getType().cast().getBody(); - SmallVector elems(types.size()); -@@ -144,6 +144,7 @@ getSharedMemoryObjectFromStruct(Location - - auto rank = (elems.size() - 1) / 2; - return {/*base=*/elems[0], -+ /*baseElemType=*/elemTy, - /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, - /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; - } -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h ---- a/lib/Conversion/TritonGPUToLLVM/Utility.h -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h -@@ -234,6 +234,7 @@ getStridesFromShapeAndOrder(ArrayRef strides, -+ SharedMemoryObject(Value base, Type baseElemType, ArrayRef strides, - ArrayRef offsets) -- : base(base), strides(strides.begin(), strides.end()), -- offsets(offsets.begin(), offsets.end()) {} -+ : base(base), -+ baseElemType(baseElemType), -+ strides(strides.begin(), strides.end()), -+ offsets(offsets.begin(), offsets.end()) { -+ assert(baseElemType == -+ base.getType().cast().getElementType()); -+ } - -- SharedMemoryObject(Value base, ArrayRef shape, -+ SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) -- : base(base) { -+ : base(base), baseElemType(baseElemType) { -+ assert(baseElemType == -+ base.getType().cast().getElementType()); - strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - offsets.append(order.size(), i32_val(0)); - } -@@ -290,13 +298,13 @@ struct SharedMemoryObject { - Value cSwizzleOffset = getCSwizzleOffset(order); - Value offset = sub(i32_val(0), cSwizzleOffset); - Type type = base.getType(); -- return gep(type, base, offset); -+ return gep(type, baseElemType, base, offset); - } - }; - --SharedMemoryObject --getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, -- ConversionPatternRewriter &rewriter); -+SharedMemoryObject getSharedMemoryObjectFromStruct( -+ Location loc, Value llvmStruct, Type elemTy, -+ ConversionPatternRewriter &rewriter); - - // Convert an \param index to a multi-dim coordinate given \param shape and - // \param order. -diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp -@@ -211,14 +211,16 @@ struct TransOpConversion - matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); -- auto srcSmemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto llvmElemTy = getTypeConverter()->convertType( -+ op.getType().cast().getElementType()); -+ auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), -+ llvmElemTy, rewriter); - SmallVector dstStrides = {srcSmemObj.strides[1], - srcSmemObj.strides[0]}; - SmallVector dstOffsets = {srcSmemObj.offsets[1], - srcSmemObj.offsets[0]}; -- auto dstSmemObj = -- SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets); -+ auto dstSmemObj = SharedMemoryObject( -+ srcSmemObj.base, srcSmemObj.baseElemType, dstStrides, dstOffsets); - auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); diff --git a/third_party/triton/cl580852372.patch b/third_party/triton/cl580852372.patch deleted file mode 100644 index d15ec833fc6f25..00000000000000 --- a/third_party/triton/cl580852372.patch +++ /dev/null @@ -1,15 +0,0 @@ -==== triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td#3 - /google/src/cloud/shyshkov/mlir_4983432f17eb4b445e161c5f8278c6ea4d5d1241_1699531174/triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td ==== -# action=edit type=text ---- triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td 2023-11-09 03:52:05.000000000 -0800 -+++ triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td 2023-11-09 04:00:14.000000000 -0800 -@@ -29,8 +29,8 @@ - include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType - --def LLVM_PointerGlobal : LLVM_OpaquePointerInAddressSpace<1>; --def LLVM_PointerShared : LLVM_OpaquePointerInAddressSpace<3>; -+def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; -+def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; - - class NVGPU_Op traits = []> : - LLVM_OpBase; diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 60f0c56799b5dc..3795c89bb75563 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl578837341" - TRITON_SHA256 = "0d8112bb31d48b5beadbfc2e13c52770a95d3759b312b15cf26dd72e71410568" + TRITON_COMMIT = "cl580208989" + TRITON_SHA256 = "bcf6e99a73c8797720325b0f2e48447cdae7f68c53c68bfe04c39104db542562" tf_http_archive( name = "triton", @@ -17,8 +17,5 @@ def repo(): patch_file = [ "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", - "//third_party/triton:cl580550344.patch", - "//third_party/triton:cl580481414.patch", - "//third_party/triton:cl580852372.patch", ], ) diff --git a/third_party/xla/third_party/triton/cl577369732.patch b/third_party/xla/third_party/triton/cl577369732.patch deleted file mode 100644 index e63b9f3804974b..00000000000000 --- a/third_party/xla/third_party/triton/cl577369732.patch +++ /dev/null @@ -1,116 +0,0 @@ -==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#19 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -759,7 +759,7 @@ - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { -- OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &operand = *forOp.getTiedLoopInit(arg); - setValueMapping(arg, operand.get(), 0); - } - -==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#10 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -188,7 +188,7 @@ - auto getIncomingOp = [this](Value v) -> Value { - if (auto arg = v.dyn_cast()) - if (arg.getOwner()->getParentOp() == forOp.getOperation()) -- return forOp.getOpOperandForRegionIterArg(arg).get(); -+ return forOp.getTiedLoopInit(arg)->get(); - return Value(); - }; - -@@ -298,10 +298,10 @@ - Operation *firstDot = builder.clone(*dot, mapping); - if (Value a = operand2headPrefetch.lookup(dot.getA())) - firstDot->setOperand( -- 0, newForOp.getRegionIterArgForOpOperand(*a.use_begin())); -+ 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); - if (Value b = operand2headPrefetch.lookup(dot.getB())) - firstDot->setOperand( -- 1, newForOp.getRegionIterArgForOpOperand(*b.use_begin())); -+ 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); - - // remaining part - int64_t kOff = prefetchWidth; -==== triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#18 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -245,7 +245,7 @@ - for (OpOperand &use : value.getUses()) { - Operation *user = use.getOwner(); - if (auto forOp = dyn_cast(user)) { -- Value arg = forOp.getRegionIterArgForOpOperand(use); -+ Value arg = forOp.getTiedLoopRegionIterArg(&use); - Value result = forOp.getResultForOpOperand(use); - setEncoding({arg, result}, info, changed, user); - continue; -@@ -767,7 +767,7 @@ - SmallVector newOperands; - for (auto arg : forOp.getRegionIterArgs()) { - if (slice.count(arg)) { -- OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); -+ OpOperand &initVal = *forOp.getTiedLoopInit(arg); - argMapping.push_back(std::make_pair( - forOp.getResultForOpOperand(initVal).getResultNumber(), - forOp.getInitArgs().size() + newOperands.size())); -==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#16 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-24 18:31:01.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -430,10 +430,10 @@ - Block *block = blockArg.getOwner(); - Operation *parentOp = block->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { -- OpOperand &initOperand = forOp.getOpOperandForRegionIterArg(blockArg); -+ OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); - Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( - blockArg.getArgNumber() - forOp.getNumInductionVars()); -- queue.push_back({initOperand.get(), encoding}); -+ queue.push_back({initOperand->get(), encoding}); - queue.push_back({yieldOperand, encoding}); - continue; - } -==== triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp#1 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-12 01:35:16.000000000 -0700 -+++ triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp 2023-10-27 20:17:46.000000000 -0700 -@@ -88,9 +88,8 @@ - auto parentOp = blockArg.getOwner()->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- if (failed(getDependentPointers( -- forOp.getOpOperandForRegionIterArg(blockArg).get(), -- dependentSet, processedSet))) -+ if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), -+ dependentSet, processedSet))) - return failure(); - - unsigned operandIdx = -@@ -383,7 +382,7 @@ - if (failed(addControlOperandsForForOp(forOp))) - return failure(); - if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { -- Value operand = forOp.getOpOperandForRegionIterArg(blockArg).get(); -+ Value operand = forOp.getTiedLoopInit(blockArg)->get(); - if (failed(tryInsertAndPropagate(operand))) - return failure(); - -==== triton/test/lib/Analysis/TestAlias.cpp#5 - /google/src/cloud/springerm/mlir_3cd2a0bc1a2dcf851f1821765946b77d0e65bd2e_1698463035/triton/test/lib/Analysis/TestAlias.cpp ==== -# action=edit type=text ---- triton/test/lib/Analysis/TestAlias.cpp 2023-10-19 14:55:11.000000000 -0700 -+++ triton/test/lib/Analysis/TestAlias.cpp 2023-10-27 20:17:47.000000000 -0700 -@@ -87,7 +87,7 @@ - } - if (auto forOp = dyn_cast(op)) { - for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { -- auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); -+ auto operand = forOp.getTiedLoopInit(arg.value())->get(); - auto opNames = getAllocOpNames(operand); - auto argName = getValueOperandName(arg.value(), state); - print(argName, opNames, os); diff --git a/third_party/xla/third_party/triton/cl580481414.patch b/third_party/xla/third_party/triton/cl580481414.patch deleted file mode 100644 index 130e4860a4e9cd..00000000000000 --- a/third_party/xla/third_party/triton/cl580481414.patch +++ /dev/null @@ -1,1512 +0,0 @@ -diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td ---- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td -+++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td -@@ -29,9 +29,8 @@ include "mlir/IR/EnumAttr.td" - include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType - --def I8Ptr_global : LLVM_IntPtrBase<8, 1>; --def I8Ptr_shared : LLVM_IntPtrBase<8, 3>; --def I64Ptr_shared : LLVM_IntPtrBase<64, 3>; -+def LLVM_PointerGlobal : LLVM_OpaquePointerInAddressSpace<1>; -+def LLVM_PointerShared : LLVM_OpaquePointerInAddressSpace<3>; - - class NVGPU_Op traits = []> : - LLVM_OpBase; -@@ -55,7 +54,7 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"w - } - - def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, I32Attr:$count); - let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)"; - } - -@@ -71,12 +70,12 @@ def MBarrier_ArriveTypeAttr : I32EnumAtt - } - - def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); - let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; - } - - def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> { -- let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase); -+ let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$phase); - let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)"; - } - -@@ -116,13 +115,13 @@ def NVGPU_WGMMADescCreateOp : NVGPU_Op<" - } - - def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> { -- let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, -+ let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc, - I1:$pred, Variadic:$coords, Optional:$mcastMask); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - - def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> { -- let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic:$coords, I16Attr:$mcastMask); -+ let arguments = (ins LLVM_PointerShared:$dst, LLVM_PointerShared:$mbarrier, LLVM_PointerGlobal:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic:$coords, I16Attr:$mcastMask); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - -@@ -217,12 +216,12 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"clus - } - - def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic:$coords); -+ let arguments = (ins LLVM_PointerGlobal:$tmaDesc, LLVM_PointerShared:$src, I1:$pred, Variadic:$coords); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - - def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { -- let arguments = (ins I8Ptr_shared:$addr, Variadic:$datas); -+ let arguments = (ins LLVM_PointerShared:$addr, Variadic:$datas); - let assemblyFormat = "operands attr-dict `:` type(operands)"; - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -@@ -41,13 +41,16 @@ struct AllocMBarrierOpConversion : publi - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); - auto resultTy = op.getType(); - auto resultTensorTy = resultTy.dyn_cast(); -- Type elemPtrTy; -+ Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Type llvmElemTy; - if (resultTensorTy) { -- auto llvmElemTy = -+ llvmElemTy = - getTypeConverter()->convertType(resultTensorTy.getElementType()); -- elemPtrTy = ptr_ty(llvmElemTy, 3); - } else { -- elemPtrTy = getTypeConverter()->convertType(resultTy); -+ auto resultPtrTy = resultTy.dyn_cast(); -+ assert(resultPtrTy && "Unknown type for AllocMBarrierOp"); -+ llvmElemTy = -+ getTypeConverter()->convertType(resultPtrTy.getPointeeType()); - } - smemBase = bitcast(smemBase, elemPtrTy); - auto threadId = getThreadId(rewriter, loc); -@@ -61,7 +64,7 @@ struct AllocMBarrierOpConversion : publi - for (int i = 0; i < numMBarriers; ++i) { - Value smem = smemBase; - if (i > 0) { -- smem = gep(elemPtrTy, smem, i32_val(i)); -+ smem = gep(elemPtrTy, llvmElemTy, smem, i32_val(i)); - } - rewriter.create(loc, smem, pred, - op.getCount()); -@@ -142,11 +145,11 @@ struct ExtractMBarrierOpConversion - op.getTensor().getType().cast().getElementType(); - auto tensorStruct = adaptor.getTensor(); - auto index = adaptor.getIndex(); -- auto ptrTy = -- LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - auto basePtr = - extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0)); -- Value result = gep(ptrTy, basePtr, index); -+ Value result = -+ gep(ptrTy, getTypeConverter()->convertType(elemTy), basePtr, index); - rewriter.replaceOp(op, result); - return success(); - } -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -@@ -310,10 +310,10 @@ private: - shapePerCTA); - Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, - paddedRepShape, outOrd); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - auto vecTy = vec_ty(llvmElemTy, vec); -- ptr = bitcast(ptr, ptr_ty(vecTy, 3)); -+ ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { -@@ -326,7 +326,7 @@ private: - } - store(valVec, ptr); - } else { -- Value valVec = load(ptr); -+ Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); - if (isInt1) -@@ -423,10 +423,10 @@ private: - for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - auto coord = coord2valT[elemId].first; - Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd); -- auto elemPtrTy = ptr_ty(elemTy, 3); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value ptr = gep(elemPtrTy, elemTy, smemBase, offset); - auto vecTy = vec_ty(elemTy, vec); -- ptr = bitcast(ptr, ptr_ty(vecTy, 3)); -+ ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { -@@ -435,7 +435,7 @@ private: - } - store(valVec, ptr); - } else { -- Value valVec = load(ptr); -+ Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(elemTy, valVec, i32_val(v)); - vals[elemId + v] = currVal; -@@ -462,7 +462,7 @@ private: - unsigned rank = srcShapePerCTA.size(); - - auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); -@@ -480,7 +480,7 @@ private: - - for (unsigned i = 0; i < inIndices.size(); ++i) { - Value offset = linearize(rewriter, loc, inIndices[i], smemShape); -- Value ptr = gep(elemPtrTy, smemBase, offset); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - store(inVals[i], ptr); - } - } -@@ -513,8 +513,8 @@ private: - linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder); - Value localOffset = linearize(rewriter, loc, localCoord, smemShape); - -- Value ptr = gep(elemPtrTy, smemBase, localOffset); -- outVals.push_back(load_dsmem(ptr, remoteCTAId)); -+ Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); -+ outVals.push_back(load_dsmem(ptr, remoteCTAId, llvmElemTy)); - } - - Value result = -@@ -545,10 +545,8 @@ private: - - if (shouldUseDistSmem(srcLayout, dstLayout)) - return lowerDistToDistWithDistSmem(op, adaptor, rewriter); -- -- auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - auto shape = dstTy.getShape(); - unsigned rank = dstTy.getRank(); -@@ -747,7 +745,7 @@ private: - auto outOrd = dstSharedLayout.getOrder(); - Value smemBase = getSharedMemoryBase(loc, rewriter, dst); - auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); -- auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - - int32_t elemSize = elemTy.getIntOrFloatBitWidth(); -@@ -774,8 +772,7 @@ private: - unsigned leadingDimOffset = - numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]]; - -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - - uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; - -@@ -804,7 +801,8 @@ private: - loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset, - numElemsPerSwizzlingRow, true); - -- Value addr = gep(elemPtrTy, smemBase, offset); -+ Value addr = gep(elemPtrTy, getTypeConverter()->convertType(elemTy), -+ smemBase, offset); - - Value words[4]; - for (unsigned i = 0; i < 8; ++i) { -@@ -815,7 +813,7 @@ private: - } - - rewriter.create( -- loc, bitcast(addr, ptrI8SharedTy), -+ loc, bitcast(addr, ptrSharedTy), - ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), - bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); - } -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -@@ -133,10 +133,10 @@ Value loadAFMA(Value A, Value llA, Block - auto elemTy = typeConverter->convertType( - A.getType().cast().getElementType()); - -- Type ptrTy = ptr_ty(elemTy, 3); -+ Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) -- aPtrs[i] = gep(ptrTy, aSmem.base, aOff[i]); -+ aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); - - SmallVector vas; - -@@ -148,8 +148,8 @@ Value loadAFMA(Value A, Value llA, Block - for (unsigned mm = 0; mm < mSizePerThread; ++mm) { - Value offset = - add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); -- Value pa = gep(ptrTy, aPtrs[0], offset); -- Value va = load(pa); -+ Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); -+ Value va = load(elemTy, pa); - vas.emplace_back(va); - } - -@@ -200,10 +200,10 @@ Value loadBFMA(Value B, Value llB, Block - auto elemTy = typeConverter->convertType( - B.getType().cast().getElementType()); - -- Type ptrTy = ptr_ty(elemTy, 3); -+ Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) -- bPtrs[i] = gep(ptrTy, bSmem.base, bOff[i]); -+ bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); - - SmallVector vbs; - -@@ -215,8 +215,8 @@ Value loadBFMA(Value B, Value llB, Block - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value offset = - add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); -- Value pb = gep(ptrTy, bPtrs[0], offset); -- Value vb = load(pb); -+ Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); -+ Value vb = load(elemTy, pb); - vbs.emplace_back(vb); - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp -@@ -150,10 +150,10 @@ static Value loadA(Value tensor, const S - } - - Type elemX2Ty = vec_ty(f16_ty, 2); -- Type elemPtrTy = ptr_ty(f16_ty, 3); -+ Type elemTy = f16_ty; - if (tensorTy.getElementType().isBF16()) { - elemX2Ty = vec_ty(i16_ty, 2); -- elemPtrTy = ptr_ty(i16_ty, 3); -+ elemTy = i16_ty; - } - - // prepare arguments -@@ -161,22 +161,23 @@ static Value loadA(Value tensor, const S - - std::map, std::pair> has; - for (int i = 0; i < numPtrA; i++) -- ptrA[i] = gep(ptr_ty(f16_ty, 3), smemBase, offA[i]); -+ ptrA[i] = gep(ptr_ty(ctx, 3), f16_ty, smemBase, offA[i]); - - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - auto loadA = [&](int m, int k) { - int offidx = (isARow ? k / 4 : m) % numPtrA; -- Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]); -+ Value thePtrA = gep(ptr_ty(ctx, 3), elemTy, smemBase, offA[offidx]); - - int stepAM = isARow ? m : m / numPtrA * numPtrA; - int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; - Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), - mul(i32_val(stepAK), strideAK)); -- Value pa = gep(elemPtrTy, thePtrA, offset); -- Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); -- Value ha = load(bitcast(pa, aPtrTy)); -+ Value pa = gep(ptr_ty(ctx, 3), elemTy, thePtrA, offset); -+ Type vecTy = vec_ty(i32_ty, std::max(vecA / 2, 1)); -+ Type aPtrTy = ptr_ty(ctx, 3); -+ Value ha = load(vecTy, bitcast(pa, aPtrTy)); - // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty); -@@ -273,17 +274,17 @@ static Value loadB(Value tensor, const S - offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); - } - -- Type elemPtrTy = ptr_ty(f16_ty, 3); -+ Type elemTy = f16_ty; - Type elemX2Ty = vec_ty(f16_ty, 2); - if (tensorTy.getElementType().isBF16()) { -- elemPtrTy = ptr_ty(i16_ty, 3); -+ elemTy = i16_ty; - elemX2Ty = vec_ty(i16_ty, 2); - } - - SmallVector ptrB(numPtrB); - ValueTable hbs; - for (int i = 0; i < numPtrB; ++i) -- ptrB[i] = gep(ptr_ty(f16_ty, 3), smem, offB[i]); -+ ptrB[i] = gep(ptr_ty(ctx, 3), f16_ty, smem, offB[i]); - - auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; -@@ -297,10 +298,10 @@ static Value loadB(Value tensor, const S - int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); - Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), - mul(i32_val(stepBK), strideBK)); -- Value pb = gep(elemPtrTy, thePtrB, offset); -+ Value pb = gep(ptr_ty(ctx, 3), elemTy, thePtrB, offset); - -- Value hb = -- load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); -+ Type vecTy = vec_ty(i32_ty, std::max(vecB / 2, 1)); -+ Value hb = load(vecTy, bitcast(pb, ptr_ty(ctx, 3))); - // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty); -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp -@@ -280,9 +280,8 @@ SmallVector MMA16816SmemLoader::c - return offs; - } - --std::tuple --MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef ptrs, Type matTy, -- Type shemPtrTy) const { -+std::tuple MMA16816SmemLoader::loadX4( -+ int mat0, int mat1, ArrayRef ptrs, Type matTy, Type shemTy) const { - assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); - int matIdx[2] = {mat0, mat1}; - -@@ -321,7 +320,7 @@ MMA16816SmemLoader::loadX4(int mat0, int - Value stridedOffset = - mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape), - stridedSmemOffset); -- Value readPtr = gep(shemPtrTy, ptr, stridedOffset); -+ Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset); - - PTXBuilder builder; - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a -@@ -363,7 +362,7 @@ MMA16816SmemLoader::loadX4(int mat0, int - - for (int i = 0; i < 4; ++i) - for (int j = 0; j < vecWidth; ++j) { -- vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]); -+ vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); - } - // row + trans and col + no-trans are equivalent - bool isActualTrans = -@@ -381,8 +380,8 @@ MMA16816SmemLoader::loadX4(int mat0, int - int e = em % vecWidth; - int m = em / vecWidth; - int idx = m * 2 + r; -- Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3)); -- Value val = load(ptr); -+ Value ptr = bitcast(vptrs[idx][e], ptr_ty(ctx, 3)); -+ Value val = load(packedTy, ptr); - Value canonval = bitcast(val, vec_ty(canonInt, canonWidth)); - for (int w = 0; w < canonWidth; ++w) { - int ridx = idx + w * kWidth / vecWidth; -@@ -455,16 +454,16 @@ MMA16816SmemLoader::MMA16816SmemLoader( - warpMatOffset = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; - } - --Type getSharedMemPtrTy(Type argType) { -+Type getSharedMemTy(Type argType) { - MLIRContext *ctx = argType.getContext(); - if (argType.isF16()) -- return ptr_ty(type::f16Ty(ctx), 3); -+ return type::f16Ty(ctx); - else if (argType.isBF16()) -- return ptr_ty(type::i16Ty(ctx), 3); -+ return type::i16Ty(ctx); - else if (argType.isF32()) -- return ptr_ty(type::f32Ty(ctx), 3); -+ return type::f32Ty(ctx); - else if (argType.getIntOrFloatBitWidth() == 8) -- return ptr_ty(type::i8Ty(ctx), 3); -+ return type::i8Ty(ctx); - else - llvm::report_fatal_error("mma16816 data type not supported"); - } -@@ -531,15 +530,16 @@ std::function getLoadMat - const int numPtrs = loader.getNumPtrs(); - SmallVector ptrs(numPtrs); - Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); -- Type smemPtrTy = getSharedMemPtrTy(eltTy); -+ Type smemTy = getSharedMemTy(eltTy); - for (int i = 0; i < numPtrs; ++i) -- ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy); -+ ptrs[i] = -+ gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]); - // actually load from shared memory - auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(), - SmallVector(4, i32_ty)); - auto [ha0, ha1, ha2, ha3] = loader.loadX4( - (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, ptrs, -- matTy, getSharedMemPtrTy(eltTy)); -+ matTy, getSharedMemTy(eltTy)); - if (!isA) - std::swap(ha1, ha2); - // the following is incorrect -diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -@@ -561,8 +561,7 @@ struct StoreAsyncOpConversion - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(ctx, 3); - - auto threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(threadId, i32_val(0)); -@@ -599,9 +598,10 @@ struct StoreAsyncOpConversion - } - } - Value srcOffset = i32_val(b * boxStride); -- auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -- Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset); -- auto addr = bitcast(srcPtrBase, ptrI8SharedTy); -+ auto srcPtrTy = ptr_ty(ctx, 3); -+ Value srcPtrBase = gep(srcPtrTy, getTypeConverter()->convertType(elemTy), -+ smemObj.base, srcOffset); -+ auto addr = bitcast(srcPtrBase, ptrSharedTy); - rewriter.create(loc, tmaDesc, addr, pred, - coord); - } -@@ -749,7 +749,7 @@ struct StoreAsyncOpConversion - Value llDst = adaptor.getDst(); - Value llSrc = adaptor.getSrc(); - auto srcShape = srcTy.getShape(); -- auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); -+ auto dstElemPtrTy = ptr_ty(ctx, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, dstElemPtrTy); - -@@ -760,8 +760,7 @@ struct StoreAsyncOpConversion - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(ctx, 3); - - auto threadId = getThreadId(rewriter, loc); - Value pred = int_val(1, 1); -@@ -817,7 +816,9 @@ struct StoreAsyncOpConversion - i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset, - numElemsPerSwizzlingRow, true); - -- Value addr = gep(dstElemPtrTy, smemBase, offset); -+ Value addr = -+ gep(dstElemPtrTy, getTypeConverter()->convertType(dstElemTy), -+ smemBase, offset); - Value words[4]; - for (unsigned i = 0; i < 8; ++i) { - if (i % minVec == 0) -@@ -827,7 +828,7 @@ struct StoreAsyncOpConversion - } - - rewriter.create( -- loc, bitcast(addr, ptrI8SharedTy), -+ loc, bitcast(addr, ptrSharedTy), - ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), - bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); - } -@@ -860,9 +861,11 @@ struct StoreAsyncOpConversion - instrShape[1] * warpsPerCTA[1] / - numBox), - mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow))); -- auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3); -- Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset); -- auto addr = bitcast(srcPtrBase, ptrI8SharedTy); -+ auto srcPtrTy = ptr_ty(ctx, 3); -+ Value srcPtrBase = -+ gep(srcPtrTy, getTypeConverter()->convertType(dstElemTy), smemBase, -+ srcOffset); -+ auto addr = bitcast(srcPtrBase, ptrSharedTy); - rewriter.create(loc, tmaDesc, addr, - pred, coord); - } -@@ -1022,7 +1025,7 @@ struct AtomicCASOpConversion - auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - createBarrier(rewriter, loc, numCTAs); - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); -+ atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); - // Only threads with mask = True store the result - PTXBuilder ptxBuilderStore; - auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); -@@ -1033,7 +1036,7 @@ struct AtomicCASOpConversion - auto ASMReturnTy = void_ty(ctx); - ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); -- Value ret = load(atomPtr); -+ Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); - rewriter.replaceOp(op, {ret}); - } -@@ -1194,7 +1197,7 @@ struct AtomicRMWOpConversion - return success(); - } - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); -- atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); -+ atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); - // Only threads with rmwMask = True store the result - PTXBuilder ptxBuilderStore; - auto &storeShared = -@@ -1204,7 +1207,7 @@ struct AtomicRMWOpConversion - storeShared(ptrOpr, valOpr).predicate(rmwMask); - ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - createBarrier(rewriter, loc, numCTAs); -- Value ret = load(atomPtr); -+ Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); - rewriter.replaceOp(op, {ret}); - } -@@ -1273,8 +1276,8 @@ struct InsertSliceOpConversion - // object - auto offset = dot(rewriter, loc, offsets, smemObj.strides); - auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); -- auto elemPtrTy = ptr_ty(elemTy, 3); -- auto smemBase = gep(elemPtrTy, smemObj.base, offset); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ auto smemBase = gep(elemPtrTy, elemTy, smemObj.base, offset); - - auto llSrc = adaptor.getSource(); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); -@@ -1357,8 +1360,8 @@ struct InsertSliceAsyncOpConversion - // Compute the offset based on the original dimensions of the shared - // memory object - auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); -- auto dstPtrTy = ptr_ty(resElemTy, 3); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); -+ Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); - - // %mask - SmallVector maskElems; -@@ -1638,7 +1641,7 @@ struct InsertSliceAsyncV2OpConversion - // currently only support rank == 2. - dstOffsetCommon = - add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0]))); -- auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); - - Value tmaDesc = - llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx); -@@ -1646,8 +1649,7 @@ struct InsertSliceAsyncV2OpConversion - // cache-policy modes - Value l2Desc = int_val(64, 0x1000000000000000ll); - -- auto ptrI8SharedTy = LLVM::LLVMPointerType::get( -- typeConverter->convertType(rewriter.getI8Type()), 3); -+ auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - - SmallVector coordCommon; - auto llCoord = getTypeConverter()->unpackLLElements( -@@ -1688,11 +1690,12 @@ struct InsertSliceAsyncV2OpConversion - for (size_t i = 0; i < numBoxes; ++i) { - Value dstOffset = - add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast)); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ Value dstPtrBase = gep(dstPtrTy, getTypeConverter()->convertType(elemTy), -+ smemObj.base, dstOffset); - SmallVector coord = coordCommon; - coord[0] = add(coordCommon[0], i32_val(i * boxDims[0])); - rewriter.create( -- loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc, -+ loc, bitcast(dstPtrBase, ptrSharedTy), adaptor.getMbar(), tmaDesc, - l2Desc, pred, coord, mcastMask); - } - -diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp -@@ -149,13 +149,11 @@ private: - // Assign base index to each operand in their order in indices - std::map indexToBase; - indexToBase[indices[0]] = -- bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), -- getElementPtrType(op, indices[0])); -+ getSharedMemoryBase(loc, rewriter, op.getOperation()); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { -- indexToBase[indices[i]] = -- bitcast(gep(getElementPtrType(op, indices[i - 1]), -- indexToBase[indices[i - 1]], i32_val(elems)), -- getElementPtrType(op, indices[i])); -+ indexToBase[indices[i]] = gep( -+ ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), -+ indexToBase[indices[i - 1]], i32_val(elems)); - } - // smemBases[k] is the base pointer for the k-th operand - SmallVector smemBases(op.getNumOperands()); -@@ -335,11 +333,10 @@ private: - rewriter.replaceOp(op, results); - } - -- // Return the type of the shared memory pointer for operand i. -- Type getElementPtrType(triton::ReduceOp op, int i) const { -+ // Return the pointee type of the shared memory pointer for operand i. -+ Type getElementType(triton::ReduceOp op, int i) const { - auto ty = op.getInputTypes()[i].getElementType(); -- auto llvmElemTy = getTypeConverter()->convertType(ty); -- return LLVM::LLVMPointerType::get(llvmElemTy, 3); -+ return getTypeConverter()->convertType(ty); - } - - SmallVector -@@ -408,8 +405,9 @@ private: - Value writeOffset = - linearize(rewriter, loc, writeIdx, smemShape, smemOrder); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); -+ auto elemTy = getElementType(op, i); -+ Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], writeOffset); - storeShared(rewriter, loc, writePtr, acc[i], laneZero); - } - } -@@ -442,17 +440,19 @@ private: - for (unsigned round = 0; round < elemsPerThread; ++round) { - SmallVector acc(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- Value readPtr = gep(elemPtrTy, smemBases[i], readOffset); -- acc[i] = loadShared(rewriter, loc, readPtr, threadIsNeeded); -+ auto elemTy = getElementType(op, i); -+ Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], readOffset); -+ acc[i] = loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); - } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); - // only the first thread in each sizeInterWarps is writing - Value writeOffset = readOffset; - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -- auto elemPtrTy = getElementPtrType(op, i); -- writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset); -+ auto elemTy = getElementType(op, i); -+ writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], writeOffset); - } - - Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); -@@ -483,6 +483,7 @@ private: - auto smemOrder = helper.getOrderWithAxisAtBeginning(); - SmallVector results(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { -+ auto elemTy = getElementType(op, i); - if (auto resultTy = - op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 -@@ -497,16 +498,16 @@ private: - readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, smemOrder); -- Value readPtr = -- gep(getElementPtrType(op, i), smemBases[i], readOffset); -- resultVals[j] = load(readPtr); -+ Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, -+ smemBases[i], readOffset); -+ resultVals[j] = load(elemTy, readPtr); - } - - results[i] = getTypeConverter()->packLLElements(loc, resultVals, - rewriter, resultTy); - } else { - // 0d-tensor -> scalar -- results[i] = load(smemBases[i]); -+ results[i] = load(elemTy, smemBases[i]); - } - } - rewriter.replaceOp(op, results); -diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp -@@ -112,7 +112,8 @@ static void storeWarpAccumulator(SmallVe - Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); - Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); - index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); -- Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index); -+ Value writePtr = gep(baseSharedMemPtr.getType(), lastElement.getType(), -+ baseSharedMemPtr, index); - storeShared(rewriter, loc, writePtr, lastElement, mask); - chunkId++; - } -@@ -170,8 +171,9 @@ static void AddPartialReduce(SmallVector - for (unsigned i = 0; i < axisNumWarps; ++i) { - Value index = add(parallelLaneId, i32_val(numParallelLane * - (i + chunkId * axisNumWarps))); -- Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index); -- Value partialReduce = load(ptr); -+ Value ptr = gep(sharedMemoryPtr.getType(), srcValues[srcIndex].getType(), -+ sharedMemoryPtr, index); -+ Value partialReduce = load(srcValues[srcIndex].getType(), ptr); - if (!accumulator.acc) { - accumulator.acc = partialReduce; - accumulator.maskedAcc = partialReduce; -@@ -411,7 +413,7 @@ ScanOpConversion::emitFastScan(triton::S - if (axisNumWarps > 1) { - // Slow path for the case where there are multiple warps with unique data on - // the axis. -- Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); -+ Type elemPtrTys = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - Value baseSharedMemPtr = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); - // Store the partial reducing for each warp into shared memory. -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -@@ -305,8 +305,7 @@ struct PrintOpConversion - - auto *context = rewriter.getContext(); - -- SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), -- ptr_ty(IntegerType::get(context, 8))}; -+ SmallVector argsType{ptr_ty(context), ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); -@@ -359,9 +358,8 @@ struct PrintOpConversion - - static void llPrintf(Value msg, ValueRange args, - ConversionPatternRewriter &rewriter) { -- Type int8Ptr = ptr_ty(i8_ty); -- - auto *ctx = rewriter.getContext(); -+ Type ptr = ptr_ty(ctx); - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - auto funcOp = getVprintfDeclaration(rewriter); -@@ -370,7 +368,7 @@ struct PrintOpConversion - Value one = i32_val(1); - Value zero = i32_val(0); - -- Value bufferPtr = null(int8Ptr); -+ Value bufferPtr = null(ptr); - - SmallVector newArgs; - if (args.size() >= 1) { -@@ -385,16 +383,16 @@ struct PrintOpConversion - - Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes); - auto allocated = -- rewriter.create(loc, ptr_ty(structTy), one, -+ rewriter.create(loc, ptr_ty(ctx), structTy, one, - /*alignment=*/0); - - for (const auto &entry : llvm::enumerate(newArgs)) { - auto index = i32_val(entry.index()); -- auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated, -+ auto fieldPtr = gep(ptr_ty(ctx), argTypes[entry.index()], allocated, - ArrayRef{zero, index}); - store(entry.value(), fieldPtr); - } -- bufferPtr = bitcast(allocated, int8Ptr); -+ bufferPtr = bitcast(allocated, ptr); - } - - SmallVector operands{msg, bufferPtr}; -@@ -488,8 +486,7 @@ struct AssertOpConversion - // void __assert_fail(const char * assertion, const char * file, unsigned - // int line, const char * function); - auto *ctx = rewriter.getContext(); -- SmallVector argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty, -- ptr_ty(i8_ty), -+ SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), - rewriter.getIntegerType(sizeof(size_t) * 8)}; - auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - -@@ -623,11 +620,14 @@ struct AddPtrOpConversion - Location loc = op->getLoc(); - auto resultTy = op.getType(); - auto offsetTy = op.getOffset().getType(); -- auto ptrTy = op.getPtr().getType(); - auto resultTensorTy = resultTy.dyn_cast(); - if (resultTensorTy) { - unsigned elems = getTotalElemsPerThread(resultTy); - Type elemTy = -+ getTypeConverter()->convertType(resultTensorTy.getElementType() -+ .cast() -+ .getPointeeType()); -+ Type ptrTy = - getTypeConverter()->convertType(resultTensorTy.getElementType()); - auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(), - rewriter, ptrTy); -@@ -635,15 +635,18 @@ struct AddPtrOpConversion - loc, adaptor.getOffset(), rewriter, offsetTy); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { -- resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); -+ resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); - } - Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, view); - } else { - assert(resultTy.isa()); -- Type llResultTy = getTypeConverter()->convertType(resultTy); -- Value result = gep(llResultTy, adaptor.getPtr(), adaptor.getOffset()); -+ auto resultPtrTy = getTypeConverter()->convertType(resultTy); -+ auto resultElemTy = getTypeConverter()->convertType( -+ resultTy.cast().getPointeeType()); -+ Value result = -+ gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); - rewriter.replaceOp(op, result); - } - return success(); -@@ -661,9 +664,7 @@ struct AllocTensorOpConversion - Location loc = op->getLoc(); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); - auto resultTy = op.getType().dyn_cast(); -- auto llvmElemTy = -- getTypeConverter()->convertType(resultTy.getElementType()); -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - auto sharedLayout = resultTy.getEncoding().cast(); - auto order = sharedLayout.getOrder(); -@@ -679,6 +680,8 @@ struct AllocTensorOpConversion - newOrder = SmallVector(order.begin(), order.end()); - } - -+ auto llvmElemTy = -+ getTypeConverter()->convertType(resultTy.getElementType()); - auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); - auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, - newOrder, loc, rewriter); -@@ -737,9 +740,10 @@ struct ExtractSliceOpConversion - } - } - -- auto elemPtrTy = ptr_ty(llvmElemTy, 3); -- smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), -- llvmElemTy, strideVals, offsetVals); -+ auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); -+ smemObj = -+ SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), -+ llvmElemTy, strideVals, offsetVals); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -@@ -261,8 +261,7 @@ public: - template - Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, - T value) const { -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - FunctionOpInterface funcOp; - if constexpr (std::is_pointer_v) - funcOp = value->template getParentOfType(); -@@ -275,7 +274,9 @@ public: - assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); - size_t offset = funcAllocation->getOffset(bufferId); - Value offVal = i32_val(offset); -- Value base = gep(ptrTy, smem, offVal); -+ Value base = -+ gep(ptrTy, this->getTypeConverter()->convertType(rewriter.getI8Type()), -+ smem, offVal); - return base; - } - -@@ -312,9 +313,10 @@ public: - // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y - // This means that we can use some immediate offsets for shared memory - // operations. -- auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3); -+ auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); - auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); -- Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset); -+ Value dstPtrBase = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ smemObj.base, dstOffset); - - auto srcEncoding = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); -@@ -423,7 +425,8 @@ public: - Value colOff = add(colOffSwizzled, colOffOrdered); - // compute non-immediate offset - offset = add(offset, add(rowOff, mul(colOff, strideCol))); -- Value currPtr = gep(dstPtrTy, dstPtrBase, offset); -+ Value currPtr = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ dstPtrBase, offset); - // compute immediate offset - Value immediateOff; - if (outOrder.size() == 2) { -@@ -434,7 +437,8 @@ public: - immediateOff = i32_val(immedateOffCol); - } - -- ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff); -+ ret[elemIdx] = gep(dstPtrTy, getTypeConverter()->convertType(resElemTy), -+ currPtr, immediateOff); - } - return ret; - } -@@ -479,8 +483,8 @@ public: - SmallVector outVals(outElems); - for (unsigned i = 0; i < numVecs; ++i) { - Value smemAddr = sharedPtrs[i * minVec]; -- smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); -- Value valVec = load(smemAddr); -+ smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); -+ Value valVec = load(wordTy, smemAddr); - for (unsigned v = 0; v < minVec; ++v) { - Value currVal = extract_element(dstElemTy, valVec, i32_val(v)); - outVals[i * minVec + v] = currVal; -@@ -537,7 +541,7 @@ public: - word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); - if (i % minVec == minVec - 1) { - Value smemAddr = sharedPtrs[i / minVec * minVec]; -- smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); -+ smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); - store(word, smemAddr); - } - } -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp -@@ -161,8 +161,7 @@ struct FuncOpConversion : public FuncOpC - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - // 1. Modify the function type to add the new argument. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); -@@ -232,15 +231,14 @@ struct FuncOpConversion : public FuncOpC - allocation.mapFuncOp(funcOp, newFuncOp); - - // Append arguments to receive TMADesc in global memory in the runtime -- auto i8PtrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), 1); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); - auto numArgs = newFuncOp.getBody().front().getNumArguments(); - auto funcTy = newFuncOp.getFunctionType().cast(); - SmallVector newInputsTy(funcTy.getParams().begin(), - funcTy.getParams().end()); - for (unsigned i = 0; i < numTMA; ++i) { -- newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc()); -- newInputsTy.push_back(i8PtrTy); -+ newFuncOp.getBody().front().addArgument(ptrTy, funcOp.getLoc()); -+ newInputsTy.push_back(ptrTy); - } - newFuncOp.setType( - LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy)); -@@ -296,9 +294,8 @@ private: - // of shared memory and append it to the operands of the callOp. - auto loc = callOp.getLoc(); - auto caller = callOp->getParentOfType(); -- auto ptrTy = LLVM::LLVMPointerType::get( -- this->getTypeConverter()->convertType(rewriter.getI8Type()), -- NVVM::kSharedMemorySpace); -+ auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), -+ NVVM::kSharedMemorySpace); - auto promotedOperands = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); -@@ -312,7 +309,9 @@ private: - } - // function has a shared mem buffer - auto offset = funcAllocation->getOffset(bufferId); -- auto offsetValue = gep(ptrTy, base, i32_val(offset)); -+ auto offsetValue = -+ gep(ptrTy, this->getTypeConverter()->convertType(rewriter.getI8Type()), -+ base, i32_val(offset)); - promotedOperands.push_back(offsetValue); - return promotedOperands; - } -@@ -612,9 +611,8 @@ private: - } else { - funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1); - } -- auto ptrTy = -- LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), -- NVVM::NVVMMemorySpace::kSharedMemorySpace); -+ auto ptrTy = LLVM::LLVMPointerType::get( -+ ctx, NVVM::NVVMMemorySpace::kSharedMemorySpace); - funcSmem = b.create(loc, ptrTy, funcSmem); - allocation.setFunctionSharedMemoryValue(funcOp, funcSmem); - }); -diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp -@@ -60,13 +60,11 @@ Type TritonGPUToLLVMTypeConverter::conve - for (size_t i = 0; i < 2 * shape.size(); ++i) - types.push_back(IntegerType::get(ctx, 64)); - -- types.push_back( -- LLVM::LLVMPointerType::get(eleType, type.getAddressSpace())); -+ types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); - - return LLVM::LLVMStructType::getLiteral(ctx, types); - } -- return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()), -- type.getAddressSpace()); -+ return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); - } - - Value TritonGPUToLLVMTypeConverter::packLLElements( -@@ -145,7 +143,7 @@ Type TritonGPUToLLVMTypeConverter::conve - if (auto shared_layout = layout.dyn_cast()) { - SmallVector types; - // base ptr -- auto ptrType = LLVM::LLVMPointerType::get(eltType, 3); -+ auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); - types.push_back(ptrType); - // shape dims - auto rank = type.getRank(); -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp ---- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp -@@ -46,12 +46,11 @@ Value createLLVMIntegerConstant(OpBuilde - // (2) Create LoadDSmemOp - // (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy - Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, -- Value ctaId) { -+ Value ctaId, Type elemTy) { - assert(addr.getType().isa() && - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value ret = - rewriter.create(loc, addr, ctaId, bitwidth); -@@ -63,12 +62,12 @@ Value createLoadDSmem(Location loc, Patt - // (2) Create LoadDSmemOp and extract results from retStruct - // (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy - SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, -- Value addr, Value ctaId, unsigned vec) { -+ Value addr, Value ctaId, unsigned vec, -+ Type elemTy) { - assert(addr.getType().isa() && - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value retStruct = rewriter.create( - loc, addr, ctaId, bitwidth, vec); -@@ -91,8 +90,7 @@ void createStoreDSmem(Location loc, Patt - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); -- unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); -+ unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = bitcast(value, dataTy); - rewriter.create(loc, addr, ctaId, data, pred); -@@ -115,8 +113,10 @@ void createStoreDSmem(Location loc, Patt - "addr must be a pointer type"); - auto ptrTy = addr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); -- auto elemTy = ptrTy.getElementType(); -- unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); -+ unsigned bitwidth = 0; -+ if (!values.empty()) { -+ bitwidth = values.back().getType().getIntOrFloatBitWidth(); -+ } - auto dataTy = rewriter.getIntegerType(bitwidth); - SmallVector data; - for (unsigned i = 0; i < values.size(); ++i) -@@ -253,11 +253,10 @@ Value storeShared(ConversionPatternRewri - } - - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, -- Value pred) { -+ Type elemTy, Value pred) { - MLIRContext *ctx = rewriter.getContext(); - auto ptrTy = ptr.getType().cast(); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); -- auto elemTy = ptrTy.getElementType(); - unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); - - const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); -@@ -363,12 +362,11 @@ Value addStringToModule(Location loc, Co - } - - Value zero = i32_val(0); -- Type globalPtrType = -- LLVM::LLVMPointerType::get(globalType, global.getAddrSpace()); -+ Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); - Value globalPtr = rewriter.create( - UnknownLoc::get(ctx), globalPtrType, global.getSymName()); - Value stringStart = -- rewriter.create(UnknownLoc::get(ctx), ptr_ty(i8_ty), -+ rewriter.create(UnknownLoc::get(ctx), ptr_ty(ctx), i8_ty, - globalPtr, SmallVector({zero, zero})); - return stringStart; - } -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h ---- a/lib/Conversion/TritonGPUToLLVM/Utility.h -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h -@@ -209,9 +209,10 @@ Value createLLVMIntegerConstant(OpBuilde - /// (1) load_dsmem(addr, ctaId) - /// (2) load_dsmem(addr, ctaId, vec) - Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, -- Value ctaId); -+ Value ctaId, Type elemTy); - SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, -- Value addr, Value ctaId, unsigned vec); -+ Value addr, Value ctaId, unsigned vec, -+ Type elemTy); - - /// Usage of macro store_dsmem - /// (1) store_dsmem(addr, ctaId, value, pred) -@@ -257,17 +258,12 @@ struct SharedMemoryObject { - : base(base), - baseElemType(baseElemType), - strides(strides.begin(), strides.end()), -- offsets(offsets.begin(), offsets.end()) { -- assert(baseElemType == -- base.getType().cast().getElementType()); -- } -+ offsets(offsets.begin(), offsets.end()) {} - - SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) - : base(base), baseElemType(baseElemType) { -- assert(baseElemType == -- base.getType().cast().getElementType()); - strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - offsets.append(order.size(), i32_val(0)); - } -@@ -332,7 +328,7 @@ Value storeShared(ConversionPatternRewri - Value val, Value pred); - - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, -- Value pred); -+ Type elemTy, Value pred); - - Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -diff --git a/lib/Dialect/NVGPU/IR/Dialect.cpp b/lib/Dialect/NVGPU/IR/Dialect.cpp ---- a/lib/Dialect/NVGPU/IR/Dialect.cpp -+++ b/lib/Dialect/NVGPU/IR/Dialect.cpp -@@ -73,7 +73,8 @@ void StoreDSmemOp::build(OpBuilder &buil - unsigned StoreDSmemOp::getBitwidth() { - auto addrTy = getAddr().getType(); - assert(addrTy.isa() && "addr must be a pointer type"); -- auto elemTy = addrTy.cast().getElementType(); -+ if (getValues().empty()) return 0; -+ auto elemTy = getValues().back().getType(); - return elemTy.getIntOrFloatBitWidth(); - } - -diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir ---- a/test/Conversion/tritongpu_to_llvm.mlir -+++ b/test/Conversion/tritongpu_to_llvm.mlir -@@ -1,7 +1,7 @@ - // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s - - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { -- // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) -+ // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) - // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32] - tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { -@@ -560,9 +560,9 @@ module attributes {"triton_gpu.num-ctas" - %index = arith.constant 1 : i32 - - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf16, #A> - tt.return - } -@@ -752,38 +752,38 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -799,14 +799,14 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -822,20 +822,20 @@ module attributes {"triton_gpu.num-ctas" - tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { - // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> - tt.return - } -@@ -889,12 +889,12 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_mmav2_block - tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0> - tt.return - } -@@ -909,16 +909,16 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_mmav1_block - tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked> - tt.return - } -@@ -932,9 +932,9 @@ module attributes {"triton_gpu.num-ctas" - // CHECK-LABEL: convert_layout_blocked_shared - tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store -- // CHECK-SAME: !llvm.ptr, 3> -+ // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> - tt.return - } -@@ -947,7 +947,7 @@ module attributes {"triton_gpu.num-ctas" - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_blocked1d_to_slice0 - tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { -- // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> -+ // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<3> - %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - tt.return - } -@@ -960,7 +960,7 @@ module attributes {"triton_gpu.num-ctas" - module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: convert_blocked1d_to_slice1 - tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { -- // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr, 3> -+ // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<3> - %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - tt.return - } -diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir ---- a/test/Conversion/tritongpu_to_llvm_hopper.mlir -+++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir -@@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" - %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> - %c0 = arith.constant 0 : i32 - %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> -- // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 -+ // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32 - %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> - tt.return - } -@@ -73,7 +73,7 @@ module attributes {"triton_gpu.num-ctas" - %src = triton_gpu.alloc_tensor : tensor<64x64xf32, #shared> - %c0 = arith.constant 0 : i32 - %dst = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> -- // CHECK: nvgpu.tma_store_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr, i1, i32, i32 -+ // CHECK: nvgpu.tma_store_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr<1>, !llvm.ptr<3>, i1, i32, i32 - triton_nvidia_gpu.store_async %dst, %src {cache = 1 : i32} : !tt.ptr, 1>, tensor<64x64xf32, #shared> - tt.return - } -diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir ---- a/test/NVGPU/test_cga.mlir -+++ b/test/NVGPU/test_cga.mlir -@@ -14,11 +14,11 @@ module attributes {"triton_gpu.num-warps - nvgpu.cga_barrier_arrive - nvgpu.cga_barrier_wait - -- %ptr = llvm.mlir.zero : !llvm.ptr -+ %ptr = llvm.mlir.zero : !llvm.ptr<3> - - // CHECK: llvm.inline_asm - %v = nvgpu.cluster_id -- llvm.store %v, %ptr : !llvm.ptr -+ llvm.store %v, %ptr : i32, !llvm.ptr<3> - - tt.return - } -diff --git a/test/NVGPU/test_mbarrier.mlir b/test/NVGPU/test_mbarrier.mlir ---- a/test/NVGPU/test_mbarrier.mlir -+++ b/test/NVGPU/test_mbarrier.mlir -@@ -2,18 +2,18 @@ - #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @test_mbarrier() { -- %mbarrier = llvm.mlir.zero : !llvm.ptr -+ %mbarrier = llvm.mlir.zero : !llvm.ptr<3> - %pred = arith.constant 1 : i1 - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr -+ nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 1 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 1 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 0 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 0 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 2 : i32, txCount = 128 : i32}: !llvm.ptr -+ nvgpu.mbarrier_arrive %mbarrier, %pred {arriveType = 2 : i32, txCount = 128 : i32}: !llvm.ptr<3> - // CHECK: llvm.inline_asm -- nvgpu.mbarrier_wait %mbarrier, %pred : !llvm.ptr, i1 -+ nvgpu.mbarrier_wait %mbarrier, %pred : !llvm.ptr<3>, i1 - tt.return - } - } // end module -diff --git a/test/NVGPU/test_tma.mlir b/test/NVGPU/test_tma.mlir ---- a/test/NVGPU/test_tma.mlir -+++ b/test/NVGPU/test_tma.mlir -@@ -2,9 +2,9 @@ - #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) { -- %mbarrier = llvm.mlir.zero : !llvm.ptr -- %tmaDesc = llvm.mlir.zero : !llvm.ptr -- %dst = llvm.mlir.zero : !llvm.ptr -+ %mbarrier = llvm.mlir.zero : !llvm.ptr<3> -+ %tmaDesc = llvm.mlir.zero : !llvm.ptr<1> -+ %dst = llvm.mlir.zero : !llvm.ptr<3> - %l2desc = arith.constant 0 : i64 - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 -@@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps - - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i32, i32 - - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint - // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i16 -- nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i16 -+ nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr<3>, !llvm.ptr<3>, !llvm.ptr<1>, i64, i1, i32, i32, i32, i32 - - tt.return - } diff --git a/third_party/xla/third_party/triton/cl580550344.patch b/third_party/xla/third_party/triton/cl580550344.patch deleted file mode 100644 index c1d598de65dc7e..00000000000000 --- a/third_party/xla/third_party/triton/cl580550344.patch +++ /dev/null @@ -1,304 +0,0 @@ -diff --git a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/BarrierOpToLLVM.cpp -@@ -67,8 +67,10 @@ struct AllocMBarrierOpConversion : publi - op.getCount()); - } - if (resultTensorTy) { -- auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(), -- {0}, loc, rewriter); -+ auto llvmElemTy = -+ getTypeConverter()->convertType(resultTensorTy.getElementType()); -+ auto smemObj = SharedMemoryObject( -+ smemBase, llvmElemTy, resultTensorTy.getShape(), {0}, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - } else { -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp -@@ -707,8 +707,9 @@ private: - auto dstLayout = dstTy.getEncoding(); - auto inOrd = getOrder(srcSharedLayout); - -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, adaptor.getSrc(), -+ getTypeConverter()->convertType(srcTy.getElementType()), rewriter); - auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); - - auto srcStrides = -@@ -843,8 +844,8 @@ private: - storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, - dst, smemBase, elemTy, loc, rewriter); - } -- auto smemObj = -- SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter); -+ auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd, -+ loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -@@ -1013,8 +1014,11 @@ private: - Value dst = op.getResult(); - bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor()); - -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto llvmElemTy = getTypeConverter()->convertType( -+ src.getType().cast().getElementType()); -+ -+ auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), -+ llvmElemTy, rewriter); - Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( -diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp -@@ -101,7 +101,9 @@ Value loadAFMA(Value A, Value llA, Block - - bool isARow = aOrder[0] == 1; - -- auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter); -+ auto aSmem = getSharedMemoryObjectFromStruct( -+ loc, llA, typeConverter->convertType(aTensorTy.getElementType()), -+ rewriter); - Value strideAM = aSmem.strides[0]; - Value strideAK = aSmem.strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; -@@ -166,7 +168,9 @@ Value loadBFMA(Value B, Value llB, Block - - bool isBRow = bOrder[0] == 1; - -- auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter); -+ auto bSmem = getSharedMemoryObjectFromStruct( -+ loc, llB, typeConverter->convertType(bTensorTy.getElementType()), -+ rewriter); - Value strideBN = bSmem.strides[1]; - Value strideBK = bSmem.strides[0]; - Value strideB0 = isBRow ? strideBN : strideBK; -diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp ---- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp -@@ -332,8 +332,15 @@ LogicalResult convertDot(TritonGPUToLLVM - Value baseA; - Value baseB; - if (aSharedLayout) -- baseA = getSharedMemoryObjectFromStruct(loc, loadedA, rewriter).base; -- baseB = getSharedMemoryObjectFromStruct(loc, loadedB, rewriter).base; -+ baseA = -+ getSharedMemoryObjectFromStruct( -+ loc, loadedA, -+ typeConverter->convertType(aTensorTy.getElementType()), rewriter) -+ .base; -+ baseB = getSharedMemoryObjectFromStruct( -+ loc, loadedB, -+ typeConverter->convertType(bTensorTy.getElementType()), rewriter) -+ .base; - if (aSharedLayout) { - auto aOrd = aSharedLayout.getOrder(); - transA = aOrd[0] == 0; -diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp -@@ -551,7 +551,8 @@ struct StoreAsyncOpConversion - Value llDst = adaptor.getDst(); - Value llSrc = adaptor.getSrc(); - auto srcShape = srcTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter); -+ auto smemObj = -+ getSharedMemoryObjectFromStruct(loc, llSrc, elemTy, rewriter); - - SmallVector offsetVals; - for (auto i = 0; i < srcShape.size(); ++i) { -@@ -1250,7 +1251,8 @@ struct InsertSliceOpConversion - - // newBase = base + offset - // Triton support either static and dynamic offsets -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, llDst, dstTy.getElementType(), rewriter); - SmallVector offsets; - SmallVector srcStrides; - auto mixedOffsets = op.getMixedOffsets(); -@@ -1339,7 +1341,8 @@ struct InsertSliceAsyncOpConversion - // %dst - auto dstTy = dst.getType().cast(); - auto dstShape = dstTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = -+ getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); - auto axis = op->getAttrOfType("axis").getInt(); - SmallVector offsetVals; - SmallVector srcStrides; -@@ -1601,7 +1604,9 @@ struct InsertSliceAsyncV2OpConversion - Value dst = op.getDst(); - auto dstTy = dst.getType().cast(); - auto dstShape = dstTy.getShape(); -- auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct( -+ loc, llDst, typeConverter->convertType(dstTy.getElementType()), -+ rewriter); - - // the offset of coord considering multicast slicing - SmallVector mcastOffsetVals; -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp -@@ -680,8 +680,8 @@ struct AllocTensorOpConversion - } - - auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); -- auto smemObj = -- SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter); -+ auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, -+ newOrder, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -@@ -704,10 +704,12 @@ struct ExtractSliceOpConversion - assert(op.hasUnitStride() && - "Only unit stride supported by ExtractSliceOpConversion"); - -+ auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); -+ - // newBase = base + offset - // Triton supports either static and dynamic offsets -- auto smemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), rewriter); -+ auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), -+ llvmElemTy, rewriter); - SmallVector opOffsetVals; - SmallVector offsetVals; - auto mixedOffsets = op.getMixedOffsets(); -@@ -735,10 +737,9 @@ struct ExtractSliceOpConversion - } - } - -- auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); - auto elemPtrTy = ptr_ty(llvmElemTy, 3); - smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), -- strideVals, offsetVals); -+ llvmElemTy, strideVals, offsetVals); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); -diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h ---- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h -@@ -525,7 +525,7 @@ public: - - SmallVector srcStrides = {dstStrides[0], dstStrides[1]}; - SmallVector offsetVals = {i32_val(0), i32_val(0)}; -- SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals); -+ SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); - - DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, dstElemTy, -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp ---- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp -@@ -131,9 +131,9 @@ void createStoreDSmem(Location loc, Patt - createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); - } - --SharedMemoryObject --getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, -- ConversionPatternRewriter &rewriter) { -+SharedMemoryObject getSharedMemoryObjectFromStruct( -+ Location loc, Value llvmStruct, Type elemTy, -+ ConversionPatternRewriter &rewriter) { - ArrayRef types = - llvmStruct.getType().cast().getBody(); - SmallVector elems(types.size()); -@@ -144,6 +144,7 @@ getSharedMemoryObjectFromStruct(Location - - auto rank = (elems.size() - 1) / 2; - return {/*base=*/elems[0], -+ /*baseElemType=*/elemTy, - /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, - /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; - } -diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h ---- a/lib/Conversion/TritonGPUToLLVM/Utility.h -+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h -@@ -234,6 +234,7 @@ getStridesFromShapeAndOrder(ArrayRef strides, -+ SharedMemoryObject(Value base, Type baseElemType, ArrayRef strides, - ArrayRef offsets) -- : base(base), strides(strides.begin(), strides.end()), -- offsets(offsets.begin(), offsets.end()) {} -+ : base(base), -+ baseElemType(baseElemType), -+ strides(strides.begin(), strides.end()), -+ offsets(offsets.begin(), offsets.end()) { -+ assert(baseElemType == -+ base.getType().cast().getElementType()); -+ } - -- SharedMemoryObject(Value base, ArrayRef shape, -+ SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) -- : base(base) { -+ : base(base), baseElemType(baseElemType) { -+ assert(baseElemType == -+ base.getType().cast().getElementType()); - strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - offsets.append(order.size(), i32_val(0)); - } -@@ -290,13 +298,13 @@ struct SharedMemoryObject { - Value cSwizzleOffset = getCSwizzleOffset(order); - Value offset = sub(i32_val(0), cSwizzleOffset); - Type type = base.getType(); -- return gep(type, base, offset); -+ return gep(type, baseElemType, base, offset); - } - }; - --SharedMemoryObject --getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, -- ConversionPatternRewriter &rewriter); -+SharedMemoryObject getSharedMemoryObjectFromStruct( -+ Location loc, Value llvmStruct, Type elemTy, -+ ConversionPatternRewriter &rewriter); - - // Convert an \param index to a multi-dim coordinate given \param shape and - // \param order. -diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp ---- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp -@@ -211,14 +211,16 @@ struct TransOpConversion - matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); -- auto srcSmemObj = -- getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); -+ auto llvmElemTy = getTypeConverter()->convertType( -+ op.getType().cast().getElementType()); -+ auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), -+ llvmElemTy, rewriter); - SmallVector dstStrides = {srcSmemObj.strides[1], - srcSmemObj.strides[0]}; - SmallVector dstOffsets = {srcSmemObj.offsets[1], - srcSmemObj.offsets[0]}; -- auto dstSmemObj = -- SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets); -+ auto dstSmemObj = SharedMemoryObject( -+ srcSmemObj.base, srcSmemObj.baseElemType, dstStrides, dstOffsets); - auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); diff --git a/third_party/xla/third_party/triton/cl580852372.patch b/third_party/xla/third_party/triton/cl580852372.patch deleted file mode 100644 index d15ec833fc6f25..00000000000000 --- a/third_party/xla/third_party/triton/cl580852372.patch +++ /dev/null @@ -1,15 +0,0 @@ -==== triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td#3 - /google/src/cloud/shyshkov/mlir_4983432f17eb4b445e161c5f8278c6ea4d5d1241_1699531174/triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td ==== -# action=edit type=text ---- triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td 2023-11-09 03:52:05.000000000 -0800 -+++ triton/include/triton/Dialect/NVGPU/IR/NVGPUOps.td 2023-11-09 04:00:14.000000000 -0800 -@@ -29,8 +29,8 @@ - include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType - --def LLVM_PointerGlobal : LLVM_OpaquePointerInAddressSpace<1>; --def LLVM_PointerShared : LLVM_OpaquePointerInAddressSpace<3>; -+def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; -+def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; - - class NVGPU_Op traits = []> : - LLVM_OpBase; diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 60f0c56799b5dc..3795c89bb75563 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl578837341" - TRITON_SHA256 = "0d8112bb31d48b5beadbfc2e13c52770a95d3759b312b15cf26dd72e71410568" + TRITON_COMMIT = "cl580208989" + TRITON_SHA256 = "bcf6e99a73c8797720325b0f2e48447cdae7f68c53c68bfe04c39104db542562" tf_http_archive( name = "triton", @@ -17,8 +17,5 @@ def repo(): patch_file = [ "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", - "//third_party/triton:cl580550344.patch", - "//third_party/triton:cl580481414.patch", - "//third_party/triton:cl580852372.patch", ], ) From 96dc6bf1a9444c596d2791c9eaf9b33d4b37e055 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 05:18:07 -0800 Subject: [PATCH 222/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/77cd0fb4225157e8326ddad1006d137ecced0aae. PiperOrigin-RevId: 583352228 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 4958ad4de72977..9754fd22233a75 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "a3959294a297645c34a6adbdd639d7df5c84d691" - TFRT_SHA256 = "3765b313e3da83774a50f01ba52ac155229359d0bfb170beb9714197ca64ec4a" + TFRT_COMMIT = "77cd0fb4225157e8326ddad1006d137ecced0aae" + TFRT_SHA256 = "c13c76ddb5a1f4646cfe7f5cb8f0a0b5789f5a2030086ac2796bf922c66eea6c" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 4958ad4de72977..9754fd22233a75 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "a3959294a297645c34a6adbdd639d7df5c84d691" - TFRT_SHA256 = "3765b313e3da83774a50f01ba52ac155229359d0bfb170beb9714197ca64ec4a" + TFRT_COMMIT = "77cd0fb4225157e8326ddad1006d137ecced0aae" + TFRT_SHA256 = "c13c76ddb5a1f4646cfe7f5cb8f0a0b5789f5a2030086ac2796bf922c66eea6c" tf_http_archive( name = "tf_runtime", From b023c38f5c3a78683e378a46a899c2bac378e670 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 06:51:28 -0800 Subject: [PATCH 223/391] Use `malloc` instead of `new` to allocate buffers to reduce overhead needed to ensure alignment. PiperOrigin-RevId: 583373189 --- tensorflow/lite/simple_memory_arena.cc | 22 +++++++++------------- tensorflow/lite/simple_memory_arena.h | 16 ++++++---------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 80f216072dada6..9c6a596ed82d10 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -16,12 +16,12 @@ limitations under the License. #include "tensorflow/lite/simple_memory_arena.h" #include -#include #include -#include #include #include +#include #include +#include #include #include "tensorflow/lite/core/c/common.h" @@ -54,20 +54,20 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), new_allocation_size); #endif - char* new_buffer = reinterpret_cast(std::malloc(new_allocation_size)); + auto new_buffer = std::unique_ptr(new char[new_allocation_size]); char* new_aligned_ptr = reinterpret_cast( - AlignTo(alignment_, reinterpret_cast(new_buffer))); + AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); if (new_size > 0 && allocation_size_ > 0) { // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = new_aligned_ptr - new_buffer; - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_; + const size_t new_alloc_alignment_adjustment = + new_aligned_ptr - new_buffer.get(); + const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); const size_t copy_amount = std::min(allocation_size_ - old_alloc_alignment_adjustment, new_allocation_size - new_alloc_alignment_adjustment); std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); } - std::free(buffer_); - buffer_ = new_buffer; + buffer_ = std::move(new_buffer); aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER if (allocation_size_ > 0) { @@ -84,15 +84,11 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { } void ResizableAlignedBuffer::Release() { - if (buffer_ == nullptr) { - return; - } #ifdef TF_LITE_TENSORFLOW_PROFILER OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), allocation_size_); #endif - std::free(buffer_); - buffer_ = nullptr; + buffer_.reset(); allocation_size_ = 0; aligned_ptr_ = nullptr; } diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 87603a26c32e78..05bb52e6a225e4 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -15,9 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ #define TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ -#include -#include +#include + #include +#include #include #include @@ -57,8 +58,7 @@ struct ArenaAllocWithUsageInterval { class ResizableAlignedBuffer { public: explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : buffer_(nullptr), - allocation_size_(0), + : allocation_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { // To silence unused private member warning, only used with @@ -66,8 +66,6 @@ class ResizableAlignedBuffer { (void)subgraph_index_; } - ~ResizableAlignedBuffer() { Release(); } - // Resizes the buffer to make sure new_size bytes fit in the buffer. Keeps // alignment and any existing the data. Returns true when any external // pointers into the data array need to be adjusted (the buffer was moved). @@ -84,12 +82,10 @@ class ResizableAlignedBuffer { private: size_t RequiredAllocationSize(size_t data_array_size) const { - // malloc guarantees returned pointers are aligned to at least max_align_t. - return data_array_size + - std::max(std::size_t{0}, alignment_ - alignof(std::max_align_t)); + return data_array_size + alignment_ - 1; } - char* buffer_; + std::unique_ptr buffer_; size_t allocation_size_; size_t alignment_; char* aligned_ptr_; From ff4ccedf63ae0b70be17a7e3fe7666077de143f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 06:59:17 -0800 Subject: [PATCH 224/391] Integrate LLVM at llvm/llvm-project@ec42d547eba5 Updates LLVM usage to match [ec42d547eba5](https://github.com/llvm/llvm-project/commit/ec42d547eba5) PiperOrigin-RevId: 583374953 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 1cd80656fe497f..50ee3694c2231a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "de176d8c5496d6cf20e82aface98e102c593dbe2" - LLVM_SHA256 = "83239b51d91f9b07d110f66ddea740f028efb61b1bdcf0d0cd0f53ec859a000d" + LLVM_COMMIT = "ec42d547eba5c0ad0bddbecc8902d35383968e78" + LLVM_SHA256 = "c7ec22eb1026b8d09afe2a70b2e2f5cf09a1805c4d16b004e72bba5b4153e2cf" tf_http_archive( name = name, From 145d46d913536a131496d7152e49299172fd6c82 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 07:41:49 -0800 Subject: [PATCH 225/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/886cfe0e0fd894ba1beafbb80585c6f32de8a2e4. PiperOrigin-RevId: 583383656 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 9754fd22233a75..b5ea5e6c8a2bd2 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "77cd0fb4225157e8326ddad1006d137ecced0aae" - TFRT_SHA256 = "c13c76ddb5a1f4646cfe7f5cb8f0a0b5789f5a2030086ac2796bf922c66eea6c" + TFRT_COMMIT = "886cfe0e0fd894ba1beafbb80585c6f32de8a2e4" + TFRT_SHA256 = "d391129b09b90a343f4b948f8fda109a260cdfb0e1ea63c978cafcdf528a85e3" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 9754fd22233a75..b5ea5e6c8a2bd2 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "77cd0fb4225157e8326ddad1006d137ecced0aae" - TFRT_SHA256 = "c13c76ddb5a1f4646cfe7f5cb8f0a0b5789f5a2030086ac2796bf922c66eea6c" + TFRT_COMMIT = "886cfe0e0fd894ba1beafbb80585c6f32de8a2e4" + TFRT_SHA256 = "d391129b09b90a343f4b948f8fda109a260cdfb0e1ea63c978cafcdf528a85e3" tf_http_archive( name = "tf_runtime", From 36b005ff53e4ed4626e17927e2327b8edddf0541 Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Fri, 17 Nov 2023 08:00:35 -0800 Subject: [PATCH 226/391] Add type annotations to stack.py and traceable_stack.py. Changes `DefaultStack` into a parameterized class which allows users to specify the type of object stored on that stack. Also adds generic parameters to `TraceableStack` and `TraceableObject` to allow the same type-checking specialization (if desired). Notably does not change the session stack to use this new generic type as that would introduce a cyclic dependency between `stack.py` and `client/session.py`. PiperOrigin-RevId: 583388633 --- tensorflow/python/framework/BUILD | 4 +- tensorflow/python/framework/stack.py | 20 ++++---- .../python/framework/traceable_stack.py | 46 +++++++++++++------ 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index f7d2e01d74fbe4..0d4369274e02ee 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1950,7 +1950,7 @@ pytype_strict_library( ]), ) -py_strict_library( +pytype_strict_library( name = "stack", srcs = ["stack.py"], visibility = visibility + ["//tensorflow:internal"], @@ -2037,7 +2037,7 @@ py_strict_library( ], ) -py_strict_library( +pytype_strict_library( name = "traceable_stack", srcs = ["traceable_stack.py"], srcs_version = "PY3", diff --git a/tensorflow/python/framework/stack.py b/tensorflow/python/framework/stack.py index 5a1e8fbd1311fd..a91fc99be530e9 100644 --- a/tensorflow/python/framework/stack.py +++ b/tensorflow/python/framework/stack.py @@ -14,39 +14,43 @@ # ============================================================================== """Classes used to handle thread-local stacks.""" +from collections.abc import Iterator import threading +from typing import Generic, Optional, TypeVar from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export +T = TypeVar("T") -class DefaultStack(threading.local): + +class DefaultStack(threading.local, Generic[T]): """A thread-local stack of objects for providing implicit defaults.""" def __init__(self): super().__init__() self._enforce_nesting = True - self.stack = [] + self.stack: list[T] = [] - def get_default(self): + def get_default(self) -> Optional[T]: return self.stack[-1] if self.stack else None - def reset(self): + def reset(self) -> None: self.stack = [] - def is_cleared(self): + def is_cleared(self) -> bool: return not self.stack @property - def enforce_nesting(self): + def enforce_nesting(self) -> bool: return self._enforce_nesting @enforce_nesting.setter - def enforce_nesting(self, value): + def enforce_nesting(self, value: bool): self._enforce_nesting = value @tf_contextlib.contextmanager - def get_controller(self, default): + def get_controller(self, default: T) -> Iterator[T]: """A context manager for manipulating a default stack.""" self.stack.append(default) try: diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py index bce16048a24983..8a1fde77e6d506 100644 --- a/tensorflow/python/framework/traceable_stack.py +++ b/tensorflow/python/framework/traceable_stack.py @@ -14,21 +14,32 @@ # ============================================================================== """A simple stack that associates filename and line numbers with each object.""" +from collections.abc import Iterator import inspect +import types +from typing import cast, Generic, Optional, TypeVar -class TraceableObject(object): +T = TypeVar("T") + + +class TraceableObject(Generic[T]): """Wrap an object together with its the code definition location.""" # Return codes for the set_filename_and_line_from_caller() method. SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2) - def __init__(self, obj, filename=None, lineno=None): + def __init__( + self, + obj: T, + filename: Optional[str] = None, + lineno: Optional[int] = None, + ): self.obj = obj self.filename = filename self.lineno = lineno - def set_filename_and_line_from_caller(self, offset=0): + def set_filename_and_line_from_caller(self, offset: int = 0) -> int: """Set filename and line using the caller's stack frame. If the requested stack information is not available, a heuristic may @@ -49,6 +60,9 @@ def set_filename_and_line_from_caller(self, offset=0): """ retcode = self.SUCCESS frame = inspect.currentframe() + if not frame: + return self.FAILURE + frame = cast(types.FrameType, frame) # Offset is defined in "Args" as relative to the caller. We are one frame # beyond the caller. for _ in range(offset + 1): @@ -57,9 +71,10 @@ def set_filename_and_line_from_caller(self, offset=0): # If the offset is too large then we use the largest offset possible. retcode = self.HEURISTIC_USED break + parent = cast(types.FrameType, parent) frame = parent self.filename = frame.f_code.co_filename - self.lineno = frame.f_lineno + self.lineno = cast(int, frame.f_lineno) return retcode def copy_metadata(self): @@ -67,19 +82,22 @@ def copy_metadata(self): return self.__class__(None, filename=self.filename, lineno=self.lineno) -class TraceableStack(object): +class TraceableStack(Generic[T]): """A stack of TraceableObjects.""" - def __init__(self, existing_stack=None): + def __init__( + self, existing_stack: Optional[list[TraceableObject[T]]] = None, + ): """Constructor. Args: existing_stack: [TraceableObject, ...] If provided, this object will set its new stack to a SHALLOW COPY of existing_stack. """ - self._stack = existing_stack[:] if existing_stack else [] + self._stack: list[TraceableObject[T]] = (existing_stack[:] if existing_stack + else []) - def push_obj(self, obj, offset=0): + def push_obj(self, obj: T, offset: int = 0): """Add object to the stack and record its filename and line information. Args: @@ -98,27 +116,27 @@ def push_obj(self, obj, offset=0): # beyond the caller and need to compensate. return traceable_obj.set_filename_and_line_from_caller(offset + 1) - def pop_obj(self): + def pop_obj(self) -> T: """Remove last-inserted object and return it, without filename/line info.""" return self._stack.pop().obj - def peek_top_obj(self): + def peek_top_obj(self) -> T: """Return the most recent stored object.""" return self._stack[-1].obj - def peek_objs(self): + def peek_objs(self) -> Iterator[T]: """Return iterator over stored objects ordered newest to oldest.""" return (t_obj.obj for t_obj in reversed(self._stack)) - def peek_traceable_objs(self): + def peek_traceable_objs(self) -> Iterator[TraceableObject[T]]: """Return iterator over stored TraceableObjects ordered newest to oldest.""" return reversed(self._stack) - def __len__(self): + def __len__(self) -> int: """Return number of items on the stack, and used for truth-value testing.""" return len(self._stack) - def copy(self): + def copy(self) -> "TraceableStack[T]": """Return a copy of self referencing the same objects but in a new list. This method is implemented to support thread-local stacks. From b949002793735a881ed7105999f84d29d486fa71 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 17 Nov 2023 08:02:13 -0800 Subject: [PATCH 227/391] Remove overly sensitive random number test. PiperOrigin-RevId: 583388965 --- tensorflow/compiler/tests/stateless_random_ops_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 7c48f5e3ec6518..01142082ae24f5 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -284,8 +284,9 @@ def testRandomNormalIsFinite(self): @parameterized.named_parameters( (f'_{dtype.name}_{seed}', dtype, seed) # pylint: disable=g-complex-comprehension - for seed in ([1, 2], [12, 23], [123, 456], [25252, 314159]) - for dtype in _allowed_types()) + for seed in ([1, 2], [12, 23], [25252, 314159]) + for dtype in _allowed_types() + ) def testDistributionOfStatelessRandomNormal(self, dtype, seed): """Use Anderson-Darling test to test distribution appears normal.""" with self.session() as sess, self.test_scope(): From fd77ec8c9261fa85ee4c8e4410e4a35e20327c4e Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Fri, 17 Nov 2023 08:08:48 -0800 Subject: [PATCH 228/391] [PJRT C API] Add non_donatable_input_indices to ExecuteOption. PiperOrigin-RevId: 583390832 --- third_party/xla/xla/pjrt/c/BUILD | 1 + third_party/xla/xla/pjrt/c/CHANGELOG.md | 4 ++++ third_party/xla/xla/pjrt/c/pjrt_c_api.h | 12 +++++++++++- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 7 +++++++ third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 18 +++++++++++++++--- third_party/xla/xla/pjrt/pjrt_c_api_client.h | 3 ++- .../xla/xla/pjrt/pjrt_c_api_client_test.cc | 4 +++- 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 05140ca12e0d83..e7e6f6a9a12098 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -82,6 +82,7 @@ cc_library( "//xla/service:hlo_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index 95c49a9129a98b..fa60ae26859c31 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,9 @@ # PJRT C API changelog +## 0.39 (Nov 16, 2023) +* Add non_donatable_input_indices and num_non_donatable_input_indices to +PJRT_ExecuteOptions. + ## 0.38 (Oct 30, 2023) * Use `enum` to define STRUCT_SIZE constants in a header file. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 122f9414ba6bdd..3cc1a26bd4460e 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 38 +#define PJRT_API_MINOR 39 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -1264,6 +1264,16 @@ struct PJRT_ExecuteOptions { // multi-host programs are launched in different orders on different hosts, // the launch IDs may be used by the runtime to detect the mismatch. int launch_id; + // A list of indices denoting the input buffers that should not be donated. + // An input buffer may be non-donable, for example, if it is referenced more + // than once. Since such runtime information is not available at compile time, + // the compiler might mark the input as `may-alias`, which could lead PjRt to + // donate the input buffer when it should not. By defining this list of + // indices, a higher-level PJRT caller can instruct PJRT client not to donate + // specific input buffers. The caller needs to make sure to keep it alive + // during the call. + const int64_t* non_donatable_input_indices; + size_t num_non_donatable_input_indices; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 7d8d1cb1b8a2e2..e782cf61331902 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -1337,6 +1338,12 @@ PJRT_Error* PJRT_LoadedExecutable_Execute( options.context = nullptr; options.multi_slice_config = nullptr; options.use_major_to_minor_data_layout_for_callbacks = true; + if (args->options->num_non_donatable_input_indices > 0) { + for (int i = 0; i < args->options->num_non_donatable_input_indices; ++i) { + options.non_donatable_input_indices.insert( + args->options->non_donatable_input_indices[i]); + } + } std::vector> cpp_argument_lists = Convert2DCBuffersToCppBuffers(args->argument_lists, args->num_devices, diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 3a7006f33b94a7..c7504baf2e3331 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -1426,7 +1426,8 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( std::vector>& c_output_lists_storage, std::vector& c_output_lists, std::optional>& device_complete_events, - SendRecvCallbackData& callback_data) { + SendRecvCallbackData& callback_data, + std::vector& non_donatable_input_indices_storage) { bool using_host_callbacks = !options.send_callbacks.empty() || !options.recv_callbacks.empty(); if (using_host_callbacks && @@ -1443,6 +1444,13 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( args.options = &c_options; args.options->struct_size = PJRT_ExecuteOptions_STRUCT_SIZE; args.options->launch_id = options.launch_id; + for (auto i : options.non_donatable_input_indices) { + non_donatable_input_indices_storage.push_back(i); + } + args.options->num_non_donatable_input_indices = + options.non_donatable_input_indices.size(); + args.options->non_donatable_input_indices = + non_donatable_input_indices_storage.data(); args.num_devices = argument_handles.size(); CHECK_GT(args.num_devices, 0); args.num_args = argument_handles[0].size(); @@ -1516,6 +1524,7 @@ PjRtCApiLoadedExecutable::Execute( std::vector> c_argument_lists_storage; std::vector> c_output_lists_storage; std::vector c_output_lists; + std::vector non_donatable_input_indices_storage; PJRT_ExecuteOptions c_options; c_options.num_send_ops = 0; c_options.num_recv_ops = 0; @@ -1531,7 +1540,8 @@ PjRtCApiLoadedExecutable::Execute( GetCommonExecuteArgs(argument_handles, options, c_options, c_argument_lists_storage, c_arguments, c_output_lists_storage, c_output_lists, - device_complete_events, *callback_data)); + device_complete_events, *callback_data, + non_donatable_input_indices_storage)); args.execute_device = nullptr; @@ -1581,6 +1591,7 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice( std::vector> c_argument_lists_storage; std::vector> c_output_lists_storage; std::vector c_output_lists; + std::vector non_donatable_input_indices_storage; PJRT_ExecuteOptions c_options; c_options.num_send_ops = 0; c_options.num_recv_ops = 0; @@ -1596,7 +1607,8 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice( GetCommonExecuteArgs(argument_handles_vec, options, c_options, c_argument_lists_storage, c_arguments, c_output_lists_storage, c_output_lists, - device_complete_events, *callback_data)); + device_complete_events, *callback_data, + non_donatable_input_indices_storage)); args.execute_device = tensorflow::down_cast(device)->c_device(); diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 0a62950c32859c..d147dee6ee2b03 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -728,7 +728,8 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable { std::vector>& c_output_lists_storage, std::vector& c_output_lists, std::optional>& device_complete_events, - SendRecvCallbackData& send_recv_callback_data); + SendRecvCallbackData& send_recv_callback_data, + std::vector& non_donatable_input_indices_storage); StatusOr>> ExecuteWithSingleDevice( absl::Span argument_handles, PjRtDevice* device, diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc index dd82746d2ae4cb..9e9cb71d226fb2 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc @@ -82,8 +82,10 @@ TEST(PjRtCApiClientTest, IsDynamicDimension) { auto computation = builder.Build(reshaped).value(); std::unique_ptr executable = client->Compile(computation, CompileOptions()).value(); + ExecuteOptions execute_options; + execute_options.non_donatable_input_indices = {0}; std::vector>> results = - executable->Execute({{param0.get(), param1.get()}}, ExecuteOptions()) + executable->Execute({{param0.get(), param1.get()}}, execute_options) .value(); ASSERT_EQ(results[0].size(), 1); auto* result_buffer = results[0][0].get(); From ce26a5fd967a3c82cb29b68ba9c83e8cc66c97b2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 17 Nov 2023 09:14:29 -0800 Subject: [PATCH 229/391] [PJRT] NFC: Cleanups to PJRT CPU client. * Move CPU client into cpu/ subdirectory. This change is in preparation to adding more CPU-specific code. * Since we're renaming the files anyway, rename tfrt_cpu_pjrt_client to cpu_client. There is only one kind of CPU client these days, so we can shorten the name. PiperOrigin-RevId: 583407197 --- third_party/xla/xla/pjrt/BUILD | 186 +----- third_party/xla/xla/pjrt/c/BUILD | 2 +- .../xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc | 2 +- third_party/xla/xla/pjrt/cpu/BUILD | 197 ++++++ .../{ => cpu}/abstract_tfrt_cpu_buffer.cc | 4 +- .../pjrt/{ => cpu}/abstract_tfrt_cpu_buffer.h | 8 +- .../cpu_client.cc} | 6 +- third_party/xla/xla/pjrt/cpu/cpu_client.h | 564 ++++++++++++++++++ .../cpu_client_test.cc} | 2 +- .../pjrt/{ => cpu}/pjrt_client_test_cpu.cc | 2 +- .../tracked_tfrt_cpu_device_buffer.cc | 2 +- .../tracked_tfrt_cpu_device_buffer.h | 6 +- .../tracked_tfrt_cpu_device_buffer_test.cc | 2 +- .../xla/xla/pjrt/tf_pjrt_client_test.cc | 2 +- .../xla/xla/pjrt/tfrt_cpu_pjrt_client.h | 545 +---------------- third_party/xla/xla/python/BUILD | 4 +- .../xla/xla/python/outfeed_receiver_test.cc | 2 +- third_party/xla/xla/python/pjrt_ifrt/BUILD | 2 +- .../pjrt_ifrt/tfrt_cpu_client_test_lib.cc | 2 +- third_party/xla/xla/python/xla.cc | 2 +- third_party/xla/xla/tests/BUILD | 2 +- .../xla/xla/tests/pjrt_cpu_client_registry.cc | 2 +- 22 files changed, 793 insertions(+), 753 deletions(-) create mode 100644 third_party/xla/xla/pjrt/cpu/BUILD rename third_party/xla/xla/pjrt/{ => cpu}/abstract_tfrt_cpu_buffer.cc (99%) rename third_party/xla/xla/pjrt/{ => cpu}/abstract_tfrt_cpu_buffer.h (98%) rename third_party/xla/xla/pjrt/{tfrt_cpu_pjrt_client.cc => cpu/cpu_client.cc} (99%) create mode 100644 third_party/xla/xla/pjrt/cpu/cpu_client.h rename third_party/xla/xla/pjrt/{tfrt_cpu_pjrt_client_test.cc => cpu/cpu_client_test.cc} (99%) rename third_party/xla/xla/pjrt/{ => cpu}/pjrt_client_test_cpu.cc (96%) rename third_party/xla/xla/pjrt/{ => cpu}/tracked_tfrt_cpu_device_buffer.cc (98%) rename third_party/xla/xla/pjrt/{ => cpu}/tracked_tfrt_cpu_device_buffer.h (97%) rename third_party/xla/xla/pjrt/{ => cpu}/tracked_tfrt_cpu_device_buffer_test.cc (99%) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index abbb5886b81cdd..59dc48817a1783 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -223,16 +223,6 @@ cc_library( alwayslink = 1, ) -xla_cc_test( - name = "pjrt_client_test_cpu", - srcs = ["pjrt_client_test_cpu.cc"], - deps = [ - ":pjrt_client_test_common", - ":tfrt_cpu_pjrt_client", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "pjrt_executable", srcs = ["pjrt_executable.cc"], @@ -571,183 +561,13 @@ cc_library( ], ) -cc_library( - name = "tracked_tfrt_cpu_device_buffer", - srcs = ["tracked_tfrt_cpu_device_buffer.cc"], - hdrs = ["tracked_tfrt_cpu_device_buffer.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla:cpu_function_runtime", - "//xla:shape_util", - "//xla:util", - "//xla/runtime:cpu_event", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/concurrency:async_value", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:platform_port", - ], -) - -xla_cc_test( - name = "tracked_tfrt_cpu_device_buffer_test", - srcs = ["tracked_tfrt_cpu_device_buffer_test.cc"], - deps = [ - ":tracked_tfrt_cpu_device_buffer", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/concurrency:async_value", - "@local_tsl//tsl/platform:env", - ], -) - -cc_library( - name = "abstract_tfrt_cpu_buffer", - srcs = ["abstract_tfrt_cpu_buffer.cc"], - hdrs = ["abstract_tfrt_cpu_buffer.h"], - visibility = ["//visibility:public"], - deps = [ - ":pjrt_client", - ":pjrt_future", - ":tracked_tfrt_cpu_device_buffer", - ":transpose", - ":utils", - "//xla:cpu_function_runtime", - "//xla:literal", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/runtime:cpu_event", - "//xla/service:shaped_buffer", - "//xla/service/cpu:cpu_executable", - "//xla/service/cpu:cpu_xfeed", - "//xla/stream_executor:device_memory", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/concurrency:async_value", - "@local_tsl//tsl/concurrency:ref_count", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:connected_traceme", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - +# Transitional forwarding target. Use cpu:cpu_client instead. cc_library( name = "tfrt_cpu_pjrt_client", - srcs = ["tfrt_cpu_pjrt_client.cc"], hdrs = ["tfrt_cpu_pjrt_client.h"], visibility = ["//visibility:public"], deps = [ - ":abstract_tfrt_cpu_buffer", - ":compile_options_proto_cc", - ":mlir_to_hlo", - ":pjrt_client", - ":pjrt_executable", - ":pjrt_future", - ":semaphore", - ":tracked_tfrt_cpu_device_buffer", - ":transpose", - ":utils", - "//xla:array", - "//xla:debug_options_flags", - "//xla:executable_run_options", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/client:executable_build_options", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/pjrt/distributed:topology_util", - "//xla/runtime:cpu_event", - "//xla/service:buffer_assignment", - "//xla/service:compiler", - "//xla/service:computation_placer_hdr", - "//xla/service:custom_call_status_public_headers", - "//xla/service:dump", - "//xla/service:executable", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_module_config", - "//xla/service:hlo_module_util", - "//xla/service:hlo_proto_cc", - "//xla/service:hlo_value", - "//xla/service/cpu:buffer_desc", - "//xla/service/cpu:cpu_compiler", - "//xla/service/cpu:cpu_executable", - "//xla/service/cpu:cpu_xfeed", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. - "@llvm-project//mlir:IR", - "@local_tsl//tsl/concurrency:async_value", - "@local_tsl//tsl/concurrency:ref_count", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:denormal", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:fingerprint", - "@local_tsl//tsl/platform:setround", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:connected_traceme", - "@local_tsl//tsl/profiler/lib:context_types_hdrs", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - -xla_cc_test( - name = "tfrt_cpu_pjrt_client_test", - srcs = ["tfrt_cpu_pjrt_client_test.cc"], - deps = [ - ":tfrt_cpu_pjrt_client", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status", - "//xla:util", - "//xla/service:custom_call_status_public_headers", - "//xla/service:custom_call_target_registry", - "//xla/service:hlo_parser", - "//xla/tests:test_utils", - "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "//xla/pjrt/cpu:cpu_client", ], ) @@ -925,8 +745,8 @@ xla_cc_test( srcs = ["tf_pjrt_client_test.cc"], deps = [ ":tf_pjrt_client", - ":tfrt_cpu_pjrt_client", "//xla:literal_util", + "//xla/pjrt/cpu:cpu_client", "//xla/service:hlo_parser", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index e7e6f6a9a12098..d74d8d92fb5eb9 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -136,7 +136,7 @@ cc_library( ":pjrt_c_api_helpers", ":pjrt_c_api_wrapper_impl", "//xla/pjrt:pjrt_client", - "//xla/pjrt:tfrt_cpu_pjrt_client", + "//xla/pjrt/cpu:cpu_client", ], ) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 89bff439bff6f7..b60296cee55556 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" namespace pjrt { namespace cpu_plugin { diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD new file mode 100644 index 00000000000000..fd615f18914f78 --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -0,0 +1,197 @@ +load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +xla_cc_test( + name = "pjrt_client_test_cpu", + srcs = ["pjrt_client_test_cpu.cc"], + deps = [ + ":cpu_client", + "//xla/pjrt:pjrt_client_test_common", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tracked_tfrt_cpu_device_buffer", + srcs = ["tracked_tfrt_cpu_device_buffer.cc"], + hdrs = ["tracked_tfrt_cpu_device_buffer.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:cpu_function_runtime", + "//xla:shape_util", + "//xla:util", + "//xla/runtime:cpu_event", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/concurrency:async_value", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:platform_port", + ], +) + +xla_cc_test( + name = "tracked_tfrt_cpu_device_buffer_test", + srcs = ["tracked_tfrt_cpu_device_buffer_test.cc"], + deps = [ + ":tracked_tfrt_cpu_device_buffer", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/concurrency:async_value", + "@local_tsl//tsl/platform:env", + ], +) + +cc_library( + name = "abstract_tfrt_cpu_buffer", + srcs = ["abstract_tfrt_cpu_buffer.cc"], + hdrs = ["abstract_tfrt_cpu_buffer.h"], + visibility = ["//visibility:public"], + deps = [ + ":tracked_tfrt_cpu_device_buffer", + "//xla:cpu_function_runtime", + "//xla:literal", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_future", + "//xla/pjrt:transpose", + "//xla/pjrt:utils", + "//xla/runtime:cpu_event", + "//xla/service:shaped_buffer", + "//xla/service/cpu:cpu_executable", + "//xla/service/cpu:cpu_xfeed", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/concurrency:async_value", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:connected_traceme", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +cc_library( + name = "cpu_client", + srcs = ["cpu_client.cc"], + hdrs = ["cpu_client.h"], + visibility = ["//visibility:public"], + deps = [ + ":abstract_tfrt_cpu_buffer", + ":tracked_tfrt_cpu_device_buffer", + "//xla:array", + "//xla:debug_options_flags", + "//xla:executable_run_options", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/client:executable_build_options", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/pjrt:compile_options_proto_cc", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_future", + "//xla/pjrt:semaphore", + "//xla/pjrt:transpose", + "//xla/pjrt:utils", + "//xla/pjrt/distributed:topology_util", + "//xla/runtime:cpu_event", + "//xla/service:buffer_assignment", + "//xla/service:compiler", + "//xla/service:computation_placer_hdr", + "//xla/service:custom_call_status_public_headers", + "//xla/service:dump", + "//xla/service:executable", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_module_config", + "//xla/service:hlo_module_util", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "//xla/service/cpu:buffer_desc", + "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:cpu_executable", + "//xla/service/cpu:cpu_xfeed", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. + "@llvm-project//mlir:IR", + "@local_tsl//tsl/concurrency:async_value", + "@local_tsl//tsl/concurrency:ref_count", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:denormal", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:setround", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:connected_traceme", + "@local_tsl//tsl/profiler/lib:context_types_hdrs", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +xla_cc_test( + name = "cpu_client_test", + srcs = ["cpu_client_test.cc"], + deps = [ + ":cpu_client", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status", + "//xla:util", + "//xla/service:custom_call_status_public_headers", + "//xla/service:custom_call_target_registry", + "//xla/service:hlo_parser", + "//xla/tests:test_utils", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc similarity index 99% rename from third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc rename to third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 92de62a2f39476..3bf5805e75888f 100644 --- a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/abstract_tfrt_cpu_buffer.h" +#include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" #include #include @@ -40,9 +40,9 @@ limitations under the License. #include "absl/types/span.h" #include "xla/cpu_function_runtime.h" #include "xla/literal.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" #include "xla/primitive_util.h" diff --git a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h similarity index 98% rename from third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.h rename to third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 79be926129f8e8..f1cec9f5acb17c 100644 --- a/third_party/xla/xla/pjrt/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PJRT_ABSTRACT_TFRT_CPU_BUFFER_H_ -#define XLA_PJRT_ABSTRACT_TFRT_CPU_BUFFER_H_ +#ifndef XLA_PJRT_CPU_ABSTRACT_TFRT_CPU_BUFFER_H_ +#define XLA_PJRT_CPU_ABSTRACT_TFRT_CPU_BUFFER_H_ #include #include @@ -34,9 +34,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/runtime/cpu_event.h" #include "xla/shape.h" @@ -421,4 +421,4 @@ class AbstractAsyncHostToHostMemoryTransferManager } // namespace xla -#endif // XLA_PJRT_ABSTRACT_TFRT_CPU_BUFFER_H_ +#endif // XLA_PJRT_CPU_ABSTRACT_TFRT_CPU_BUFFER_H_ diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc similarity index 99% rename from third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc rename to third_party/xla/xla/pjrt/cpu/cpu_client.cc index ad4b5ea45711ec..a869f9fb8bbd3b 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" #include #include @@ -56,15 +56,15 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/abstract_tfrt_cpu_buffer.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/distributed/topology_util.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/semaphore.h" -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" #include "xla/runtime/cpu_event.h" diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h new file mode 100644 index 00000000000000..a350ce7d843b94 --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -0,0 +1,564 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_CPU_CLIENT_H_ +#define XLA_PJRT_CPU_CPU_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "xla/client/xla_computation.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal.h" +#include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/semaphore.h" +#include "xla/pjrt/transpose.h" +#include "xla/runtime/cpu_event.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/status.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/concurrency/async_value_ref.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/platform/threadpool.h" + +namespace xla { + +class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { + public: + TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); + + int id() const override { return id_; } + + int process_index() const override { return process_index_; } + + int local_hardware_id() const { return local_hardware_id_; } + + absl::string_view device_kind() const override; + + absl::string_view DebugString() const override; + + absl::string_view ToString() const override; + + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + private: + int id_; + int process_index_; + int local_hardware_id_; + std::string debug_string_; + std::string to_string_; + absl::flat_hash_map attributes_ = {}; +}; + +class TfrtCpuDevice final : public PjRtDevice { + public: + explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, + int max_inflight_computations = 32); + + const TfrtCpuDeviceDescription& description() const override { + return description_; + } + + void SetClient(PjRtClient* client) { + CHECK(client_ == nullptr); + client_ = client; + } + + PjRtClient* client() const override { return client_; } + + bool IsAddressable() const override { + return process_index() == client()->process_index(); + } + + int local_hardware_id() const override { + return description_.local_hardware_id(); + } + + Status TransferToInfeed(const LiteralSlice& literal) override; + + Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; + + absl::Span memory_spaces() const override; + + StatusOr default_memory_space() const override; + + // Returns a semaphore for admission control on inflight computations. + Semaphore& max_inflight_computations_semaphore() { + return max_inflight_computations_semaphore_; + } + + std::unique_ptr CreateAsyncTrackingEvent( + absl::string_view description) const override { + return nullptr; + } + + private: + PjRtClient* client_ = nullptr; + TfrtCpuDeviceDescription description_; + + // TODO(zhangqiaorjc): Optimize semaphore related overhead. + // Semaphore used to limit how many programs can be enqueued by the host + // ahead of the device. + Semaphore max_inflight_computations_semaphore_; +}; + +class TfrtCpuClient final : public PjRtClient { + public: + TfrtCpuClient(int process_index, + std::vector> devices, + size_t num_threads); + ~TfrtCpuClient() override; + + int process_index() const override { return process_index_; } + + int device_count() const override { return devices_.size(); } + + int addressable_device_count() const override { + return addressable_devices_.size(); + } + + absl::Span devices() const override { return devices_; } + + absl::Span addressable_devices() const override { + return addressable_devices_; + } + + StatusOr LookupDevice(int device_id) const override; + + StatusOr LookupAddressableDevice( + int local_hardware_id) const override; + + absl::Span memory_spaces() const override; + + PjRtPlatformId platform_id() const override { + return tsl::Fingerprint64(CpuName()); + } + + absl::string_view platform_name() const override { return CpuName(); } + + absl::string_view platform_version() const override { return ""; } + + PjRtRuntimeType runtime_type() const override { return kTfrt; } + + StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + StatusOr> GetHloCostAnalysis() + const override; + + StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; + StatusOr> Compile( + mlir::ModuleOp module, CompileOptions options) override; + + // For TfrtCpuClient, `options` is mandatory. + // This function returns an InvalidArgument error if `std::nullopt` is passed. + // TODO(b/237720161): make it actually optional + StatusOr> DeserializeExecutable( + absl::string_view serialized, + std::optional options) override; + + StatusOr> CreateErrorBuffer( + Status error, const Shape& shape, PjRtDevice* device) override; + + StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device) override; + + StatusOr> + CreateBuffersForAsyncHostToDevice(absl::Span shapes, + PjRtDevice* device) override; + + absl::StatusOr> + CreateBuffersForAsyncHostToDevice(absl::Span shapes, + PjRtMemorySpace* memory_space) override { + return Unimplemented( + "CreateBuffersForAsyncHostToDevice with memory_space not implemented."); + } + + StatusOr> BufferFromHostBuffer( + const void* data, PrimitiveType type, absl::Span dims, + std::optional> byte_strides, + HostBufferSemantics host_buffer_semantics, + std::function on_done_with_host_buffer, + PjRtDevice* device) override; + + StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) override; + + StatusOr>> + MakeCrossHostReceiveBuffers(absl::Span shapes, + PjRtDevice* device, + PjRtCrossHostRecvNotifier notifier) override { + return Unimplemented("MakeCrossHostReceiveBuffers not implemented."); + } + + StatusOr>> + MakeCrossHostReceiveBuffersForGather( + absl::Span shapes, std::vector gather_details, + PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { + return Unimplemented( + "MakeCrossHostReceiveBuffersForGather not implemented."); + } + + StatusOr> CreateViewOfDeviceBuffer( + void* device_ptr, const Shape& shape, PjRtDevice* device, + std::function on_delete_callback, + std::optional stream) override; + + StatusOr CreateChannelHandle() override { + return Unimplemented("CreateChannelHandle not implemented."); + } + StatusOr CreateDeviceToHostChannelHandle() override { + return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); + } + StatusOr CreateHostToDeviceChannelHandle() override { + return Unimplemented("CreateHostToDeviceChannelHandle not implemented."); + } + + Status Defragment() override { + return Unimplemented("Defragment not implemented."); + } + + tsl::thread::ThreadPool* pjrt_client_thread_pool() const { + return pjrt_client_thread_pool_.get(); + } + + AsyncWorkRunner* async_work_runner() const { + return async_work_runner_.get(); + } + + Eigen::ThreadPoolDevice* eigen_intraop_device() const { + return eigen_intraop_device_.get(); + } + + tsl::AsyncValueRef GetLastCollectiveLaunchEvent() { + absl::MutexLock lock(&mu_); + return last_collective_launch_event_.CopyRef(); + } + + void SetLastCollectiveLaunchEvent( + tsl::AsyncValueRef event) { + absl::MutexLock lock(&mu_); + last_collective_launch_event_ = std::move(event); + } + + private: + int process_index_; + // Includes all devices, including non-addressable devices. + std::vector> owned_devices_; + // Pointers to `owned_devices_`. + std::vector devices_; + // Maps Device::id() to the corresponding Device. Includes all devices. + absl::flat_hash_map id_to_device_; + // Addressable devices indexed by core_id. + std::vector addressable_devices_; + std::unique_ptr computation_placer_; + + // Thread pool for running PjRtClient tasks. + std::unique_ptr pjrt_client_thread_pool_; + std::unique_ptr async_work_runner_; + + // TODO(zhangqiaorjc): Use tsl::compat::EigenHostContextThreadPool. + std::unique_ptr eigen_intraop_pool_; + std::unique_ptr eigen_intraop_device_; + + // Launching collectives are prone to deadlock when we use fixed-sized + // threadpools since ExecuteHelper will block until all replicas reach the + // barrier. We ensure that + // 1. Threadpool size is at least as large as device_count so one collective + // launch over all devices can succeed. + // 2. Gang-schedule each collective by conservatively ensuring a total order + // of collectives and launching only one collective at a time to avoid + // having no active threads to make progress + // TODO(zhangqiaorjc): Explore alternatives that allow multiple concurrent + // collectives. + mutable absl::Mutex mu_; + tsl::AsyncValueRef last_collective_launch_event_ + ABSL_GUARDED_BY(mu_); + + // A cache for transpose plans. We use transposes to convert + // (possibly strided) buffers provided to BufferFromHostBuffer into dense + // major-to-minor layout. + absl::Mutex transpose_mu_; + TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); +}; + +class TfrtCpuBuffer final : public AbstractTfrtCpuBuffer { + public: + TfrtCpuBuffer( + Shape on_device_shape, + std::unique_ptr tracked_device_buffer, + TfrtCpuClient* client, TfrtCpuDevice* device); + + TfrtCpuBuffer(const TfrtCpuBuffer&) = delete; + TfrtCpuBuffer(TfrtCpuBuffer&&) = delete; + TfrtCpuBuffer& operator=(const TfrtCpuBuffer&) = delete; + TfrtCpuBuffer& operator=(TfrtCpuBuffer&&) = delete; + + PjRtMemorySpace* memory_space() const override { return nullptr; } + TfrtCpuDevice* device() const override { return device_; } + TfrtCpuClient* client() const override { return client_; } + + using PjRtBuffer::ToLiteralSync; + PjRtFuture ToLiteral(MutableLiteralBase* literal) override; + + StatusOr> CopyToDevice( + PjRtDevice* dst_device) override; + + private: + absl::string_view buffer_name() const override { return "TfrtCpuBuffer"; } + + TfrtCpuClient* client_; + TfrtCpuDevice* const device_; +}; + +class TfrtCpuExecutable final : public PjRtLoadedExecutable { + public: + TfrtCpuExecutable( + int num_replicas, int num_partitions, + std::shared_ptr device_assignment, + bool parameter_is_tupled_arguments, CompileOptions compile_options, + std::unique_ptr cpu_executable, + BufferAllocation::Index result_buffer_index, + absl::InlinedVector result_buffer_indices, + std::vector addressable_device_logical_ids, + std::vector addressable_devices, TfrtCpuClient* client); + + ~TfrtCpuExecutable() override = default; + + TfrtCpuClient* client() const override { return client_; } + + absl::string_view name() const override { + return cpu_executable_->shared_module()->name(); + } + + int num_replicas() const override { return num_replicas_; } + + int num_partitions() const override { return num_partitions_; } + + int64_t SizeOfGeneratedCodeInBytes() const override { + return cpu_executable_->SizeOfGeneratedCodeInBytes(); + } + + const DeviceAssignment& device_assignment() const override { + return *device_assignment_; + } + + absl::Span addressable_device_logical_ids() + const override { + return addressable_device_logical_ids_; + } + + absl::Span addressable_devices() const override { + return addressable_devices_; + } + + StatusOr>> GetHloModules() + const override { + return std::vector>{ + cpu_executable_->shared_module()}; + } + + StatusOr>> GetOutputMemoryKinds() + const override { + return Unimplemented("GetOutputMemoryKinds is not supported."); + } + + StatusOr GetCompiledMemoryStats() const override { + CompiledMemoryStats memory_stats = CompiledMemoryStats(); + memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); + const HloProto* proto = cpu_executable_->hlo_proto(); + if (!proto) { + return tsl::errors::FailedPrecondition( + "cpu_executable_ has no hlo_proto."); + } + memory_stats.serialized_hlo_proto = proto->SerializeAsString(); + return memory_stats; + } + + using PjRtLoadedExecutable::Execute; + StatusOr>>> Execute( + absl::Span> argument_handles, + const ExecuteOptions& options, + std::optional>>& returned_futures) + override; + + using PjRtLoadedExecutable::ExecuteSharded; + StatusOr>> ExecuteSharded( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, + std::optional>& returned_future, + bool fill_future) override; + + using PjRtLoadedExecutable::ExecutePortable; + StatusOr>> ExecutePortable( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, + std::optional>& returned_future, + bool fill_future) override; + + void Delete() override; + + bool IsDeleted() override; + + StatusOr SerializeExecutable() const override; + + bool IsReturnedFutureSupported() const override { return true; } + + StatusOr> Fingerprint() const; + + std::shared_ptr cpu_executable() const { return cpu_executable_; } + + StatusOr FingerprintExecutable() const override { + return Unimplemented("Fingerprinting executable is not supported."); + } + + private: + friend class TfrtCpuClient; + + Status SetUpDonation(bool tuple_inputs); + + // Checks that the input buffers passed in by the user have the correct size + // on device for the compiled program. + Status CheckBufferCompatibilities( + absl::Span const> + input_buffers) const; + + StatusOr ExecuteHelper( + absl::Span argument_handles, int replica, + int partition, const RunId& run_id, const ExecuteOptions& options, + tsl::AsyncValueRef last_collective_launch_event, + bool fill_future, TfrtCpuDevice* device = nullptr); + + TfrtCpuClient* client_; + + int num_replicas_; + int num_partitions_; + std::shared_ptr device_assignment_; + bool parameter_is_tupled_arguments_; + CompileOptions compile_options_; + + std::shared_ptr cpu_executable_; + + // Caching `result_buffer_index_` and `result_buffer_indices_` to avoid lookup + // HLO dataflow analysis data structures in program execution critical path. + + // Buffer allocation index corresponding to root buffer buffer. + BufferAllocation::Index result_buffer_index_; + // Buffer allocation indices corresponding to each result buffer leaf buffer. + absl::InlinedVector result_buffer_indices_; + + // Size on device of each leaf buffer of the compiled program, cached here + // for performance reasons. + std::vector input_buffer_sizes_in_bytes_; + + // A sorted vector of parameters that have any aliased buffers and thus must + // be donated when executing the computation. + std::vector parameters_that_must_be_donated_; + + // The replica and partition indices of device_assignment_ to be run by this + // client. On single-host platforms without partitioning, this is all + // replicas (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may + // not be the case on multi-host platforms. If there are 4 replicas and 2 + // partitions on a single host platform, size of + // addressable_device_logical_ids_ is 4*2 = 8. + std::vector addressable_device_logical_ids_; + + // addressable_devices_[i] is the Device to which + // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of + // unique_ptrs to play well with the Python bindings (see xla.cc). + std::vector addressable_devices_; + + // Cached result of comparing HloCostAnalysis FLOP estimate for execute + // critical path. + bool cheap_computation_; +}; + +struct CpuClientOptions { + // Does nothing at the moment. Ignored. + bool asynchronous = true; + + // Number of CPU devices. If not provided, the value of + // --xla_force_host_platform_device_count is used. + std::optional cpu_device_count = std::nullopt; + + int max_inflight_computations_per_device = 32; + + // Number of distributed nodes. node_id, kv_get, and kv_put are ignored if + // this is set to 1. + int num_nodes = 1; + + // My node ID. + int node_id = 0; + + // KV store primitives for sharing topology information. + PjRtClient::KeyValueGetCallback kv_get = nullptr; + PjRtClient::KeyValuePutCallback kv_put = nullptr; +}; +StatusOr> GetTfrtCpuClient( + const CpuClientOptions& options); + +// Deprecated. Use the overload that takes 'options' instead. +inline StatusOr> GetTfrtCpuClient( + bool asynchronous) { + CpuClientOptions options; + options.asynchronous = asynchronous; + return GetTfrtCpuClient(options); +} + +// Deprecated. Use the overload that takes 'options' instead. +inline StatusOr> GetTfrtCpuClient( + bool asynchronous, int cpu_device_count, + int max_inflight_computations_per_device = 32) { + CpuClientOptions options; + options.asynchronous = asynchronous; + options.cpu_device_count = cpu_device_count; + options.max_inflight_computations_per_device = + max_inflight_computations_per_device; + return GetTfrtCpuClient(options); +} + +} // namespace xla + +#endif // XLA_PJRT_CPU_CPU_CLIENT_H_ diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc similarity index 99% rename from third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc rename to third_party/xla/xla/pjrt/cpu/cpu_client_test.cc index 7503f5370e31db..f23aab9f808eeb 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" #include diff --git a/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc b/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc similarity index 96% rename from third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc rename to third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc index 59ef9ff1514472..ccc2ac8cc2575b 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test_cpu.cc +++ b/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/pjrt/pjrt_client_test.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" namespace xla { namespace { diff --git a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc similarity index 98% rename from third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc rename to third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc index 8895aa927f44dd..5d327e57cb4018 100644 --- a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include #include diff --git a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.h b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h similarity index 97% rename from third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.h rename to third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h index de491729ecffe3..2d4b7589cbb4b4 100644 --- a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ -#define XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ +#ifndef XLA_PJRT_CPU_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ +#define XLA_PJRT_CPU_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ #include #include @@ -148,4 +148,4 @@ class TrackedTfrtCpuDeviceBuffer { }; } // namespace xla -#endif // XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ +#endif // XLA_PJRT_CPU_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ diff --git a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc similarity index 99% rename from third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc rename to third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc index d064e96c629fc0..4ca8b79a2fd884 100644 --- a/third_party/xla/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" +#include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include #include diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc b/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc index 49e1f219028706..9e3b785fc01853 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "xla/literal_util.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/service/hlo_parser.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" diff --git a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h index e9543ab92e93e7..7fa97e13118f0d 100644 --- a/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -16,549 +16,8 @@ limitations under the License. #ifndef XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ #define XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ -#include -#include -#include -#include -#include -#include -#include -#include +// Transitional forwarding header. Please include cpu/cpu_client.h directly. -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "xla/client/xla_computation.h" -#include "xla/executable_run_options.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" -#include "xla/pjrt/abstract_tfrt_cpu_buffer.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/semaphore.h" -#include "xla/pjrt/tracked_tfrt_cpu_device_buffer.h" -#include "xla/pjrt/transpose.h" -#include "xla/runtime/cpu_event.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/computation_placer.h" -#include "xla/service/executable.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/status.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/concurrency/async_value_ref.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/fingerprint.h" -#include "tsl/platform/threadpool.h" - -namespace xla { - -class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { - public: - TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id); - - int id() const override { return id_; } - - int process_index() const override { return process_index_; } - - int local_hardware_id() const { return local_hardware_id_; } - - absl::string_view device_kind() const override; - - absl::string_view DebugString() const override; - - absl::string_view ToString() const override; - - const absl::flat_hash_map& Attributes() - const override { - return attributes_; - } - - private: - int id_; - int process_index_; - int local_hardware_id_; - std::string debug_string_; - std::string to_string_; - absl::flat_hash_map attributes_ = {}; -}; - -class TfrtCpuDevice final : public PjRtDevice { - public: - explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id, - int max_inflight_computations = 32); - - const TfrtCpuDeviceDescription& description() const override { - return description_; - } - - void SetClient(PjRtClient* client) { - CHECK(client_ == nullptr); - client_ = client; - } - - PjRtClient* client() const override { return client_; } - - bool IsAddressable() const override { - return process_index() == client()->process_index(); - } - - int local_hardware_id() const override { - return description_.local_hardware_id(); - } - - Status TransferToInfeed(const LiteralSlice& literal) override; - - Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; - - absl::Span memory_spaces() const override; - - StatusOr default_memory_space() const override; - - // Returns a semaphore for admission control on inflight computations. - Semaphore& max_inflight_computations_semaphore() { - return max_inflight_computations_semaphore_; - } - - std::unique_ptr CreateAsyncTrackingEvent( - absl::string_view description) const override { - return nullptr; - } - - private: - PjRtClient* client_ = nullptr; - TfrtCpuDeviceDescription description_; - - // TODO(zhangqiaorjc): Optimize semaphore related overhead. - // Semaphore used to limit how many programs can be enqueued by the host - // ahead of the device. - Semaphore max_inflight_computations_semaphore_; -}; - -class TfrtCpuClient final : public PjRtClient { - public: - TfrtCpuClient(int process_index, - std::vector> devices, - size_t num_threads); - ~TfrtCpuClient() override; - - int process_index() const override { return process_index_; } - - int device_count() const override { return devices_.size(); } - - int addressable_device_count() const override { - return addressable_devices_.size(); - } - - absl::Span devices() const override { return devices_; } - - absl::Span addressable_devices() const override { - return addressable_devices_; - } - - StatusOr LookupDevice(int device_id) const override; - - StatusOr LookupAddressableDevice( - int local_hardware_id) const override; - - absl::Span memory_spaces() const override; - - PjRtPlatformId platform_id() const override { - return tsl::Fingerprint64(CpuName()); - } - - absl::string_view platform_name() const override { return CpuName(); } - - absl::string_view platform_version() const override { return ""; } - - PjRtRuntimeType runtime_type() const override { return kTfrt; } - - StatusOr GetDefaultDeviceAssignment( - int num_replicas, int num_partitions) const override; - - StatusOr> GetHloCostAnalysis() - const override; - - StatusOr> Compile( - const XlaComputation& computation, CompileOptions options) override; - StatusOr> Compile( - mlir::ModuleOp module, CompileOptions options) override; - - // For TfrtCpuClient, `options` is mandatory. - // This function returns an InvalidArgument error if `std::nullopt` is passed. - // TODO(b/237720161): make it actually optional - StatusOr> DeserializeExecutable( - absl::string_view serialized, - std::optional options) override; - - StatusOr> CreateErrorBuffer( - Status error, const Shape& shape, PjRtDevice* device) override; - - StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device) override; - - StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override; - - absl::StatusOr> - CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override { - return Unimplemented( - "CreateBuffersForAsyncHostToDevice with memory_space not implemented."); - } - - StatusOr> BufferFromHostBuffer( - const void* data, PrimitiveType type, absl::Span dims, - std::optional> byte_strides, - HostBufferSemantics host_buffer_semantics, - std::function on_done_with_host_buffer, - PjRtDevice* device) override; - - StatusOr> BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device) override; - - StatusOr>> - MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtDevice* device, - PjRtCrossHostRecvNotifier notifier) override { - return Unimplemented("MakeCrossHostReceiveBuffers not implemented."); - } - - StatusOr>> - MakeCrossHostReceiveBuffersForGather( - absl::Span shapes, std::vector gather_details, - PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override { - return Unimplemented( - "MakeCrossHostReceiveBuffersForGather not implemented."); - } - - StatusOr> CreateViewOfDeviceBuffer( - void* device_ptr, const Shape& shape, PjRtDevice* device, - std::function on_delete_callback, - std::optional stream) override; - - StatusOr CreateChannelHandle() override { - return Unimplemented("CreateChannelHandle not implemented."); - } - StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); - } - StatusOr CreateHostToDeviceChannelHandle() override { - return Unimplemented("CreateHostToDeviceChannelHandle not implemented."); - } - - Status Defragment() override { - return Unimplemented("Defragment not implemented."); - } - - tsl::thread::ThreadPool* pjrt_client_thread_pool() const { - return pjrt_client_thread_pool_.get(); - } - - AsyncWorkRunner* async_work_runner() const { - return async_work_runner_.get(); - } - - Eigen::ThreadPoolDevice* eigen_intraop_device() const { - return eigen_intraop_device_.get(); - } - - tsl::AsyncValueRef GetLastCollectiveLaunchEvent() { - absl::MutexLock lock(&mu_); - return last_collective_launch_event_.CopyRef(); - } - - void SetLastCollectiveLaunchEvent( - tsl::AsyncValueRef event) { - absl::MutexLock lock(&mu_); - last_collective_launch_event_ = std::move(event); - } - - private: - int process_index_; - // Includes all devices, including non-addressable devices. - std::vector> owned_devices_; - // Pointers to `owned_devices_`. - std::vector devices_; - // Maps Device::id() to the corresponding Device. Includes all devices. - absl::flat_hash_map id_to_device_; - // Addressable devices indexed by core_id. - std::vector addressable_devices_; - std::unique_ptr computation_placer_; - - // Thread pool for running PjRtClient tasks. - std::unique_ptr pjrt_client_thread_pool_; - std::unique_ptr async_work_runner_; - - // TODO(zhangqiaorjc): Use tsl::compat::EigenHostContextThreadPool. - std::unique_ptr eigen_intraop_pool_; - std::unique_ptr eigen_intraop_device_; - - // Launching collectives are prone to deadlock when we use fixed-sized - // threadpools since ExecuteHelper will block until all replicas reach the - // barrier. We ensure that - // 1. Threadpool size is at least as large as device_count so one collective - // launch over all devices can succeed. - // 2. Gang-schedule each collective by conservatively ensuring a total order - // of collectives and launching only one collective at a time to avoid - // having no active threads to make progress - // TODO(zhangqiaorjc): Explore alternatives that allow multiple concurrent - // collectives. - mutable absl::Mutex mu_; - tsl::AsyncValueRef last_collective_launch_event_ - ABSL_GUARDED_BY(mu_); - - // A cache for transpose plans. We use transposes to convert - // (possibly strided) buffers provided to BufferFromHostBuffer into dense - // major-to-minor layout. - absl::Mutex transpose_mu_; - TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); -}; - -class TfrtCpuBuffer final : public AbstractTfrtCpuBuffer { - public: - TfrtCpuBuffer( - Shape on_device_shape, - std::unique_ptr tracked_device_buffer, - TfrtCpuClient* client, TfrtCpuDevice* device); - - TfrtCpuBuffer(const TfrtCpuBuffer&) = delete; - TfrtCpuBuffer(TfrtCpuBuffer&&) = delete; - TfrtCpuBuffer& operator=(const TfrtCpuBuffer&) = delete; - TfrtCpuBuffer& operator=(TfrtCpuBuffer&&) = delete; - - PjRtMemorySpace* memory_space() const override { return nullptr; } - TfrtCpuDevice* device() const override { return device_; } - TfrtCpuClient* client() const override { return client_; } - - using PjRtBuffer::ToLiteralSync; - PjRtFuture ToLiteral(MutableLiteralBase* literal) override; - - StatusOr> CopyToDevice( - PjRtDevice* dst_device) override; - - private: - absl::string_view buffer_name() const override { return "TfrtCpuBuffer"; } - - TfrtCpuClient* client_; - TfrtCpuDevice* const device_; -}; - -class TfrtCpuExecutable final : public PjRtLoadedExecutable { - public: - TfrtCpuExecutable( - int num_replicas, int num_partitions, - std::shared_ptr device_assignment, - bool parameter_is_tupled_arguments, CompileOptions compile_options, - std::unique_ptr cpu_executable, - BufferAllocation::Index result_buffer_index, - absl::InlinedVector result_buffer_indices, - std::vector addressable_device_logical_ids, - std::vector addressable_devices, TfrtCpuClient* client); - - ~TfrtCpuExecutable() override = default; - - TfrtCpuClient* client() const override { return client_; } - - absl::string_view name() const override { - return cpu_executable_->shared_module()->name(); - } - - int num_replicas() const override { return num_replicas_; } - - int num_partitions() const override { return num_partitions_; } - - int64_t SizeOfGeneratedCodeInBytes() const override { - return cpu_executable_->SizeOfGeneratedCodeInBytes(); - } - - const DeviceAssignment& device_assignment() const override { - return *device_assignment_; - } - - absl::Span addressable_device_logical_ids() - const override { - return addressable_device_logical_ids_; - } - - absl::Span addressable_devices() const override { - return addressable_devices_; - } - - StatusOr>> GetHloModules() - const override { - return std::vector>{ - cpu_executable_->shared_module()}; - } - - StatusOr>> GetOutputMemoryKinds() - const override { - return Unimplemented("GetOutputMemoryKinds is not supported."); - } - - StatusOr GetCompiledMemoryStats() const override { - CompiledMemoryStats memory_stats = CompiledMemoryStats(); - memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); - const HloProto* proto = cpu_executable_->hlo_proto(); - if (!proto) { - return tsl::errors::FailedPrecondition( - "cpu_executable_ has no hlo_proto."); - } - memory_stats.serialized_hlo_proto = proto->SerializeAsString(); - return memory_stats; - } - - using PjRtLoadedExecutable::Execute; - StatusOr>>> Execute( - absl::Span> argument_handles, - const ExecuteOptions& options, - std::optional>>& returned_futures) - override; - - using PjRtLoadedExecutable::ExecuteSharded; - StatusOr>> ExecuteSharded( - absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options, - std::optional>& returned_future, - bool fill_future) override; - - using PjRtLoadedExecutable::ExecutePortable; - StatusOr>> ExecutePortable( - absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options, - std::optional>& returned_future, - bool fill_future) override; - - void Delete() override; - - bool IsDeleted() override; - - StatusOr SerializeExecutable() const override; - - bool IsReturnedFutureSupported() const override { return true; } - - StatusOr> Fingerprint() const; - - std::shared_ptr cpu_executable() const { return cpu_executable_; } - - StatusOr FingerprintExecutable() const override { - return Unimplemented("Fingerprinting executable is not supported."); - } - - private: - friend class TfrtCpuClient; - - Status SetUpDonation(bool tuple_inputs); - - // Checks that the input buffers passed in by the user have the correct size - // on device for the compiled program. - Status CheckBufferCompatibilities( - absl::Span const> - input_buffers) const; - - StatusOr ExecuteHelper( - absl::Span argument_handles, int replica, - int partition, const RunId& run_id, const ExecuteOptions& options, - tsl::AsyncValueRef last_collective_launch_event, - bool fill_future, TfrtCpuDevice* device = nullptr); - - TfrtCpuClient* client_; - - int num_replicas_; - int num_partitions_; - std::shared_ptr device_assignment_; - bool parameter_is_tupled_arguments_; - CompileOptions compile_options_; - - std::shared_ptr cpu_executable_; - - // Caching `result_buffer_index_` and `result_buffer_indices_` to avoid lookup - // HLO dataflow analysis data structures in program execution critical path. - - // Buffer allocation index corresponding to root buffer buffer. - BufferAllocation::Index result_buffer_index_; - // Buffer allocation indices corresponding to each result buffer leaf buffer. - absl::InlinedVector result_buffer_indices_; - - // Size on device of each leaf buffer of the compiled program, cached here - // for performance reasons. - std::vector input_buffer_sizes_in_bytes_; - - // A sorted vector of parameters that have any aliased buffers and thus must - // be donated when executing the computation. - std::vector parameters_that_must_be_donated_; - - // The replica and partition indices of device_assignment_ to be run by this - // client. On single-host platforms without partitioning, this is all - // replicas (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may - // not be the case on multi-host platforms. If there are 4 replicas and 2 - // partitions on a single host platform, size of - // addressable_device_logical_ids_ is 4*2 = 8. - std::vector addressable_device_logical_ids_; - - // addressable_devices_[i] is the Device to which - // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of - // unique_ptrs to play well with the Python bindings (see xla.cc). - std::vector addressable_devices_; - - // Cached result of comparing HloCostAnalysis FLOP estimate for execute - // critical path. - bool cheap_computation_; -}; - -struct CpuClientOptions { - // Does nothing at the moment. Ignored. - bool asynchronous = true; - - // Number of CPU devices. If not provided, the value of - // --xla_force_host_platform_device_count is used. - std::optional cpu_device_count = std::nullopt; - - int max_inflight_computations_per_device = 32; - - // Number of distributed nodes. node_id, kv_get, and kv_put are ignored if - // this is set to 1. - int num_nodes = 1; - - // My node ID. - int node_id = 0; - - // KV store primitives for sharing topology information. - PjRtClient::KeyValueGetCallback kv_get = nullptr; - PjRtClient::KeyValuePutCallback kv_put = nullptr; -}; -StatusOr> GetTfrtCpuClient( - const CpuClientOptions& options); - -// Deprecated. Use the overload that takes 'options' instead. -inline StatusOr> GetTfrtCpuClient( - bool asynchronous) { - CpuClientOptions options; - options.asynchronous = asynchronous; - return GetTfrtCpuClient(options); -} - -// Deprecated. Use the overload that takes 'options' instead. -inline StatusOr> GetTfrtCpuClient( - bool asynchronous, int cpu_device_count, - int max_inflight_computations_per_device = 32) { - CpuClientOptions options; - options.asynchronous = asynchronous; - options.cpu_device_count = cpu_device_count; - options.max_inflight_computations_per_device = - max_inflight_computations_per_device; - return GetTfrtCpuClient(options); -} - -} // namespace xla +#include "xla/pjrt/cpu/cpu_client.h" #endif // XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_ diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index c7367a7f35ba59..96d225404bc605 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -719,7 +719,7 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_stream_executor_client", - "//xla/pjrt:tfrt_cpu_pjrt_client", + "//xla/pjrt/cpu:cpu_client", "//xla/service:platform_util", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:test_main", @@ -1168,8 +1168,8 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", - "//xla/pjrt:tfrt_cpu_pjrt_client", "//xla/pjrt/c:pjrt_c_api_hdrs", + "//xla/pjrt/cpu:cpu_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:protocol_proto_cc", diff --git a/third_party/xla/xla/python/outfeed_receiver_test.cc b/third_party/xla/xla/python/outfeed_receiver_test.cc index 62539fa6079bea..0f67839f4cdfc2 100644 --- a/third_party/xla/xla/python/outfeed_receiver_test.cc +++ b/third_party/xla/xla/python/outfeed_receiver_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_builder.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_stream_executor_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/service/platform_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index e256468cdcea82..b64176848b9f2d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -259,7 +259,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pjrt_ifrt", - "//xla/pjrt:tfrt_cpu_pjrt_client", + "//xla/pjrt/cpu:cpu_client", "//xla/python/ifrt:test_util", ], alwayslink = True, diff --git a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 9486ec0156ca8a..b790c4c5606169 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/python/ifrt/test_util.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 5f2d4321aaf8c6..00a969753b4c29 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -60,10 +60,10 @@ limitations under the License. #ifdef XLA_PYTHON_ENABLE_GPU #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #endif // XLA_PYTHON_ENABLE_GPU +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/python/custom_call_sharding.h" #include "xla/python/dlpack.h" #include "xla/python/jax_jit.h" diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index e9df3e5ca4bb32..447e018675735c 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -180,7 +180,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pjrt_client_registry", - "//xla/pjrt:tfrt_cpu_pjrt_client", + "//xla/pjrt/cpu:cpu_client", ], ) diff --git a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc index 1f0eaa43d5819f..d65884c15bbdbf 100644 --- a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc +++ b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/tests/pjrt_client_registry.h" namespace xla { From 50a89c49c2d489f9acb39bbee61710de4485707e Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 17 Nov 2023 09:30:57 -0800 Subject: [PATCH 230/391] Represent int4 as the LLVM type i4 instead of i8. Before, int4 was represented as i8. But this caused issues when converting S8 to S4 then back to S8 within a single fusion, as the S4 value was represented as i8 and so could hold values not representable with a signed int4 value. This could have been solved by introducing a mask when converting to S4, but I think it's cleaner to directly use i4, as this avoids the unusual case where the XLA type did not match the corresponding LLVM type. PiperOrigin-RevId: 583410950 --- .../xla/xla/service/llvm_ir/ir_array.cc | 33 +++++++------------ .../xla/xla/service/llvm_ir/ir_array.h | 27 +++++++-------- .../xla/xla/service/llvm_ir/ir_array_test.cc | 24 +++++++------- .../xla/xla/service/llvm_ir/llvm_util.cc | 4 +-- third_party/xla/xla/tests/convert_test.cc | 28 ++++++++++++++++ 5 files changed, 66 insertions(+), 50 deletions(-) diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 503aeddf450e7e..1a0284ee5c9a49 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "xla/layout_util.h" #include "xla/permutation_util.h" @@ -555,7 +556,6 @@ llvm::Value* IrArray::EmitLinearArrayElementAddress( // Handle int4 case by dividing index by 2. Int4 arrays are represented in // LLVM IR as an array of i8 value where each i8 value stores two int4 // numbers. - CHECK(type->isIntegerTy(8)); llvm::Type* index_type = index.linear()->getType(); llvm::Value* zero = llvm::ConstantInt::get(index_type, 0); llvm::Value* two = llvm::ConstantInt::get(index_type, 2); @@ -564,7 +564,7 @@ llvm::Value* IrArray::EmitLinearArrayElementAddress( // is_high_order_bits must be set for int4 arrays. CHECK_NE(is_high_order_bits, nullptr); *is_high_order_bits = b->CreateICmpEQ(remainder, zero); - return b->CreateInBoundsGEP(type, base_ptr_, byte_offset, + return b->CreateInBoundsGEP(b->getInt8Ty(), base_ptr_, byte_offset, llvm_ir::AsStringRef(name)); } @@ -587,29 +587,18 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::Value* is_high_order_bits = nullptr; llvm::Value* element_address = EmitArrayElementAddress( index, b, name, use_linear_index, &is_high_order_bits); + llvm::Type* load_type = primitive_util::Is4BitType(shape_.element_type()) + ? b->getInt8Ty() + : element_type_; llvm::LoadInst* load = - b->CreateLoad(element_type_, element_address, llvm_ir::AsStringRef(name)); + b->CreateLoad(load_type, element_address, llvm_ir::AsStringRef(name)); AnnotateLoadStoreInstructionWithMetadata(load); llvm::Value* elem = load; if (primitive_util::Is4BitType(shape_.element_type())) { llvm::Type* type = load->getType(); - llvm::Value* high_order_bits; - llvm::Value* low_order_bits; - if (shape_.element_type() == U4) { - high_order_bits = b->CreateLShr(load, llvm::ConstantInt::get(type, 4)); - low_order_bits = b->CreateAnd(load, llvm::ConstantInt::get(type, 0x0F)); - } else { - CHECK_EQ(shape_.element_type(), S4); - high_order_bits = b->CreateAShr(load, llvm::ConstantInt::get(type, 4)); - // To compute low_order_bits, cast to i4 then back to i8, which fills the - // left 4 bits with ones for negative numbers and zeros for positive - // numbers. - low_order_bits = - b->CreateIntCast(load, b->getIntNTy(4), /*isSigned=*/true); - low_order_bits = - b->CreateIntCast(low_order_bits, b->getInt8Ty(), /*isSigned=*/true); - } - elem = b->CreateSelect(is_high_order_bits, high_order_bits, low_order_bits); + llvm::Value* shifted = b->CreateLShr(load, llvm::ConstantInt::get(type, 4)); + elem = b->CreateSelect(is_high_order_bits, shifted, load); + elem = b->CreateTrunc(elem, b->getIntNTy(4)); } return elem; } @@ -623,9 +612,11 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, if (primitive_util::Is4BitType(shape_.element_type())) { // Read a byte, replace the high-order or low-order bits with 'value', // and write it back. - llvm::LoadInst* load = b->CreateLoad(element_type_, element_address); + llvm::LoadInst* load = b->CreateLoad(b->getInt8Ty(), element_address); AnnotateLoadStoreInstructionWithMetadata(load); llvm::Type* type = load->getType(); + value = b->CreateIntCast(value, b->getInt8Ty(), + /*isSigned=*/shape_.element_type() == S4); llvm::Value* high_order_value = b->CreateShl(value, llvm::ConstantInt::get(type, 4)); diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index cfd55362215cc9..1ee0995b523a6b 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -223,12 +223,14 @@ class IrArray { // base_ptr is a pointer type pointing to the first element(lowest address) // of the array. // - // For int4 arrays, pointee_type should be i8, not i4, as int4 - // IrArrays are represented as i8 arrays where each i8 value stores two 4-bit - // values. Additionally, reads and write return or take in i8 values which - // hold a value representable by i4, instead of directly returning or taking - // in i4 values. Specifically, the i8 values returned or passed in are between - // 0 and 15 for U4 arrays and between -8 and 7 for S4 arrays. + // For int4 arrays, base_ptr should have half the number of bytes as array + // elements (rounded up), as two int4 values are packed into a byte. + // pointee_type should be an i4 array in this case, and reads and writes will + // return or take in i4 values. IrArray internally reads or writes i8 values, + // by treating base_ptr as an i8 array and masking out the high- or low-order + // 4 bits of the byte. IrArray does not directly read/write i4 values, since + // arrays of i4 values in LLVM are not packed (every element of an LLVM IR + // array must have unique address). IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape); // Default implementations of copying and moving. @@ -272,10 +274,6 @@ class IrArray { // the emitted LLVM IR. // 'use_linear_index' can be used to specify whether the linear index (if // available) or the multi-dimensional index should be used. - // - // For int4 arrays, returns an i8 value that is representable by i4. The - // returned i8 value will be between 0 and 15 for U4 arrays and between -8 and - // 7 for S4 arrays. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, absl::string_view name = "", bool use_linear_index = true) const; @@ -284,11 +282,10 @@ class IrArray { // 'use_linear_index' can be used to specify whether the linear index (if // available) or the multi-dimensional index should be used. // - // For int4 arrays, the given value must be an i8 value representable by i4. - // Only 4 bits of a byte in the array are written. First the appropriate byte - // is read from the array, then 4 bits are modified and written back. To avoid - // race conditions, the caller must ensure that the two different 4-bit values - // within a byte are not written to in parallel. + // For int4 arrays, only 4 bits of a byte in the array are written. First the + // appropriate byte is read from the array, then 4 bits are modified and + // written back. To avoid race conditions, the caller must ensure that the two + // different 4-bit values within a byte are not written to in parallel. void EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::IRBuilder<>* b, bool use_linear_index = true) const; diff --git a/third_party/xla/xla/service/llvm_ir/ir_array_test.cc b/third_party/xla/xla/service/llvm_ir/ir_array_test.cc index ec92a223b421d3..ea1717ce4371fb 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array_test.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array_test.cc @@ -151,7 +151,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddressInt4) { /*is_high_order_bits=*/&is_high_order_bits); std::string ir_str = DumpToString(&module_); - // The index is divided by 2 and used as an index to the i8 array. A + // The index is divided by 2 and used as an index to the i8 array. A remainder // is also computed to calculate is_high_order_bits. const char* filecheck_pattern = R"( CHECK: define void @test_function(ptr %[[ptr:[0-9]+]], i32 %[[idx:[0-9]+]]) { @@ -186,7 +186,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddressInt4NonLinear) { std::string ir_str = DumpToString(&module_); // The index is linearized despite use_linear_index=false being passed because - // non-linaer indices are not supported with int4 + // non-linear indices are not supported with int4 const char* filecheck_pattern = R"( CHECK: define void @test_function(ptr %[[ptr:[0-9]+]], i32 %[[idx0:[0-9]+]], i32 %[[idx1:[0-9]+]]) { CHECK: %[[mul1:[0-9]+]] = mul nuw nsw i32 %[[idx1]], 1 @@ -222,15 +222,14 @@ TEST_F(IrArrayTest, EmitReadArrayElementInt4) { COM: Calculate the address. CHECK: %[[srem:[0-9]+]] = srem i32 %[[idx0]], 2 CHECK: %[[addr:[0-9]+]] = udiv i32 %[[idx0]], 2 - CHECK: %[[isodd:[0-9]+]] = icmp eq i32 %[[srem]], 0 + CHECK: %[[iseven:[0-9]+]] = icmp eq i32 %[[srem]], 0 CHECK: %[[gep:[0-9]+]] = getelementptr inbounds i8, ptr %[[ptr]], i32 %[[addr]] - COM: Load the element and mask out 4 bits. + COM: Load the element, optionally shift, and truncate. CHECK: %[[load:[0-9]+]] = load i8, ptr %[[gep]], align 1 - CHECK: %[[shift:[0-9]+]] = ashr i8 %[[load]], 4 - CHECK: %[[trunc:[0-9]+]] = trunc i8 %[[load]] to i4 - CHECK: %[[sext:[0-9]+]] = sext i4 %[[trunc]] to i8 - CHECK: select i1 %[[isodd]], i8 %[[shift]], i8 %[[sext]] + CHECK: %[[shift:[0-9]+]] = lshr i8 %[[load]], 4 + CHECK: %[[select:[0-9]+]] = select i1 %[[iseven]], i8 %[[shift]], i8 %[[load]] + CHECK: trunc i8 %[[select]] to i4 )"; TF_ASSERT_OK_AND_ASSIGN(bool filecheck_match, @@ -240,7 +239,7 @@ TEST_F(IrArrayTest, EmitReadArrayElementInt4) { TEST_F(IrArrayTest, EmitWriteArrayElementInt4) { llvm::Function* function = EmitFunctionAndSetInsertPoint( - {builder_.getPtrTy(), builder_.getInt32Ty(), builder_.getInt8Ty()}); + {builder_.getPtrTy(), builder_.getInt32Ty(), builder_.getIntNTy(4)}); llvm::Argument* array_ptr = function->getArg(0); llvm::Argument* array_index = function->getArg(1); llvm::Argument* val_to_write = function->getArg(2); @@ -254,7 +253,7 @@ TEST_F(IrArrayTest, EmitWriteArrayElementInt4) { std::string ir_str = DumpToString(&module_); const char* filecheck_pattern = R"( - CHECK: define void @test_function(ptr %[[ptr:[0-9]+]], i32 %[[idx0:[0-9]+]], i8 %[[val:[0-9]+]]) { + CHECK: define void @test_function(ptr %[[ptr:[0-9]+]], i32 %[[idx0:[0-9]+]], i4 %[[val:[0-9]+]]) { COM: Calculate the address. CHECK: %[[srem:[0-9]+]] = srem i32 %[[idx0]], 2 @@ -264,10 +263,11 @@ TEST_F(IrArrayTest, EmitWriteArrayElementInt4) { COM: Load address, replace 4 bits with the value, and write to address. CHECK: %[[load:[0-9]+]] = load i8, ptr %[[gep]], align 1 - CHECK: %[[shl:[0-9]+]] = shl i8 %[[val]], 4 + CHECK: %[[sext:[0-9]+]] = sext i4 %[[val]] to i8 + CHECK: %[[shl:[0-9]+]] = shl i8 %[[sext]], 4 CHECK: %[[and1:[0-9]+]] = and i8 %[[load]], 15 CHECK: %[[or1:[0-9]+]] = or i8 %[[shl]], %[[and1]] - CHECK: %[[and2:[0-9]+]] = and i8 %[[val]], 15 + CHECK: %[[and2:[0-9]+]] = and i8 %[[sext]], 15 CHECK: %[[and3:[0-9]+]] = and i8 %[[load]], -16 CHECK: %[[or2:[0-9]+]] = or i8 %[[and2]], %[[and3]] CHECK: %[[towrite:[0-9]+]] = select i1 %[[isodd]], i8 %[[or1]], i8 %[[or2]] diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 91f30c6bfc968f..59124f4e6a7875 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -216,10 +216,10 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type, llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, llvm::Module* module) { switch (element_type) { - case PRED: - // i8 is used for S4/U4 as arrays of i4 values are not packed case S4: case U4: + return llvm::Type::getIntNTy(module->getContext(), 4); + case PRED: case S8: case U8: return llvm::Type::getInt8Ty(module->getContext()); diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index 4330d29c715dbf..3e1ae9c9a30f6a 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -601,6 +601,34 @@ TEST_F(ConvertTest, ConvertR1U8ToR1U4) { ComputeAndCompareR1(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1S8ToR1S4Roundtrip) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 8, -8, -9, 127, -128}); + auto b = ConvertElementType(a, S4); + ConvertElementType(b, S8); + + std::vector expected = {0, -8, -8, 7, -1, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1S4) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0., 2.5, -2.5}); + ConvertElementType(a, S4); + + std::vector expected = {s4(0), s4(2), s4(-2)}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1S4ToR1F32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {s4(0), s4(1), s4(2), s4(-8)}); + ConvertElementType(a, F32); + + std::vector expected = {0, 1, 2, -8}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(ConvertTest, ConvertBF16F32) { XlaBuilder builder(TestName()); From 7500dfc73f2cfbeb0a3a76c005cbf0402b5cc0c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 10:04:47 -0800 Subject: [PATCH 231/391] Fix flaky scatter3d test. In the test we only constrain the memory budget such that the large tensors are sharded 4-ways. Based on how we handle scatter ops, there is no reason to prefer one 4-way sharding over others, despite the assymetric mesh communication costs because: 1. We do not model any communication cost for the scatter op itself, and/ 2. Given any sharding for the scatter op, there always exists sharding strategies (among the ones we consider) for all the operands such that operand resharding costs are zero. This fix therefore ensures that we check for all possible 4-way shardings. PiperOrigin-RevId: 583420037 --- .../xla/hlo/experimental/auto_sharding/auto_sharding_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 9d11494eeb76ed..7b772c51bd0c34 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -720,7 +720,9 @@ ENTRY %Scatter { EXPECT_THAT(scatter, AnyOf(op::Sharding("{devices=[2,2,1]0,2,1,3}"), op::Sharding("{devices=[2,2,1]0,1,2,3}"), op::Sharding("{devices=[2,1,2]0,2,1,3}"), - op::Sharding("{devices=[2,1,2]0,1,2,3}"))); + op::Sharding("{devices=[2,1,2]0,1,2,3}"), + op::Sharding("{devices=[1,2,2]0,1,2,3}"), + op::Sharding("{devices=[1,2,2]0,2,1,3}"))); auto scatter_sharding = scatter->sharding(); TF_EXPECT_OK(scatter_sharding.Validate(scatter->shape(), 4)); } From 5bb063a40ecac72997615603ad7b97b11c929046 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 10:07:55 -0800 Subject: [PATCH 232/391] Replace dependency on xla/util.h with dependency on math_util.h. PiperOrigin-RevId: 583420892 --- third_party/xla/xla/service/gpu/runtime/BUILD | 2 +- third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index b8c0c35f89e2a4..9dc7c482b520b1 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -32,8 +32,8 @@ gpu_kernel_library( "TENSORFLOW_USE_ROCM=1", ]), deps = [ - "//xla:util", "//xla/stream_executor/platform", + "@local_tsl//tsl/lib/math:math_util", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ diff --git a/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h b/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h index 3431eb4d000cf5..f3ee9b250249d3 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h +++ b/third_party/xla/xla/service/gpu/runtime/gpu_kernel_helper.h @@ -20,7 +20,7 @@ limitations under the License. #include -#include "xla/util.h" +#include "tsl/lib/math/math_util.h" namespace xla { namespace gpu { @@ -108,7 +108,8 @@ enum class ShflType { Sync, Up, Down, Xor }; template __device__ FORCEINLINE NT GpuShuffle(NT val, uint32_t idx, uint32_t allmsk = 0xffffffffu) { - constexpr uint32_t SZ = CeilOfRatio(sizeof(NT), sizeof(uint32_t)); + constexpr uint32_t SZ = + tsl::MathUtil::CeilOfRatio(sizeof(NT), sizeof(uint32_t)); union S { NT v; uint32_t d[SZ]; From ee9442eba9f3848ff5836fd4b10379d61c10ef33 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 17 Nov 2023 11:00:57 -0800 Subject: [PATCH 233/391] [XLA:GPU] Expose buffer-assignment from hlo-opt PiperOrigin-RevId: 583436422 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 11 +++++------ .../xla/xla/service/gpu/gpu_executable.cc | 7 +++++-- .../xla/xla/service/gpu/gpu_executable.h | 13 +++++++++++-- .../xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo | 17 +++++++++++++++++ .../xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo | 2 +- third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 9 ++++++++- 6 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 924b75ac0193b7..14b4e68e93ba50 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1669,15 +1669,12 @@ StatusOr> GpuCompiler::RunBackend( } std::shared_ptr buffer_assignment; - std::unique_ptr buffer_assignment_proto; std::function buffer_assignment_dumper = [] { return std::string(); }; if (!options.is_autotuning_compilation) { // Make it shared to be captured in the later lambda. buffer_assignment = std::move(res.compile_module_results.buffer_assignment); - buffer_assignment_proto = - std::make_unique(buffer_assignment->ToProto()); size_t max_buffers_to_show = module->config().debug_options().xla_debug_buffer_assignment_show_max(); buffer_assignment_dumper = [buffer_assignment, max_buffers_to_show] { @@ -1723,7 +1720,7 @@ StatusOr> GpuCompiler::RunBackend( module->config() .debug_options() .xla_gpu_enable_persistent_temp_buffers(), - /*debug_buffer_assignment=*/std::move(buffer_assignment_proto), + /*debug_buffer_assignment=*/std::move(buffer_assignment), /*verbose_buffer_assignment_string_dumper=*/ std::move(buffer_assignment_dumper), /*debug_module=*/options.is_autotuning_compilation @@ -1744,9 +1741,11 @@ StatusOr> GpuCompiler::RunBackend( // CompiledMemoryAnalysis. auto hlo_proto = std::make_unique(); *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto(); - *hlo_proto->mutable_buffer_assignment() = buffer_assignment->ToProto(); + *hlo_proto->mutable_buffer_assignment() = + gpu_executable->BufferAssignment()->ToProto(); gpu_executable->set_hlo_proto(std::move(hlo_proto)); - gpu_executable->set_debug_info(buffer_assignment->GetStats().ToString()); + gpu_executable->set_debug_info( + gpu_executable->BufferAssignment()->GetStats().ToString()); } return static_cast>(std::move(gpu_executable)); diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index a576d8922534b3..6ed3e3a171aa74 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -144,8 +144,11 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params) *(uint64_t*)(&binary_[binary_.size() - 8]) = tsl::random::New64(); #endif if (has_module() && enable_debug_info_manager_) { + debug_buffer_assignment_proto_ = + std::make_shared( + debug_buffer_assignment_->ToProto()); XlaDebugInfoManager::Get()->RegisterModule(shared_module(), - debug_buffer_assignment_); + debug_buffer_assignment_proto_); } } @@ -952,7 +955,7 @@ GpuExecutable::GpuExecutable( enable_debug_info_manager_(true) { if (has_module()) { XlaDebugInfoManager::Get()->RegisterModule(shared_module(), - debug_buffer_assignment_); + debug_buffer_assignment_proto_); } } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index bef637f595de01..6146bd546192e3 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -95,7 +95,7 @@ class GpuExecutable : public Executable { xla::Shape output_shape; std::vector allocations; bool enable_persistent_temp_buffers; - std::unique_ptr debug_buffer_assignment = nullptr; + std::shared_ptr debug_buffer_assignment; // A callable that dumps out a debug string upon device OOM. It's not the // string itself, as the string can be huge and increase peak host memory @@ -202,6 +202,14 @@ class GpuExecutable : public Executable { StatusOr GetObjFile() const; StatusOr GetMlirModule() const; + BufferAssignment* BufferAssignment() const { + return debug_buffer_assignment_.get(); + } + + BufferAssignmentProto* BufferAssignmentProto() const { + return debug_buffer_assignment_proto_.get(); + } + private: // Use GpuExecutable::Create() to create an instance. explicit GpuExecutable(Params params); @@ -308,7 +316,8 @@ class GpuExecutable : public Executable { BufferAllocToDeviceMemoryMap> persistent_temp_buffers_ ABSL_GUARDED_BY(persistent_temp_buffers_mu_); - std::shared_ptr debug_buffer_assignment_; + std::shared_ptr debug_buffer_assignment_; + std::shared_ptr debug_buffer_assignment_proto_; std::function verbose_buffer_assignment_string_dumper_; absl::Mutex module_handle_mutex_; diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo new file mode 100644 index 00000000000000..e99d7146bebe78 --- /dev/null +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo @@ -0,0 +1,17 @@ +// RUN: hlo-opt %s --platform=CUDA --stage=buffer-assignment --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s + +HloModule m, is_scheduled=true + +add { + a = f16[] parameter(0) + b = f16[] parameter(1) + ROOT out = f16[] add(a, b) +} + + +// CHECK: parameter allocation: 2.00MiB +ENTRY e { + p1 = f16[1048576] parameter(0) + i = f16[] constant(0) + ROOT out = f16[] reduce(p1, i), dimensions={0}, to_apply=add +} diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo index b9ba42564d8227..428bb07d95329c 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_ptx.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s --dump-input-filter=all +// RUN: hlo-opt %s --platform=CUDA --stage=ptx --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s HloModule m diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 422bccdfdd81cf..3a467c5748d9f4 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -88,6 +88,13 @@ struct GpuOptProvider : public OptProvider { std::unique_ptr executable, ToGpuExecutable(std::move(module), compiler, executor, opts)); return static_cast(executable.get())->text(); + } else if (s == "buffer-assignment") { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + ToGpuExecutable(std::move(module), compiler, executor, opts)); + return static_cast(executable.get()) + ->BufferAssignment() + ->ToVerboseString(9999); } // Unimplemented stage. @@ -95,7 +102,7 @@ struct GpuOptProvider : public OptProvider { } std::vector SupportedStages() override { - return {"hlo", "llvm", "ptx"}; + return {"hlo", "llvm", "ptx", "buffer-assignment"}; } }; From a3a1179e783dc08cfb0471060425dfa3753bbb62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 11:29:31 -0800 Subject: [PATCH 234/391] Use Homebrew to install CMake. PiperOrigin-RevId: 583444037 --- ci/official/utilities/setup_macos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/official/utilities/setup_macos.sh b/ci/official/utilities/setup_macos.sh index 7f00dea7838197..61378bc4e75204 100644 --- a/ci/official/utilities/setup_macos.sh +++ b/ci/official/utilities/setup_macos.sh @@ -60,5 +60,5 @@ if [[ "${TFCI_PYTHON_VERSION}" == "3.12" ]]; then # dm-tree (Keras v3 dependency) doesn't have pre-built wheels for 3.12 yet. # Having CMake allows building them. # Once the wheels are added, this should be removed - b/308399490. - sudo apt-get install -y --no-install-recommends cmake + brew install cmake fi From ade69b73dbf696b35bf370a7b6e5cb17eae974bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 11:37:56 -0800 Subject: [PATCH 235/391] No op change. PiperOrigin-RevId: 583446207 --- tensorflow/core/data/BUILD | 12 ++++++ .../core/data/file_logger_client_interface.h | 41 +++++++++++++++++++ .../core/data/file_logger_client_no_op.h | 41 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 tensorflow/core/data/file_logger_client_interface.h create mode 100644 tensorflow/core/data/file_logger_client_no_op.h diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index a6debf6a378624..2c21b45cb6f2b9 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -636,3 +636,15 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +cc_library( + name = "file_logger_client_interface", + hdrs = ["file_logger_client_interface.h"], + # copybara:uncomment visibility = ["//learning/processing/tf_data_logger/client:__subpackages__"], +) + +cc_library( + name = "file_logger_client_no_op", + hdrs = ["file_logger_client_no_op.h"], + deps = [":file_logger_client_interface"], +) diff --git a/tensorflow/core/data/file_logger_client_interface.h b/tensorflow/core/data/file_logger_client_interface.h new file mode 100644 index 00000000000000..afa6cda0cf15f5 --- /dev/null +++ b/tensorflow/core/data/file_logger_client_interface.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ +#define TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ + +#include +#include + +namespace tensorflow::data { + +// An abstract class to provides an easy and thread safe api to make +// asynchronous calls to the TFDataLoggerService. +// LogFilesAsync is guaranteed to be non blocking. +// The destructor however might be blocking. +class FileLoggerClientInterface { + public: + // Default constructor + FileLoggerClientInterface() = default; + + // Sends file names in `files` to the TFDataLoggerService. Asynchronously. + virtual void LogFilesAsync(std::vector files) = 0; + + // Default destructor. May block depending on implementation of the derived + // class. + virtual ~FileLoggerClientInterface() = default; +}; +} // namespace tensorflow::data + +#endif // TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_INTERFACE_H_ diff --git a/tensorflow/core/data/file_logger_client_no_op.h b/tensorflow/core/data/file_logger_client_no_op.h new file mode 100644 index 00000000000000..65247844f741c4 --- /dev/null +++ b/tensorflow/core/data/file_logger_client_no_op.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ +#define TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ + +#include +#include + +#include "tensorflow/core/data/file_logger_client_interface.h" + +namespace tensorflow::data { + +// Implementation of the abstract class FileLoggerClientInterface, which does +// nothing. It does not allocate any resources and immediately returns in +// LogFilesAsync.3rd This is used in 3rd party version of the tf.data library. +class FileLoggerClientNoOp : public FileLoggerClientInterface { + public: + // Default constructor + FileLoggerClientNoOp() = default; + + // Does not do anything + void LogFilesAsync(std::vector files) override{}; + + // Default destructor + ~FileLoggerClientNoOp() override = default; +}; +} // namespace tensorflow::data + +#endif // TENSORFLOW_CORE_DATA_FILE_LOGGER_CLIENT_NO_OP_H_ From 69093a9ccd8d273fea9134dd1e160ee990f07117 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Fri, 17 Nov 2023 11:40:13 -0800 Subject: [PATCH 236/391] #tf-data-service Don't apply general optimizations during compression map rewrite. PiperOrigin-RevId: 583446790 --- tensorflow/core/data/rewrite_utils.cc | 5 +++-- tensorflow/core/data/rewrite_utils.h | 5 ++++- tensorflow/core/data/service/graph_rewriters.cc | 10 +++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/data/rewrite_utils.cc b/tensorflow/core/data/rewrite_utils.cc index 707e4d8264118d..76c05e6e47f2fc 100644 --- a/tensorflow/core/data/rewrite_utils.cc +++ b/tensorflow/core/data/rewrite_utils.cc @@ -249,7 +249,8 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, } std::unique_ptr GetGrapplerItem( - GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) { + GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks, + bool apply_optimizations) { // Add an identity node as the fetch node, otherwise we might get 'placeholder // is both fed and fetched' errors in some cases when using input list with // placeholder dataset nodes. @@ -285,7 +286,7 @@ std::unique_ptr GetGrapplerItem( // Create Grappler item. tensorflow::grappler::ItemConfig item_config; - item_config.apply_optimizations = true; + item_config.apply_optimizations = apply_optimizations; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef( "graph", meta_graph_def, item_config); diff --git a/tensorflow/core/data/rewrite_utils.h b/tensorflow/core/data/rewrite_utils.h index 23ea965d67e105..44205dc83b24f5 100644 --- a/tensorflow/core/data/rewrite_utils.h +++ b/tensorflow/core/data/rewrite_utils.h @@ -57,10 +57,13 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, // `dataset_node` is the name of the node corresponding to the dataset. // If `add_fake_sinks` is true, it adds fake sink node to graph and functions to // allow rewriting the actual sink nodes. +// If `apply_optimizations` is true, general grappler optimizations at level +// `tensorflow::OptimizerOptions::L1` are applied to the graph. // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals to // be optimizable, we will no longer need to add fake nodes. std::unique_ptr GetGrapplerItem( - GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks); + GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks, + bool apply_optimizations = true); // Returns the name of the node corresponding to the dataset. It is indicated by // the symbolic `_Retval` node. diff --git a/tensorflow/core/data/service/graph_rewriters.cc b/tensorflow/core/data/service/graph_rewriters.cc index af154691549f2a..114ae7c336cedb 100644 --- a/tensorflow/core/data/service/graph_rewriters.cc +++ b/tensorflow/core/data/service/graph_rewriters.cc @@ -95,10 +95,18 @@ RemoveCompressionMapRewriter::ApplyRemoveCompressionMapRewrite( tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig(); TF_RETURN_IF_ERROR(remove_compression_map.Init(&config)); + // Don't apply general grappler optimizations. Sometimes there is a conflict + // between two applications of these optimizations to the same graph (see + // b/303524867). This conflict isn't worth resolving in the context of this + // rewrite: the point of this rewrite is to remove one node and change one + // reference to it, not to apply any general optimizations. + bool apply_general_grappler_optimizations = false; + GraphDef input_graph = graph_def; TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph)); std::unique_ptr grappler_item = - GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false); + GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false, + apply_general_grappler_optimizations); GraphDef rewritten_graph; std::unordered_map device_map; From 3065bd3b2d2aceb8d9e2018d2187eb8145a86513 Mon Sep 17 00:00:00 2001 From: hmonishN <143435143+hmonishN@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:42:20 -0800 Subject: [PATCH 237/391] PR #7054: Incresed the tolerance for Ampere Imported from GitHub PR https://github.com/openxla/xla/pull/7054 The TritonGemmLevel2Test.BinaryOperationWithSmallInputsIsFused. fails on A100 gpu, cuda version is 12.1. It passes on volta. Increasing the relative error is increased to 1e-2(default is 1e-3) for A100 to pass it on A100 gpu. Copybara import of the project: -- a932a867d8318befef39e5c3222a2a6af7e45e19 by Harshit Monish : Incresed the tolerance for Ampere Merging this change closes #7054 PiperOrigin-RevId: 583447341 --- third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 1a6b11b77c0bd5..946bf2e19fa88e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -1449,7 +1449,7 @@ ENTRY e { GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter()) .WithFusionKind(HloInstruction::FusionKind::kCustom))); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-3})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-1, /*arel=*/1e-2})); } TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { From 50208708c707d1ded860f3da4c78a30dd2141284 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 17 Nov 2023 12:21:58 -0800 Subject: [PATCH 238/391] Create an experimental continuous cross-compile build for Linux Aarch64 PiperOrigin-RevId: 583457583 --- ci/official/envs/continuous_linux_arm64_cpu_cross_compile | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 ci/official/envs/continuous_linux_arm64_cpu_cross_compile diff --git a/ci/official/envs/continuous_linux_arm64_cpu_cross_compile b/ci/official/envs/continuous_linux_arm64_cpu_cross_compile new file mode 100644 index 00000000000000..d506aca9441b98 --- /dev/null +++ b/ci/official/envs/continuous_linux_arm64_cpu_cross_compile @@ -0,0 +1,6 @@ +# This envrionment is experimental and should not yet be used for production jobs +TFCI_BAZEL_COMMON_ARGS="--config rbe_cross_compile_linux_arm64 --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 +TFCI_DOCKER_IMAGE=gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +TFCI_DOCKER_REBUILD_ARGS="--target=tf ci/official/containers/linux_arm64" +TFCI_PYTHON_VERSION=3.11 From 88d4759d723d2b8c8fd16bd22055c410f5c4d0eb Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 17 Nov 2023 12:39:09 -0800 Subject: [PATCH 239/391] [xla:gpu] Rename buffer assignment getters to fix compilation error PiperOrigin-RevId: 583461801 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 4 ++-- third_party/xla/xla/service/gpu/gpu_executable.h | 4 ++-- third_party/xla/xla/tools/hlo_opt/gpu_opt.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 14b4e68e93ba50..d0a7a1807118da 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1742,10 +1742,10 @@ StatusOr> GpuCompiler::RunBackend( auto hlo_proto = std::make_unique(); *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto(); *hlo_proto->mutable_buffer_assignment() = - gpu_executable->BufferAssignment()->ToProto(); + gpu_executable->buffer_assignment()->ToProto(); gpu_executable->set_hlo_proto(std::move(hlo_proto)); gpu_executable->set_debug_info( - gpu_executable->BufferAssignment()->GetStats().ToString()); + gpu_executable->buffer_assignment()->GetStats().ToString()); } return static_cast>(std::move(gpu_executable)); diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index 6146bd546192e3..064b535a0cff42 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -202,11 +202,11 @@ class GpuExecutable : public Executable { StatusOr GetObjFile() const; StatusOr GetMlirModule() const; - BufferAssignment* BufferAssignment() const { + BufferAssignment* buffer_assignment() const { return debug_buffer_assignment_.get(); } - BufferAssignmentProto* BufferAssignmentProto() const { + BufferAssignmentProto* buffer_assignment_proto() const { return debug_buffer_assignment_proto_.get(); } diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 3a467c5748d9f4..4498538056c51b 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -93,7 +93,7 @@ struct GpuOptProvider : public OptProvider { std::unique_ptr executable, ToGpuExecutable(std::move(module), compiler, executor, opts)); return static_cast(executable.get()) - ->BufferAssignment() + ->buffer_assignment() ->ToVerboseString(9999); } From 139ea31daf14c49cc9ec6ef83240f112fec28ec6 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Fri, 17 Nov 2023 12:44:32 -0800 Subject: [PATCH 240/391] Fix mistaken path/python-ver in default docker rebuild args PiperOrigin-RevId: 583463047 --- ci/official/envs/continuous_linux_x86_cpu_py310 | 2 +- ci/official/envs/continuous_linux_x86_cpu_py311 | 2 +- ci/official/envs/continuous_linux_x86_cpu_py39 | 2 +- ci/official/envs/continuous_linux_x86_cuda_py310 | 2 +- ci/official/envs/continuous_linux_x86_cuda_py311 | 2 +- ci/official/envs/continuous_linux_x86_cuda_py39 | 2 +- ci/official/envs/nightly_libtensorflow_linux_x86_cpu | 2 +- ci/official/envs/nightly_libtensorflow_linux_x86_cuda | 2 +- ci/official/envs/nightly_linux_x86_cpu_py310 | 2 +- ci/official/envs/nightly_linux_x86_cpu_py311 | 2 +- ci/official/envs/nightly_linux_x86_cpu_py312 | 2 +- ci/official/envs/nightly_linux_x86_cpu_py39 | 2 +- ci/official/envs/nightly_linux_x86_cuda_py310 | 2 +- ci/official/envs/nightly_linux_x86_cuda_py311 | 2 +- ci/official/envs/nightly_linux_x86_cuda_py312 | 2 +- ci/official/envs/nightly_linux_x86_cuda_py39 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py310 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py311 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py312 | 2 +- ci/official/envs/nightly_linux_x86_tpu_py39 | 2 +- 20 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ci/official/envs/continuous_linux_x86_cpu_py310 b/ci/official/envs/continuous_linux_x86_cpu_py310 index edadd16cbe3f56..81ad6455963c6e 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py310 +++ b/ci/official/envs/continuous_linux_x86_cpu_py310 @@ -1,5 +1,5 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_DOCKER_IMAGE=tensorflow/build:latest-pythonlatest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_linux_x86_cpu_py311 b/ci/official/envs/continuous_linux_x86_cpu_py311 index 3f0a688e543e28..4a306e19f97258 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py311 +++ b/ci/official/envs/continuous_linux_x86_cpu_py311 @@ -1,5 +1,5 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_x86_cpu_py39 b/ci/official/envs/continuous_linux_x86_cpu_py39 index 1fc06c40ff61d5..6b225c4e8f3170 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py39 +++ b/ci/official/envs/continuous_linux_x86_cpu_py39 @@ -1,5 +1,5 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config rbe_linux_cpu --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py310 b/ci/official/envs/continuous_linux_x86_cuda_py310 index ca287fcfe6cbb7..95e30867ced0ed 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py310 +++ b/ci/official/envs/continuous_linux_x86_cuda_py310 @@ -2,6 +2,6 @@ TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NVIDIA_SMI_ENABLE=1 TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py311 b/ci/official/envs/continuous_linux_x86_cuda_py311 index dca7656df10913..8bc69dc0ed514c 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py311 +++ b/ci/official/envs/continuous_linux_x86_cuda_py311 @@ -2,6 +2,6 @@ TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NVIDIA_SMI_ENABLE=1 TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py39 b/ci/official/envs/continuous_linux_x86_cuda_py39 index 0816384aac99d1..3899fed43065ba 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py39 +++ b/ci/official/envs/continuous_linux_x86_cuda_py39 @@ -2,6 +2,6 @@ TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config rbe_linux_cuda --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NVIDIA_SMI_ENABLE=1 TFCI_PYTHON_VERSION=3.9 diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu index f140342e706fc2..d5e7b0b634f0ef 100644 --- a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu @@ -1,7 +1,7 @@ source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda index 7bd4157119587f..adb557c7845196 100644 --- a/ci/official/envs/nightly_libtensorflow_linux_x86_cuda +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda @@ -2,7 +2,7 @@ source ci/official/envs/ci_nightly_uploads TFCI_BAZEL_COMMON_ARGS="--config release_gpu_linux --config tf_public_cache_push --config resultstore --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_LIB_SUFFIX="-gpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_NVIDIA_SMI_ENABLE=1 diff --git a/ci/official/envs/nightly_linux_x86_cpu_py310 b/ci/official/envs/nightly_linux_x86_cpu_py310 index 1c71aaedba4963..6576b8ab239593 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py310 +++ b/ci/official/envs/nightly_linux_x86_cpu_py310 @@ -3,7 +3,7 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cpu_py311 b/ci/official/envs/nightly_linux_x86_cpu_py311 index 924e8123ad44d6..544fff21a905fd 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py311 +++ b/ci/official/envs/nightly_linux_x86_cpu_py311 @@ -3,7 +3,7 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cpu_py312 b/ci/official/envs/nightly_linux_x86_cpu_py312 index 06d51849cd7501..b8442d9e03cb4a 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py312 +++ b/ci/official/envs/nightly_linux_x86_cpu_py312 @@ -3,7 +3,7 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cpu_py39 b/ci/official/envs/nightly_linux_x86_cpu_py39 index 703192cae008b2..69696ee814f77e 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py39 +++ b/ci/official/envs/nightly_linux_x86_cpu_py39 @@ -3,7 +3,7 @@ TFCI_BAZEL_COMMON_ARGS="--config release_cpu_linux --config tf_public_cache_push TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cpu TFCI_BUILD_PIP_PACKAGE_ARGS="--cpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cuda_py310 b/ci/official/envs/nightly_linux_x86_cuda_py310 index febf245f325c91..ec26fb1cb14905 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py310 +++ b/ci/official/envs/nightly_linux_x86_cuda_py310 @@ -4,7 +4,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cuda_py311 b/ci/official/envs/nightly_linux_x86_cuda_py311 index 697a326362cd3f..e7101efa94cb57 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py311 +++ b/ci/official/envs/nightly_linux_x86_cuda_py311 @@ -4,7 +4,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cuda_py312 b/ci/official/envs/nightly_linux_x86_cuda_py312 index f914e0547b87ac..4b9e371ae26ed3 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py312 +++ b/ci/official/envs/nightly_linux_x86_cuda_py312 @@ -4,7 +4,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_cuda_py39 b/ci/official/envs/nightly_linux_x86_cuda_py39 index 9615d10492eb17..63ee868a8db0b3 100644 --- a/ci/official/envs/nightly_linux_x86_cuda_py39 +++ b/ci/official/envs/nightly_linux_x86_cuda_py39 @@ -4,7 +4,7 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda TFCI_BUILD_PIP_PACKAGE_ARGS="--nightly_flag" TFCI_DOCKER_ARGS="--gpus all" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 TFCI_WHL_AUDIT_PLAT=manylinux2014_x86_64 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310 index 8331f324b6517f..8367da6b55b456 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py310 +++ b/ci/official/envs/nightly_linux_x86_tpu_py310 @@ -4,7 +4,7 @@ TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.10 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311 index 9a93c1d5fda548..8a186aad7dcce0 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py311 +++ b/ci/official/envs/nightly_linux_x86_tpu_py311 @@ -4,7 +4,7 @@ TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.11 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py312 b/ci/official/envs/nightly_linux_x86_tpu_py312 index 086c53046c9fd6..0f8c73bd601e26 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py312 +++ b/ci/official/envs/nightly_linux_x86_tpu_py312 @@ -4,7 +4,7 @@ TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.12 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39 index 012206b18a7ca9..aa413f939ee5fd 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py39 +++ b/ci/official/envs/nightly_linux_x86_tpu_py39 @@ -4,7 +4,7 @@ TFCI_BAZEL_COMMON_ARGS="--config tf_public_cache_push --config resultstore --rep TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS="--tpu --nightly_flag" TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} -TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles" +TFCI_DOCKER_REBUILD_ARGS="--build-arg PYTHON_VERSION=python$TFCI_PYTHON_VERSION --target=devel tensorflow/tools/tf_sig_build_dockerfiles" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="-f https://storage.googleapis.com/libtpu-releases/index.html" TFCI_PYTHON_VERSION=3.9 From c8d59c262bcbf48840b313e3035f83884b2d790d Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Fri, 17 Nov 2023 13:03:44 -0800 Subject: [PATCH 241/391] Move custom device trace up by reorder the pid allocation to be (a total of 1000 slots per host): - 0-500: reserved for device traces - 501-700: reserved for device custom traces (eg. megascale) - 701-999: other host related traces PiperOrigin-RevId: 583467926 --- .../profiler/convert/xplane_to_trace_container.cc | 2 +- .../third_party/tsl/tsl/profiler/utils/trace_utils.h | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc index a72b0968499a52..cfb4f2ec20cf4b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc @@ -242,7 +242,7 @@ void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, for (const XPlane* custom_plane : FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) { ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kCustomPlaneDeviceId + custom_plane->id(), hostname, + tsl::profiler::kFirstCustomPlaneDeviceId + custom_plane->id(), hostname, *custom_plane, container); } } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h index 90cee796fd95a7..6a7093b422c7d1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h @@ -29,10 +29,15 @@ namespace profiler { // Support up to 500 accelerator devices. constexpr uint32 kFirstDeviceId = 1; constexpr uint32 kLastDeviceId = 500; -// Support Upto 200 custom planes. -constexpr uint32 kCustomPlaneDeviceId = kLastDeviceId + 1; +// Support Upto 200 custom planes as fake devices (i.e., planes with a +// "/custom:" prefix). See `::kCustomPlanePrefix` for more +// information +constexpr uint32 kFirstCustomPlaneDeviceId = kLastDeviceId + 1; +constexpr uint32 kMaxCustomPlaneDevicesPerHost = 200; +constexpr uint32 kLastCustomPlaneDeviceId = + kFirstCustomPlaneDeviceId + kMaxCustomPlaneDevicesPerHost - 1; // Host threads are shown as a single fake device. -constexpr uint32 kHostThreadsDeviceId = kCustomPlaneDeviceId + 200; +constexpr uint32 kHostThreadsDeviceId = kLastCustomPlaneDeviceId + 1; // Constants used as trace_viewer TID (resource_id in trace_events.proto). constexpr int kThreadIdDerivedMin = 0xdeadbeef; From d5811f113e24cacc36d7f37b8a5c9731cdbcd921 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 17 Nov 2023 13:32:33 -0800 Subject: [PATCH 242/391] Fix crash with int4 scalar arrays. PiperOrigin-RevId: 583474466 --- third_party/xla/xla/service/llvm_ir/ir_array.cc | 4 ++++ third_party/xla/xla/tests/int4_test.cc | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 1a0284ee5c9a49..acffe65df84ef7 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -485,6 +485,10 @@ llvm::Value* IrArray::EmitArrayElementAddress( const IrArray::Index& index, llvm::IRBuilder<>* b, absl::string_view name, bool use_linear_index, llvm::Value** is_high_order_bits) const { if (ShapeUtil::IsScalar(shape_)) { + if (primitive_util::Is4BitType(shape_.element_type())) { + CHECK_NE(is_high_order_bits, nullptr); + *is_high_order_bits = b->getTrue(); + } // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value // over higher-rank arrays. diff --git a/third_party/xla/xla/tests/int4_test.cc b/third_party/xla/xla/tests/int4_test.cc index 4be51c41519527..84e2e0b98a91c2 100644 --- a/third_party/xla/xla/tests/int4_test.cc +++ b/third_party/xla/xla/tests/int4_test.cc @@ -108,5 +108,18 @@ XLA_TEST_F(HloTestBase, OddNumberOfElements) { EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); } +XLA_TEST_F(HloTestBase, Scalar) { + // Tests reading an int4 scalar value + const std::string hlo_text = R"( + HloModule Scalar + ENTRY main { + x = s4[] parameter(0) + y = s8[] convert(x) + ROOT z = s8[3, 3] broadcast(y), dimensions={} + } +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); +} + } // namespace } // namespace xla From 755803ed7227a4991ff6ece55fa6c5dcc7217a60 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 17 Nov 2023 13:44:49 -0800 Subject: [PATCH 243/391] [stream_executor] Add efficient command buffer updates PiperOrigin-RevId: 583477112 --- .../gpu/runtime3/command_buffer_cmd.cc | 53 +++++++++++-- .../service/gpu/runtime3/command_buffer_cmd.h | 21 +++++ .../gpu/runtime3/command_buffer_thunk_test.cc | 77 ++++++++++++++++++- 3 files changed, 145 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc index 5f5f4794134975..4e1dbfbcc24402 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/runtime3/command_buffer_cmd.h" +#include #include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/stream_executor_util.h" @@ -61,12 +63,38 @@ Status CommandBufferCmdSequence::Record( if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { TF_RETURN_IF_ERROR(command_buffer->Update()); } + // Returns if no cmd requires update. + if (!ShouldUpdateCmd(params)) { + return OkStatus(); + } for (auto& cmd : commands_) { TF_RETURN_IF_ERROR(cmd->Record(params, command_buffer)); } return command_buffer->Finalize(); } +bool CommandBufferCmdSequence::ShouldUpdateCmd( + const CommandBufferCmd::RecordParams& params) { + bool should_update = false; + const BufferAllocations* allocs = params.buffer_allocations; + size_t size = allocs->size(); + if (prev_allocs_.size() < size) { + prev_allocs_.resize(size); + should_update = true; + } + // Traversing all allocations from `params` using the index alone (no need for + // offset) is enough because every time `BufferAllocation` remapped to a new + // physical memory location all commands reading from any slice from that + // allocation must be invalidated. + for (unsigned i = 0; i < size; ++i) { + se::DeviceMemoryBase new_alloc = allocs->GetDeviceAddress(i); + se::DeviceMemoryBase& prev_alloc = prev_allocs_[i]; + should_update |= !new_alloc.IsSameAs(prev_alloc); + prev_alloc = new_alloc; + } + return should_update; +} + //===----------------------------------------------------------------------===// // LaunchCmd //===----------------------------------------------------------------------===// @@ -121,6 +149,10 @@ Status LaunchCmd::Record(const RecordParams& params, *kernel_args); } +CommandBufferCmd::Slices LaunchCmd::slices() { + return CommandBufferCmd::Slices(args_.begin(), args_.end()); +} + //===----------------------------------------------------------------------===// // MemcpyDeviceToDeviceCmd //===----------------------------------------------------------------------===// @@ -139,6 +171,10 @@ Status MemcpyDeviceToDeviceCmd::Record(const RecordParams& params, return command_buffer->MemcpyDeviceToDevice(&dst, src, num_bytes_); } +CommandBufferCmd::Slices MemcpyDeviceToDeviceCmd::slices() { + return {dst_, src_}; +} + //===----------------------------------------------------------------------===// // GemmCmd //===----------------------------------------------------------------------===// @@ -167,19 +203,26 @@ Status GemmCmd::Record(const RecordParams& params, << ", output=" << output_buffer_ << ", deterministic=" << deterministic_; - const BufferAllocations& allocs = *params.buffer_allocations; se::DeviceMemoryBase workspace(nullptr, 0); + se::DeviceMemoryBase lhs = + params.buffer_allocations->GetDeviceAddress(lhs_buffer_); + se::DeviceMemoryBase rhs = + params.buffer_allocations->GetDeviceAddress(rhs_buffer_); + se::DeviceMemoryBase out = + params.buffer_allocations->GetDeviceAddress(output_buffer_); TF_ASSIGN_OR_RETURN( auto nested_buffer, stream_executor::CommandBuffer::Trace( command_buffer->executor(), [&](stream_executor::Stream* stream) { - return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), - allocs.GetDeviceAddress(rhs_buffer_), - allocs.GetDeviceAddress(output_buffer_), workspace, - deterministic_, stream); + return RunGemm(config_, lhs, rhs, out, workspace, deterministic_, + stream); })); return command_buffer->AddNestedCommandBuffer(nested_buffer); } +CommandBufferCmd::Slices GemmCmd::slices() { + return {lhs_buffer_, rhs_buffer_, output_buffer_}; +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h index 75ab07bb2d5042..a80150bfd31589 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_cmd.h @@ -45,6 +45,7 @@ namespace xla::gpu { class CommandBufferCmd { public: using ExecutableSource = Thunk::ExecutableSource; + using Slices = absl::InlinedVector; // Run time parameters required for recording commands into the command // buffer. For example when we emit command buffer cmd sequence from an HLO @@ -66,6 +67,10 @@ class CommandBufferCmd { virtual Status Record(const RecordParams& params, se::CommandBuffer* command_buffer) = 0; + // Returns all buffer slices of the cmd. These will be used to track cmd + // updates, thus they need to be consistent across calls to the function. + virtual Slices slices() = 0; + virtual ~CommandBufferCmd() = default; }; @@ -96,7 +101,17 @@ class CommandBufferCmdSequence { se::CommandBuffer* command_buffer); private: + // Traverse the list of commands and figures out if any of them requires an + // update. Also updates `prev_allocs_` with new allocations from `params`. + bool ShouldUpdateCmd(const CommandBufferCmd::RecordParams& params); + std::vector> commands_; + // Mapping from buffer slice index to device memory passed at that index via + // the `CommandBufferCmd::RecordParams` in previous invocation of `Record`. + // We can just use a vector instead of map because `BufferAllocation` has a + // unique identifier assigned contiguously and thus can be used as array + // index. + std::vector prev_allocs_; }; //===----------------------------------------------------------------------===// @@ -115,6 +130,8 @@ class LaunchCmd : public CommandBufferCmd { Status Record(const RecordParams& params, se::CommandBuffer* command_buffer) override; + Slices slices() override; + private: using OwnedKernel = std::unique_ptr; @@ -138,6 +155,8 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { Status Record(const RecordParams& params, se::CommandBuffer* command_buffer) override; + Slices slices() override; + private: BufferAllocation::Slice dst_; BufferAllocation::Slice src_; @@ -160,6 +179,8 @@ class GemmCmd : public CommandBufferCmd { Status Record(const RecordParams& params, se::CommandBuffer* command_buffer) override; + Slices slices() override; + private: const GemmConfig config_; const BufferAllocation::Slice lhs_buffer_; diff --git a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc index 9ad683a640980e..7712ce1c8c7098 100644 --- a/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime3/command_buffer_thunk_test.cc @@ -89,6 +89,18 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { stream.ThenMemcpy(dst.data(), b, byte_length); ASSERT_EQ(dst, std::vector(4, 42)); + + // Try to update the command buffer with the same buffers. + stream.ThenMemZero(&b, byte_length); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42)); } TEST(CommandBufferThunkTest, LaunchCmd) { @@ -157,6 +169,18 @@ TEST(CommandBufferThunkTest, LaunchCmd) { stream.ThenMemcpy(dst.data(), c, byte_length); ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Try to update the command buffer with the same buffers. + stream.ThenMemZero(&c, byte_length); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `c` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), c, byte_length); + + ASSERT_EQ(dst, std::vector(4, 42 + 42)); } TEST(CommandBufferThunkTest, GemmCmd) { @@ -244,6 +268,18 @@ TEST(CommandBufferThunkTest, GemmCmd) { stream.ThenMemcpy(dst.data(), updated_out, out_length); ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Try to update the command buffer with the same buffers. + stream.ThenMemZero(&updated_out, out_length); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), updated_out, out_length); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } TEST(CommandBufferThunkTest, MultipleLaunchCmd) { @@ -308,9 +344,48 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { ASSERT_EQ(dst, std::vector(4, 42 + 42)); // Copy `d` data back to host. - std::vector dst_1(4, 0); + std::fill(dst.begin(), dst.end(), 0); stream.ThenMemcpy(dst.data(), d, byte_length); ASSERT_EQ(dst, std::vector(4, 21 + 21)); + + BufferAllocation alloc_e(/*index=*/3, byte_length, /*color=*/0); + BufferAllocation::Slice slice_e(&alloc_e, 0, byte_length); + + // Prepare buffer allocation for updating command buffer: e=0 + se::DeviceMemory e = executor->AllocateArray(length, 0); + stream.ThenMemZero(&e, byte_length); + + // Update buffer allocation #1 to buffer `c`. + allocations = BufferAllocations({a, b, c, e}, 0, executor->GetAllocator()); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `e` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), e, byte_length); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); + + // Try to update the command buffer with the same buffers. + stream.ThenMemZero(&e, byte_length); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), b, byte_length); + ASSERT_EQ(dst, std::vector(4, 42 + 42)); + + // Copy `e` data back to host. + std::fill(dst.begin(), dst.end(), 0); + stream.ThenMemcpy(dst.data(), e, byte_length); + ASSERT_EQ(dst, std::vector(4, 21 + 21)); } } // namespace xla::gpu From 285b0b13c2e591ea393ae61860a8498be368c8cb Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Fri, 17 Nov 2023 14:08:33 -0800 Subject: [PATCH 244/391] Shrink MLRT buffer to reduce memory usage. PiperOrigin-RevId: 583482753 --- .../compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc | 2 ++ tensorflow/core/tfrt/mlrt/bytecode/bytecode.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 8c85f9f80ac912..1953ddd3d93997 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -464,6 +464,8 @@ absl::StatusOr EmitExecutable( return status; } + buffer.shrink_to_fit(); + return buffer; } diff --git a/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h b/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h index f6b8de5da15dcb..f82666f172a37d 100644 --- a/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h +++ b/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h @@ -109,6 +109,8 @@ class Buffer { size_t size() const { return buffer_.size(); } bool empty() const { return buffer_.empty(); } + void shrink_to_fit() { buffer_.shrink_to_fit(); } + private: static_assert(alignof(std::max_align_t) >= 8, "The bytecode buffer needs to be at least 8-byte aligned."); From e5a75b525be1f4c40931e2125c5f9f04bdaa40e8 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 17 Nov 2023 14:40:32 -0800 Subject: [PATCH 245/391] [xla:python] Defer evaluating `platform_name` in `PyClient`. PiperOrigin-RevId: 583490555 --- third_party/xla/xla/python/py_client.cc | 10 ---------- third_party/xla/xla/python/py_client.h | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 3f03c99d601e6e..9f7021be5761c7 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -60,16 +60,6 @@ PyClient::PyClient(std::shared_ptr ifrt_client) : ifrt_client_(std::move(ifrt_client)), client_attributes_(ifrt_client_->attributes()) { CHECK(ifrt_client_); - // TODO(phawkins): this is a temporary backwards compatibility shim. We - // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but - // we haven't yet updated JAX clients that expect "gpu". Migrate users and - // remove this code. - if (ifrt_client_->platform_name() == "cuda" || - ifrt_client_->platform_name() == "rocm") { - platform_name_ = "gpu"; - } else { - platform_name_ = ifrt_client_->platform_name(); - } } PyClient::~PyClient() { diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 68b95c23affb4d..581801831ffec6 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -139,7 +139,18 @@ class PyClient : public std::enable_shared_from_this { return shared_ptr_pjrt_client(); } - absl::string_view platform_name() const { return platform_name_; } + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } absl::string_view platform_version() const { return ifrt_client_->platform_version(); } @@ -257,7 +268,6 @@ class PyClient : public std::enable_shared_from_this { friend struct PyArray_Storage; std::shared_ptr ifrt_client_; - std::string platform_name_; absl::flat_hash_map client_attributes_; // Pointers to intrusive doubly-linked lists of arrays and executables, used From 232a64df72146a2cf8747a7f6bf7f2a27213d7c8 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Fri, 17 Nov 2023 14:51:48 -0800 Subject: [PATCH 246/391] Add some TPU passes to MLIR CPU/GPU phase 1 pipeline This adds some TPU passes to CPU/GPU pass pipeline before XlaClusterFormation, in preparation for the unification of the clustering code. The passes should be no-ops for CPU/GPU graphs. PiperOrigin-RevId: 583493387 --- .../tf2xla/internal/clustering_bridge_passes.cc | 15 ++++++++++++--- .../internal/clustering_bridge_passes_test.cc | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index b6b22dc2b40690..2628d9f17b59cb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -176,11 +176,12 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // The following ops must be preserved regardless of reachability. Ideally, // all graphs should have control dependencies to enforce this. VLOG(2) << "Create TF XLA Bridge pipeline"; + pm.addPass(mlir::TFDevice::CreateXlaValidateInputsPass()); pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); - // This pass expectes unified compilation markers. - pm.addPass(mlir::TFDevice::CreateXlaValidateInputsPass()); - const llvm::SmallVector ops_to_preserve = {}; + const llvm::SmallVector ops_to_preserve = { + "tf.TPUReplicateMetadata", "tf.TPUCompilationResult", + "tf.TPUReplicatedOutput"}; pm.addNestedPass( mlir::tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve)); // It is assumed at this stage there are no V1 control flow ops as Graph @@ -192,6 +193,14 @@ void AddNonTPUBridgeClusteringPipelinePasses(OpPassManager& pm) { // inference. pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // The following passe are addded to match TPU pipeline and expected to be + // no-op. + pm.addNestedPass(mlir::TFTPU::CreateTPUPartitionedOpConversionPass()); + pm.addNestedPass( + mlir::TFTPU::CreateTPUReorderReplicateAndPartitionedInputsPass()); + pm.addNestedPass(mlir::TF::CreateDecomposeReduceDatasetPass()); + pm.addPass(mlir::TFDevice::CreateEmbeddingPipeliningPass()); + pm.addPass(mlir::TFDevice::CreateEmbeddingSequencingPass()); // Encapsulate PartitionedCall ops within a cluster so that the composite // resource ops can be decomposed. pm.addPass(tensorflow::tf2xla::internal::CreateXlaClusterFormationPass()); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc index 91b80fa485a83f..de9d97697ec5f0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes_test.cc @@ -35,7 +35,7 @@ TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) { OpPassManager pass_manager; AddNonTPUBridgeClusteringPipelinePasses(pass_manager); - EXPECT_EQ(pass_manager.size(), 15); + EXPECT_EQ(pass_manager.size(), 20); } }; // namespace internal From 7b2d5e070ebd820e6fc35f65af98b4338e8ce15c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 15:51:16 -0800 Subject: [PATCH 247/391] Fix some compilation perf issues in copy_insertion.cc 1) pass by reference for instruction_ids in sorting lambda 2) use instruction unique_id for hashing PiperOrigin-RevId: 583506743 --- third_party/xla/xla/client/lib/BUILD | 6 +++--- third_party/xla/xla/service/copy_insertion.cc | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index f6c2fe2d91ee9e..134c8c47d6e115 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -208,10 +208,10 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - tags = [ + backend_tags = { # Times out. - "noasan", - ], + "ghostfish_iss": ["noasan"], + }, deps = [ ":constants", ":math", diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 13f2bb2bdc5605..546f25e431eb2c 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -1196,12 +1196,12 @@ class CopyRemover { // Instruction indices based on post order traversal of computations and // instructions. Used as an enhancement for getting strict weak ordering // used for sorting below. - absl::flat_hash_map instruction_ids; + absl::flat_hash_map instruction_ids; int64_t id = 0; for (HloComputation* computation : module.MakeComputationPostOrder()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { - instruction_ids[instruction] = id++; + instruction_ids[instruction->unique_id()] = id++; } } @@ -1255,8 +1255,8 @@ class CopyRemover { } std::vector values = buffer.values(); - absl::c_sort(values, [this, instruction_ids](const HloValue* a, - const HloValue* b) { + absl::c_sort(values, [this, &instruction_ids](const HloValue* a, + const HloValue* b) { // IsDefinedBefore() is generally not strict weak ordering required by // the sort algorithm, since a may not be comparable to b or c by // IsDefinedBefore(), but b and c can be comparable. Such as in: @@ -1273,8 +1273,8 @@ class CopyRemover { return false; } const bool a_has_smaller_id = - instruction_ids.at(a->defining_instruction()) < - instruction_ids.at(b->defining_instruction()); + instruction_ids.at(a->defining_instruction()->unique_id()) < + instruction_ids.at(b->defining_instruction()->unique_id()); // Use a_has_smaller_id as a hint for the order between a and b. In case // it's right, there is no need for two IsDefinedBefore() tests. if (a_has_smaller_id) { From af223be11863d5e1c3195c75ec85fd6935f6f5d0 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 17 Nov 2023 15:52:41 -0800 Subject: [PATCH 248/391] Reorganize `hlo_test` to only take one input HLO PiperOrigin-RevId: 583507102 --- third_party/xla/xla/tests/fuzz/BUILD | 31 +++++----- third_party/xla/xla/tests/fuzz/build_defs.bzl | 56 +++++++++++++++---- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/tests/fuzz/BUILD b/third_party/xla/xla/tests/fuzz/BUILD index 0f450e028a662d..c1e4edddbc42f2 100644 --- a/third_party/xla/xla/tests/fuzz/BUILD +++ b/third_party/xla/xla/tests/fuzz/BUILD @@ -14,22 +14,17 @@ cc_library( ], ) -hlo_test( - name = "rand", - srcs = [], - hlo_files = glob( - include = ["rand_*.hlo"], - exclude = [ - "rand_000001.hlo", # fails on GPU - "rand_000004.hlo", # times out during coverage - # These fail on all platforms - "rand_000060.hlo", - "rand_000067.hlo", - "rand_000072.hlo", - ], - ), - deps = [ - ":hlo_test_lib", - "@local_tsl//tsl/platform:test_main", +[hlo_test( + name = hlo + "_test", + hlo = hlo, +) for hlo in glob( + include = ["rand_*.hlo"], + exclude = [ + "rand_000001.hlo", # fails on GPU + "rand_000004.hlo", # times out during coverage + # These fail on all platforms + "rand_000060.hlo", + "rand_000067.hlo", + "rand_000072.hlo", ], -) +)] diff --git a/third_party/xla/xla/tests/fuzz/build_defs.bzl b/third_party/xla/xla/tests/fuzz/build_defs.bzl index 4d19dba75f273a..08f5e31a332819 100644 --- a/third_party/xla/xla/tests/fuzz/build_defs.bzl +++ b/third_party/xla/xla/tests/fuzz/build_defs.bzl @@ -2,15 +2,47 @@ load("//xla/tests:build_defs.bzl", "xla_test") -def hlo_test(name, hlo_files, srcs, deps, **kwargs): - for hlo in hlo_files: - without_extension = hlo.split(".")[0] - xla_test( - name = without_extension, - srcs = srcs, - env = {"HLO_PATH": "$(location {})".format(hlo)}, - data = [hlo], - real_hardware_only = True, - deps = deps, - **kwargs - ) +def hlo_test(name, hlo, **kwargs): + """Wrapper around `xla_test` which runs an HLO through `hlo_test_lib`. + + `srcs = []` because `hlo_test_lib` linked with `tsl/platform:test_main` + makes usable test binary where the path to the HLO is given via `HLO_PATH` + environment variable. + + This has the following nice properties: + * adding an HLO to this directory with the appropriate prefix for a test + suite (e.g. rand) will have it automatically create the corresponding test + * `hlo_test_lib` only needs to be compiled once instead of for every + target + * automated tools can easily create reproducer CLs by appending one line + to the `xla/tests/fuzz` BUILD file like `hlo_test(name = ..., hlo = ...)`. + * plays nicely with `xla_test`, so we have easy testing against all + platforms and a `test_suite` generated for each HLO which includes tests + against all platforms. This is particularly useful for pruning the set of + HLOs, as we can prune against `test_suites` representing all the tests + associated with a particular HLO, rather than individual targets. + + In the future it may make sense to reformulate this to use `hlo-opt` and + `run_hlo_module` or similar to accomplish the same thing. + + Args: + name: + The name of the macro. This really could be generated from `hlo`, but + tools like build_cleaner assume that all macros have a name attribute. + hlo: + The hlo to test. + **kwargs: + Additional arguments passed to `xla_test`. + """ + xla_test( + name = name, + srcs = [], + env = {"HLO_PATH": "$(location {})".format(hlo)}, + data = [hlo], + real_hardware_only = True, + deps = [ + "//xla/tests/fuzz:hlo_test_lib", + "@local_tsl//tsl/platform:test_main", + ], + **kwargs + ) From ef70a62b4803fafc962bdf5d3696c93a38221cfa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 15:56:51 -0800 Subject: [PATCH 249/391] Fix flatbuffer_exporter custom_option_vector alignment PiperOrigin-RevId: 583508027 --- tensorflow/compiler/mlir/lite/flatbuffer_export.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 44e81f58ad9686..4775016509f8c3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1311,7 +1311,7 @@ BufferOffset Translator::BuildCustomOperator( /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS); } if (custom_option_alignment_.has_value()) { - builder_.ForceVectorAlignment(results.size(), sizeof(uint8_t), + builder_.ForceVectorAlignment(custom_option_vector.size(), sizeof(uint8_t), custom_option_alignment_.value()); } auto custom_option_fbs_vector = From 46b3e472d653ff0bc86cde6c9ca34fa26dae3df7 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 17 Nov 2023 16:05:22 -0800 Subject: [PATCH 250/391] [xla:gpu] Revise gpu-schedule-postprocessing pass to exclude fusion. This is because cuLaunchKernel is safe, even though it could technically block if too many unsubmitted commands are in the push buffer but this doesn't happen in practice. Also rename no_parallel_gpu_op to no_parallel_custom_call. Modify the added routines to use a const parameter when it is appropriate. This is to address the follow up comments for the PR that adds the pass. PiperOrigin-RevId: 583510144 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/backend_configs.proto | 11 +-- .../xla/service/gpu/backend_configs_test.cc | 2 +- .../gpu/gpu_schedule_postprocessing.cc | 81 ++++++++++--------- .../service/gpu/gpu_schedule_postprocessing.h | 14 ++-- .../gpu/gpu_schedule_postprocessing_test.cc | 6 +- 6 files changed, 59 insertions(+), 56 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8445dec5c78840..a9a6d3d540287b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3357,6 +3357,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 867f7f8fe2af77..46f51ebe5857bc 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -111,13 +111,14 @@ message BitcastBackendConfig { // Backend config for async collective operations. Note that for is_sync will // be false by default, so even if a backend config is not explicitly attached // to the HLOInstruction, getting the backend_config will yield a default valued -// proto which will have is_sync = false. Attribute no_parallel_gpu_op asserts -// that an asynchronous collective operation does not execute in parallel with -// other operations in GPU. This attribute will also be false by default, which -// should lead to conversative runtime behavior. +// proto which will have is_sync = false. Attribute no_parallel_custom_call +// asserts that an asynchronous collective operation does not execute in +// parallel with custom-calls, which can trigger device synchronization . This +// attribute will also be false by default and should lead to conversative +// runtime behavior. message CollectiveBackendConfig { bool is_sync = 1; - bool no_parallel_gpu_op = 2; + bool no_parallel_custom_call = 2; } message ReificationCost { diff --git a/third_party/xla/xla/service/gpu/backend_configs_test.cc b/third_party/xla/xla/service/gpu/backend_configs_test.cc index d1b32e1abda201..99c86bd4e08d52 100644 --- a/third_party/xla/xla/service/gpu/backend_configs_test.cc +++ b/third_party/xla/xla/service/gpu/backend_configs_test.cc @@ -50,7 +50,7 @@ TEST_F(BackendConfigsTest, DefaultCollectiveBackendConfig) { ags->backend_config(); EXPECT_THAT(collective_backend_config.status(), IsOk()); EXPECT_THAT(collective_backend_config->is_sync(), IsFalse()); - EXPECT_THAT(collective_backend_config->no_parallel_gpu_op(), IsFalse()); + EXPECT_THAT(collective_backend_config->no_parallel_custom_call(), IsFalse()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc index 603dfe28b94be7..8c66874b378eef 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" @@ -34,31 +35,32 @@ limitations under the License. namespace xla { namespace gpu { namespace { -// Maps a computation to a boolean that indicates whether the computation -// invokes gpu ops directly or indirectly. -using GpuOpInComputation = absl::flat_hash_map; - -// Returns whether the hlo may invoke gpu-ops which are operations that call -// into CUDA, directly or indirectly. Currently, we only check for custom-calls -// and fusion, because they are the only gpu-ops that can be parallel with -// asynchronous collectives operations. -bool MayInvokeGpuOp(HloInstruction* hlo, - GpuOpInComputation& gpu_op_in_computation) { - if (hlo->opcode() == HloOpcode::kCustomCall || - hlo->opcode() == HloOpcode::kFusion) { +// Maps a computation to a boolean that indicates whether the computation may +// invoke custom-calls directly or indirectly, which can eventually trigger gpu +// synchronization. +using CustomCallInComputation = + absl::flat_hash_map; + +// Returns whether the hlo may invoke custom-calls which may trigger gpu +// synchronization. Currently, we only check for custom-calls, because they are +// the only operations that can be parallel with asynchronous collectives +// operations in an hlo-schedule and may trigger gpu synchronization. +bool MayInvokeCustomCall( + const HloInstruction* hlo, + const CustomCallInComputation& custom_call_in_computation) { + if (hlo->opcode() == HloOpcode::kCustomCall) { return true; } - return std::any_of(hlo->called_computations().begin(), - hlo->called_computations().end(), - [&](const HloComputation* callee) { - return gpu_op_in_computation.find(callee)->second; - }); + return absl::c_any_of( + hlo->called_computations(), [&](const HloComputation* callee) { + return custom_call_in_computation.find(callee)->second; + }); } // Returns true if this is an asynchronous collective start operation, excluding // P2P operations. -StatusOr IsRelevantAsynchronousStart(HloInstruction* hlo) { +StatusOr IsRelevantAsynchronousStart(const HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); if (!hlo_query::IsAsyncCollectiveStartOp(opcode, /*include_send_recv=*/false)) { @@ -71,37 +73,35 @@ StatusOr IsRelevantAsynchronousStart(HloInstruction* hlo) { // Returns true if this is a collective done operation, excluding P2P // operations. -StatusOr IsRelevantAsynchronousDone(HloInstruction* hlo) { +StatusOr IsRelevantAsynchronousDone(const HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); return hlo_query::IsAsyncCollectiveDoneOp(opcode, /*include_send_recv=*/false); } // For a given computation, finds all the asynchronous collective operations -// that aren't parallel with other gpu-op-invoking instructions and sets its -// no_parallel_gpu_op attribute to true. Also records whether the given -// computation may invoke gpu-ops. -StatusOr ProcessComputation(HloSchedule& schedule, - HloComputation* computation, - GpuOpInComputation& gpu_op_in_computation) { +// that aren't parallel with custom-calls and sets its no_parallel_custom_call +// attribute to true. Also records whether the given computation may invoke +// custom-calls. +StatusOr ProcessComputation( + const HloSchedule& schedule, HloComputation* computation, + CustomCallInComputation& custom_call_in_computation) { bool changed = false; - bool has_gpu_op = false; + bool has_custom_call = false; absl::flat_hash_set async_starts; const HloInstructionSequence& sequence = schedule.sequence(computation); // Visit instructions in the sequence. Collect relevant asynchronous // collective start ops. When we see a relevant asynchronous collective done // op, remove the corresponding start op from the collection and set its - // attribute no_parallel_gpu_op to true. When we see a gpu-op, clear the start - // ops from the collection and keep their attribute no_parallel_gpu_op as - // false. + // attribute no_parallel_custom_call to true. When we see a custom-call, clear + // the start ops from the collection and keep their attribute + // no_parallel_custom_call as false. const std::vector all_instructions = sequence.instructions(); - for (auto instr_it = all_instructions.begin(); - instr_it != all_instructions.end(); ++instr_it) { - HloInstruction* hlo = *instr_it; - if (MayInvokeGpuOp(hlo, gpu_op_in_computation)) { + for (HloInstruction* hlo : all_instructions) { + if (MayInvokeCustomCall(hlo, custom_call_in_computation)) { async_starts.clear(); - has_gpu_op = true; + has_custom_call = true; continue; } TF_ASSIGN_OR_RETURN(bool is_async_start, IsRelevantAsynchronousStart(hlo)); @@ -118,7 +118,7 @@ StatusOr ProcessComputation(HloSchedule& schedule, TF_ASSIGN_OR_RETURN( CollectiveBackendConfig collective_backend_config, async_start->backend_config()); - collective_backend_config.set_no_parallel_gpu_op(true); + collective_backend_config.set_no_parallel_custom_call(true); TF_RETURN_IF_ERROR( async_start->set_backend_config(collective_backend_config)); async_starts.erase(async_start); @@ -126,7 +126,7 @@ StatusOr ProcessComputation(HloSchedule& schedule, } } - gpu_op_in_computation[computation] = has_gpu_op; + custom_call_in_computation[computation] = has_custom_call; return changed; } @@ -138,7 +138,7 @@ StatusOr GpuSchedulePostprocessing::Run( if (!module->has_schedule()) return false; HloSchedule& schedule = module->schedule(); bool changed = false; - GpuOpInComputation gpu_op_in_computation; + CustomCallInComputation custom_call_in_computation; // We visit computations in the order of callees to callers, as information is // propagated from calles to callers. @@ -148,12 +148,13 @@ StatusOr GpuSchedulePostprocessing::Run( ++iter) { HloComputation* computation = *iter; if (computation->IsFusionComputation()) { - gpu_op_in_computation[computation] = false; + custom_call_in_computation[computation] = false; continue; } - TF_ASSIGN_OR_RETURN(bool result, ProcessComputation(schedule, computation, - gpu_op_in_computation)); + TF_ASSIGN_OR_RETURN( + bool result, + ProcessComputation(schedule, computation, custom_call_in_computation)); changed |= result; } diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h index 578dffabc146f2..521d74e617d100 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h @@ -22,13 +22,13 @@ namespace xla { namespace gpu { // Amends a schedule result with the needed information to support a runtime -// implementation. Currently, this pass refines attribute no_parallel_gpu_op -// for asynchronous collective operations to support runtime optimization, such -// as skipping rendezvous of all participating threads for NCCL collective -// operations. In particular, it sets the attribute value for Collective-start -// operations with is_sync=false; it also keeps the attribute value untouch for -// the operations with is_sync=true and for P2P operations, assumming the -// runtime won't use those values. +// implementation. Currently, this pass refines attribute +// no_parallel_custom_call for asynchronous collective operations to support +// runtime optimization, such as skipping rendezvous of all participating +// threads for NCCL collective operations. In particular, it sets the attribute +// value for Collective-start operations with is_sync=false; it also keeps the +// attribute value untouch for the operations with is_sync=true and for P2P +// operations, assumming the runtime won't use those values. // class GpuSchedulePostprocessing : public HloModulePass { public: diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc index e803e7843eb007..b9ef17de14c825 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc @@ -96,7 +96,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, start->backend_config()); - EXPECT_TRUE(collective_backend_config.no_parallel_gpu_op()); + EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); } TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { @@ -120,7 +120,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, start->backend_config()); - EXPECT_FALSE(collective_backend_config.no_parallel_gpu_op()); + EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } TEST_F(GpuSchedulePostprocessingTest, @@ -149,7 +149,7 @@ TEST_F(GpuSchedulePostprocessingTest, HloInstruction* start = FindInstruction(module.get(), "all-gather-start"); TF_ASSERT_OK_AND_ASSIGN(CollectiveBackendConfig collective_backend_config, start->backend_config()); - EXPECT_FALSE(collective_backend_config.no_parallel_gpu_op()); + EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } } // namespace From d10078685a7f0c330f8394ad9b55721f6894506c Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Fri, 17 Nov 2023 16:08:07 -0800 Subject: [PATCH 251/391] Removes `QuantizationOptions` dependencies from `AddQuantizePtqPostCalibrationStablehloPasses`. PiperOrigin-RevId: 583510761 --- .../quantization/tensorflow/python/quantize_model.cc | 5 ++--- .../mlir/quantization/tensorflow/quantize_passes.cc | 9 +++------ .../mlir/quantization/tensorflow/quantize_passes.h | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 10b9bb9ad7296d..72cf30186d3fb3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -636,10 +636,9 @@ absl::StatusOr QuantizePtqModelPostCalibration( TF_QUANT_RETURN_IF_ERROR(RunPasses( /*name=*/kTfQuantPtqPostCalibrationStepStableHloName, /*add_passes_func=*/ - [&quantization_options](mlir::PassManager &pm) { + [](mlir::PassManager &pm) { AddQuantizePtqPostCalibrationStablehloPasses( - pm, quantization_options, - kTfQuantPtqPostCalibrationStepStableHloName); + pm, kTfQuantPtqPostCalibrationStepStableHloName); }, context, *module_ref)); } else { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 497ae346bfb5eb..36ebd0d9ccc297 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -50,7 +50,7 @@ void AddStablehloQuantToIntPasses(mlir::PassManager &pm) { } void AddStaticRangeQuantizationPass( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::PassManager &pm, std::optional mlir_dump_file_prefix) { pm.addPass(mlir::quant::stablehlo::createQuantizeCompositeFunctionsPass()); } @@ -231,7 +231,6 @@ void AddQuantizePtqPostCalibrationPasses( } // StableHLO Quantization passes that are ran if StableHLO opset is selected. -// TODO: b/298581932 - Add tests for passes below once migration is complete. void AddQuantizePtqPreCalibrationStablehloPasses( mlir::PassManager &pm, const QuantizationOptions &quantization_options) { pm.addPass( @@ -246,9 +245,8 @@ void AddQuantizePtqPreCalibrationStablehloPasses( AddCallModuleSerializationPasses(pm); } -// TODO: b/298581932 - Migrate and add passes below. void AddQuantizePtqPostCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::PassManager &pm, std::optional mlir_dump_file_prefix) { // Deserializes the StableHLO module embedded in tf.XlaCallModule and lifts // the StableHLO functions to the top level module. This is needed for @@ -257,8 +255,7 @@ void AddQuantizePtqPostCalibrationStablehloPasses( pm.addPass(mlir::quant::stablehlo::createRestoreFunctionNamePass()); pm.addNestedPass( mlir::quant::CreateConvertCustomAggregationOpToQuantStatsPass()); - AddStaticRangeQuantizationPass(pm, quantization_options, - mlir_dump_file_prefix); + AddStaticRangeQuantizationPass(pm, mlir_dump_file_prefix); AddStablehloQuantToIntPasses(pm); AddCallModuleSerializationPasses(pm); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h index 3aef23b5667d51..2ef01587da0f09 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h @@ -54,7 +54,7 @@ void AddQuantizePtqPreCalibrationStablehloPasses( mlir::PassManager &pm, const QuantizationOptions &quantization_options); void AddQuantizePtqPostCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options, + mlir::PassManager &pm, std::optional mlir_dump_file_prefix = std::nullopt); From 8b544872ad83dbc49cde9b25c3d06d478a6dd568 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 16:17:49 -0800 Subject: [PATCH 252/391] Add UnTileShape() and UnTileLeafShape() in hlo_sharding_util. PiperOrigin-RevId: 583512870 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 32 +++++++++++++++++++ .../xla/xla/hlo/utils/hlo_sharding_util.h | 8 +++++ 2 files changed, 40 insertions(+) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 7cdbf2f16c03ac..bee53152f046d8 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/protobuf_util.h" #include "xla/service/call_graph.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -2758,5 +2759,36 @@ std::optional GetOutputSharding( return instruction->sharding(); } +Shape UntileShape(const HloSharding& sharding, const Shape& shape) { + if (!sharding.IsTuple()) { + return UntileLeafShape(sharding, shape); + } + Shape result_shape = shape; + ShapeUtil::ForEachMutableSubshape( + &result_shape, + [&shape, &sharding](Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(shape, index)) { + return; + } + const HloSharding& subshape_sharding = + sharding.GetSubSharding(shape, index); + *subshape = UntileLeafShape(subshape_sharding, *subshape); + }); + + return result_shape; +} + +Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape) { + if (sharding.IsTileMaximal() || sharding.IsManual() || sharding.IsUnknown()) { + return shape; + } + Shape result_shape = shape; + for (int64_t i = 0; i < sharding.TiledDataRank(); ++i) { + result_shape.set_dimensions( + i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + return result_shape; +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index d953f0cc532334..c158f15891d8d8 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/call_graph.h" +#include "xla/shape.h" namespace xla { namespace hlo_sharding_util { @@ -467,6 +468,13 @@ std::optional GetGatherScatterBatchParallelDims( // special handling like Outfeed and this function takes care of those. std::optional GetOutputSharding(const HloInstruction* instruction); +// Returns the un-tiled shape. +Shape UntileShape(const HloSharding& sharding, const Shape& shape); + +// Returns the un-tiled shape. +// REQUIRES: !sharding.IsTuple() +Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape); + } // namespace hlo_sharding_util } // namespace xla From aad5343c6504be8cf3d5deba3f39cfe03dde1ebe Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Fri, 17 Nov 2023 16:26:21 -0800 Subject: [PATCH 253/391] Make IsCallToPureFunction analysis recursive, and aware of "If" and "Assert". PiperOrigin-RevId: 583514777 --- .../analysis/side_effect_analysis.cc | 30 ++++++- .../tests/side-effect-analysis-test.mlir | 80 +++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 744fa37a914de0..b0d730898316d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -346,6 +347,7 @@ class OpSideEffectCollector { } bool IsCallToPureFunction(Operation* callOp) const; + bool IsPureFunction(func::FuncOp func_op) const; private: // Adds op-based side effects from all ops in `region` to `op` side effects. @@ -510,18 +512,42 @@ bool OpSideEffectCollector::IsCallToPureFunction(Operation* callOp) const { return false; // not a call func::FuncOp func_op = dyn_cast(call.resolveCallable( &symbol_table_collection_)); + return IsPureFunction(func_op); +} + +bool OpSideEffectCollector::IsPureFunction(func::FuncOp func_op) const { auto it = is_pure_function_.find(func_op); if (it == is_pure_function_.end()) { bool is_pure = true; + is_pure_function_[func_op] = is_pure; // prevent infinite recursion func_op->walk([&](Operation* op) { - if (op == func_op) return WalkResult::advance(); + if (op == func_op) { + return WalkResult::advance(); + } + // AssertOp is not, technically, pure. However, we treat functions + // that contain an assert as pure, so that graphs with and without + // assert don't have different side effect semantics. Also see + // b/309824992 for the challenges associated with improving the side + // effect modelling of Assert on the op level. + if (llvm::isa(op)) { + return WalkResult::advance(); + } + if (auto if_op = llvm::dyn_cast(op)) { + if (IsPureFunction(if_op.then_function()) && + IsPureFunction(if_op.else_function())) { + return WalkResult::advance(); + } + } + if (IsCallToPureFunction(op)) { + return WalkResult::advance(); + } if (TensorFlowDialect::CanHaveSideEffects(op)) { is_pure = false; return WalkResult::interrupt(); } return WalkResult::advance(); }); - is_pure_function_.insert({func_op, is_pure}); + is_pure_function_[func_op] = is_pure; } return is_pure_function_[func_op]; } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index c57e07b5e3f74e..7246cdb4513280 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2964,3 +2964,83 @@ func.func @global_iter_id_effect() -> () { // expected-remark@above {{ID: 6}} // expected-remark@above {{Sinks: {}}} } + +// ----- + +func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 2}} + %sum = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + func.return %sum : tensor<1xf32> + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Sinks: {}}} +} + +func.func @intermediary(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 2}} + %result = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config="", config_proto="", executor_type="", f=@add} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + func.return %result : tensor<1xf32> + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Sinks: {}}} +} + +// CHECK-LABEL: func @call_pure_function +func.func @call_pure_function(%arg0: tensor) -> tensor { + // expected-remark@above {{ID: 5}} + %one = "tf.Const"() { value = dense<1.0> : tensor<1xf32> } : () -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + %r1 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 1}} + %two = "tf.StatefulPartitionedCall"(%one, %one) {config="", config_proto="", executor_type="", f=@intermediary} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 2}} + %r2 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 3}} + func.return %arg0 : tensor + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Sinks: {1,3}}} +} + +// ----- + +func.func @assert(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor { + // expected-remark@above {{ID: 3}} + %cond = builtin.unrealized_conversion_cast to tensor + // expected-remark@above {{ID: 0}} + "tf.Assert"(%cond, %arg1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", summarize = 3 : i64} : (tensor, tensor<1xf32>) -> () + // expected-remark@above {{ID: 1}} + func.return %cond : tensor + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Sinks: {1}}} +} + +func.func @intermediary(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // expected-remark@above {{ID: 3}} + %cond = builtin.unrealized_conversion_cast to tensor + // expected-remark@above {{ID: 0}} + %sum = "tf.If"(%cond, %arg0, %arg1) { + then_branch = @assert, + else_branch = @assert, + is_stateless = false + } : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + // expected-remark@-5 {{ID: 1}} + func.return %arg0 : tensor<1xf32> + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Sinks: {1}}} +} + +// CHECK-LABEL: func @assert_within_if +func.func @assert_within_if(%arg0: tensor) -> tensor { + // expected-remark@above {{ID: 5}} + %one = "tf.Const"() { value = dense<1.0> : tensor<1xf32> } : () -> tensor<1xf32> + // expected-remark@above {{ID: 0}} + %r1 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 1}} + %result = "tf.StatefulPartitionedCall"(%one, %one) {config="", config_proto="", executor_type="", f=@intermediary} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // expected-remark@above {{ID: 2}} + %r2 = "tf.ReadVariableOp"(%arg0) : (tensor) -> tensor<1xf32> + // expected-remark@above {{ID: 3}} + func.return %arg0 : tensor + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Sinks: {1,3}}} +} From 2b73c0203f83bad5878e0d225642eeec21ee45b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 16:29:04 -0800 Subject: [PATCH 254/391] Updates the test_config_proto visibility so that tensorflow_serving can have access to test_config_proto_cc_impl. PiperOrigin-RevId: 583515351 --- tensorflow/core/tfrt/graph_executor/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 38653fe09f1c0a..e301153261735d 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -205,6 +205,7 @@ tf_proto_library( name = "test_config_proto", testonly = True, srcs = ["test_config.proto"], + visibility = ["//visibility:public"], ) tf_cc_test( From 8c0b7d02793b5b5917dede442da4e1d3eb1e349f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 17:06:23 -0800 Subject: [PATCH 255/391] Narrow down the internal visibility of test_config_proto while keep the OSS visibility to public. PiperOrigin-RevId: 583522726 --- tensorflow/core/tfrt/graph_executor/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index e301153261735d..d99f3519d05206 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -205,7 +205,10 @@ tf_proto_library( name = "test_config_proto", testonly = True, srcs = ["test_config.proto"], - visibility = ["//visibility:public"], + visibility = if_google( + [":friends"], + ["//visibility:public"], + ), ) tf_cc_test( From dd0af620170ee30b8b1051773a103027d6967a3f Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Fri, 17 Nov 2023 17:25:43 -0800 Subject: [PATCH 256/391] [XLA] Fix algebraic simplifier to not mess with async DUS/DS. PiperOrigin-RevId: 583526034 --- .../xla/xla/service/algebraic_simplifier.cc | 8 ++++ .../xla/service/algebraic_simplifier_test.cc | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index 166d8ec7de046f..57f46a51316f7b 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -6102,6 +6102,10 @@ Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) { Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* dynamic_slice) { + // Skip optimizations for async dynamic-slices. + if (dynamic_slice->parent()->IsAsyncComputation()) { + return OkStatus(); + } auto operand = dynamic_slice->mutable_operand(0); if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); @@ -6358,6 +6362,10 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + // Skip optimizations for async dynamic update slices + if (dynamic_update_slice->parent()->IsAsyncComputation()) { + return OkStatus(); + } // Rewriting DynamicUpdateSlice when it matches // dynamic_update_slice(broadcast(constant),data,constant_index0,...) // to a Pad(x, constant) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index fbbfff2a584049..a156077655ab1e 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -9985,6 +9985,51 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfBroadcastSkipped) { EXPECT_FALSE(changed); } +TEST_F(AlgebraicSimplifierTest, DontSinkInstructionsInDUSAsyncComputation) { + const char* kModuleStr = R"( + HloModule m + test { + %param_0 = f32[1]{0} parameter(0) + %param_1 = f32[10]{0} parameter(1) + %constant_1 = s32[] constant(0) + %dynamic-update-slice-start = ((f32[10]{0}, f32[1]{0}, s32[]), + f32[10]{0}, u32[]) dynamic-update-slice-start(f32[10]{0} %param_1, + f32[1]{0} %param_0, s32[] %constant_1) + ROOT %dynamic-update-slice-done = + f32[10]{0} dynamic-update-slice-done(((f32[10]{0}, f32[1]{0}, s32[]), + f32[10]{0}, u32[]) %dynamic-update-slice-start) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + bool changed = + RunHloPass(AlgebraicSimplifier(default_options_), m.get()).value(); + SCOPED_TRACE(m->ToString()); + EXPECT_FALSE(changed); +} + +TEST_F(AlgebraicSimplifierTest, DontSinkInstructionsInDSAsyncComputation) { + const char* kModuleStr = R"( + HloModule m + test { + %param_0 = f32[10]{0} parameter(0) + %constant_1 = s32[] constant(0) + %dynamic-slice-start = ((f32[10]{0}, s32[]), f32[1]{0}, u32[]) + dynamic-slice-start(f32[10]{0} %param_0, s32[] %constant_1), + dynamic_slice_sizes={1} + ROOT %dynamic-slice-done = f32[1]{0} + dynamic-slice-done(((f32[10]{0}, s32[]), f32[1]{0}, u32[]) + %dynamic-slice-start), dynamic_slice_sizes={1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + bool changed = + RunHloPass(AlgebraicSimplifier(default_options_), m.get()).value(); + SCOPED_TRACE(m->ToString()); + EXPECT_FALSE(changed); +} + class AlgebraicSimplifierUpcastDowncastTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< From 2b89e5b50b88c1485e19506c9e45b009ce15150c Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Fri, 17 Nov 2023 18:09:38 -0800 Subject: [PATCH 257/391] [xla:gpu] Use xla_gpu_enable_command_buffer flag to control command buffer scheduling pass #6528 PiperOrigin-RevId: 583533980 --- .../service/gpu/command_buffer_scheduling.cc | 35 +++++++++++++------ .../service/gpu/command_buffer_scheduling.h | 4 ++- .../gpu/command_buffer_scheduling_test.cc | 5 ++- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc index 7cf4d4038bd9fa..8c6036c70feae7 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/command_buffer_scheduling.h" #include +#include #include #include @@ -44,11 +45,6 @@ namespace { // category. // 2. Intermediates: Instructions that produce intermediate values that are // used by commands. -bool IsCommand(const HloInstruction* inst) { - // TODO(anlunx): Add support for conditionals and while loops. - return inst->opcode() == HloOpcode::kFusion; -} - bool IsIntermediate(const HloInstruction* inst) { switch (inst->opcode()) { case HloOpcode::kConstant: @@ -79,7 +75,8 @@ constexpr int kMinNumCommands = 2; // subsequences that will be extracted as command buffers. std::vector CommandBufferScheduling::CollectCommandBufferSequences( - const HloInstructionSequence inst_sequence) { + const HloInstructionSequence inst_sequence, + std::function is_command) { struct Accumulator { std::vector sequences; HloInstructionSequence current_seq; @@ -96,10 +93,10 @@ CommandBufferScheduling::CollectCommandBufferSequences( return acc; }; - auto process_instruction = [&start_new_sequence]( + auto process_instruction = [&start_new_sequence, &is_command]( Accumulator* acc, HloInstruction* inst) -> Accumulator* { - if (IsCommand(inst)) { + if (is_command(inst)) { acc->current_seq.push_back(inst); acc->num_commands_in_current_seq += 1; return acc; @@ -239,8 +236,26 @@ StatusOr CommandBufferScheduling::Run( } HloComputation* entry = module->entry_computation(); MoveParametersToFront(entry); - std::vector sequences = - CollectCommandBufferSequences(module->schedule().sequence(entry)); + + absl::flat_hash_set command_types; + for (auto cmd_type_num : + module->config().debug_options().xla_gpu_enable_command_buffer()) { + DebugOptions::CommandBufferCmdType cmd_type = + static_cast(cmd_type_num); + command_types.insert(cmd_type); + } + + std::function is_command = + [&command_types = + std::as_const(command_types)](const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kFusion) { + if (command_types.contains(DebugOptions::FUSION)) return true; + } + return false; + }; + + std::vector sequences = CollectCommandBufferSequences( + module->schedule().sequence(entry), is_command); for (const HloInstructionSequence& seq : sequences) { TF_ASSIGN_OR_RETURN(BuildCommandBufferResult result, diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h index 601e72c9f984bf..ad0844a207c3c7 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ #include +#include #include #include @@ -79,7 +80,8 @@ class CommandBufferScheduling : public HloModulePass { const absl::flat_hash_set& execution_threads) override; static std::vector CollectCommandBufferSequences( - HloInstructionSequence inst_sequence); + HloInstructionSequence inst_sequence, + std::function is_command); static void MoveParametersToFront(HloComputation* computation); struct BuildCommandBufferResult { diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc index 2a556b6f0ef2bf..aa63b7e40d2c25 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc @@ -222,7 +222,10 @@ TEST_F(CommandBufferSchedulingTest, CollectCommandBufferSequence) { EXPECT_EQ(seq.size(), 10); std::vector command_buffer_sequences = - CommandBufferScheduling::CollectCommandBufferSequences(seq); + CommandBufferScheduling::CollectCommandBufferSequences( + seq, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kFusion; + }); EXPECT_EQ(command_buffer_sequences.size(), 2); std::vector seq_0 = From 7ad824fa6fd956aee5fa4261d97392251b61e6a0 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 17 Nov 2023 21:20:18 -0800 Subject: [PATCH 258/391] [xla:gpu] Add no_parallel_custom_call to LMHLO_GPU async collective operations. Also extend XLA HLO to LMHLO conversion to set this flag based on the backend config attached to async collectives in HLO. PiperOrigin-RevId: 583558972 --- .../xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 6 ++++-- .../tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir | 6 ++++-- .../mhlo_to_lhlo_with_xla.cc | 16 ++++++++++++++++ .../tests/hlo_text_to_lhlo_no_opt.hlotxt | 4 +++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index e56d2964d767ac..b2e9b4dff8f450 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/third_party/xla/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -244,7 +244,8 @@ class LHLOGPU_AsyncCollectiveCommunicationOp traits = [ DefaultValuedOptionalAttr:$constrain_layout, OptionalAttr:$channel_id, DefaultValuedOptionalAttr:$use_global_device_ids, - BoolAttr:$is_sync + BoolAttr:$is_sync, + BoolAttr:$no_parallel_custom_call ); } @@ -271,7 +272,8 @@ def LHLOGPU_CollectivePermuteStartOp : Arg:$output, I64ElementsAttr:$source_target_pairs, OptionalAttr:$channel_id, - BoolAttr:$is_sync + BoolAttr:$is_sync, + BoolAttr:$no_parallel_custom_call ); } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir index ca4b2c0be2a117..2370f762db4398 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir @@ -228,7 +228,8 @@ func.func @ag_start(%arg : memref<10x10xf32>, %out: memref<20x10xf32>) { { replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, all_gather_dimension = 0, - is_sync = false + is_sync = false, + no_parallel_custom_call = false } : (memref<10x10xf32>, memref<20x10xf32>) -> (!mhlo.token) func.return @@ -241,7 +242,8 @@ func.func @ag_start_mixed(%arg0 : memref<10x10xf32>, %arg1 : memref<10x10xf16>, { replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, all_gather_dimension = 0, - is_sync = true + is_sync = true, + no_parallel_custom_call = true } : (memref<10x10xf32>, memref<10x10xf16>, memref<20x10xf32>, memref<20x10xf16>) -> (!mhlo.token) func.return diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 1a340c51cb326e..86b741ca76b16b 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -119,6 +119,12 @@ bool IsSyncCollective(const HloInstruction* instr) { return backend_config.is_sync(); } +bool NoParallelCustomCallCollective(const HloInstruction* instr) { + auto backend_config = + instr->backend_config().value(); + return backend_config.no_parallel_custom_call(); +} + // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the // given platform. tsl::Status ConvertHloToLmhlo(std::unique_ptr hlo_module, @@ -1693,6 +1699,8 @@ LhloDialectEmitter::EmitAllToAllStartOp(const xla::HloInstruction* instr) { builder_.getI64IntegerAttr(*all_to_all->split_dimension())); } all_to_all_start_op.setIsSync(IsSyncCollective(instr)); + all_to_all_start_op.setNoParallelCustomCall( + NoParallelCustomCallCollective(instr)); auto [_, was_inserted] = ret_tokens_.insert({instr, all_to_all_start_op.getToken()}); @@ -1728,6 +1736,8 @@ LhloDialectEmitter::EmitAllGatherStartOp(const HloInstruction* instr) { all_gather_start_op.setAllGatherDimensionAttr( builder_.getI64IntegerAttr(all_gather->all_gather_dimension())); all_gather_start_op.setIsSync(IsSyncCollective(instr)); + all_gather_start_op.setNoParallelCustomCall( + NoParallelCustomCallCollective(instr)); auto [_, was_inserted] = ret_tokens_.insert({instr, all_gather_start_op.getToken()}); TF_RET_CHECK(was_inserted) << "all-gather-start already lowered"; @@ -1759,6 +1769,8 @@ LhloDialectEmitter::EmitAllReduceStartOp(const HloInstruction* instr) { all_reduce_start_op.setUseGlobalDeviceIdsAttr( builder_.getBoolAttr(all_reduce->use_global_device_ids())); all_reduce_start_op.setIsSync(IsSyncCollective(instr)); + all_reduce_start_op.setNoParallelCustomCall( + NoParallelCustomCallCollective(instr)); TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( *instr->called_computations()[0], symbol_table_, @@ -1833,6 +1845,8 @@ LhloDialectEmitter::EmitReduceScatterStartOp(const xla::HloInstruction* instr) { reduce_scatter_start_op.setScatterDimensionAttr( builder_.getI64IntegerAttr(reduce_scatter->scatter_dimension())); reduce_scatter_start_op.setIsSync(IsSyncCollective(instr)); + reduce_scatter_start_op.setNoParallelCustomCall( + NoParallelCustomCallCollective(instr)); TF_RETURN_IF_ERROR(xla::HloFunctionImporter::ImportAsRegion( *reduce_scatter->to_apply(), symbol_table_, &reduce_scatter_start_op.getComputation(), &builder_)); @@ -1871,6 +1885,8 @@ LhloDialectEmitter::EmitCollectivePermuteStartOp(const HloInstruction* instr) { permute_start_op->setAttr(source_target_pairs_attr.getName(), source_target_pairs_attr.getValue()); permute_start_op.setIsSync(IsSyncCollective(instr)); + permute_start_op.setNoParallelCustomCall( + NoParallelCustomCallCollective(instr)); auto [_, was_inserted] = ret_tokens_.insert({instr, permute_start_op.getToken()}); diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt index a1722f6d03326f..6af4c9df2f4da0 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/hlo_text_to_lhlo_no_opt.hlotxt @@ -705,6 +705,7 @@ HloModule TestAllGatherAsyncWithSyncFlagFalse // CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < // CHECK-SAME: all_gather_dimension = 1 : i64 // CHECK-SAME: is_sync = false +// CHECK-SAME: no_parallel_custom_call = false // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> // CHECK-SAME: use_global_device_ids = false // CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) @@ -723,12 +724,13 @@ HloModule TestAllGatherAsyncWithSyncFlagTrue // CHECK: %[[TOKEN:.*]] = "lmhlo_gpu.all_gather_start"(%{{.*}}, %{{.*}}) < // CHECK-SAME: all_gather_dimension = 1 : i64 // CHECK-SAME: is_sync = true +// CHECK-SAME: no_parallel_custom_call = true // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> // CHECK-SAME: use_global_device_ids = false // CHECK ""lmhlo_gpu.all_gather_done"(%[[TOKEN]]) ENTRY main { param0 = f32[10,20] parameter(0) ags = (f32[10,20], f32[10,80]) all-gather-start(param0), replica_groups={{0,1,2,3}}, - dimensions={1}, backend_config="{\"is_sync\":true}" + dimensions={1}, backend_config="{\"is_sync\":true, \"no_parallel_custom_call\":true}" ROOT ag = f32[10,80] all-gather-done(ags) } From 428feb029229fe7f3f1f872d69c36dafb131e623 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Nov 2023 22:44:00 -0800 Subject: [PATCH 259/391] Internal Code Change PiperOrigin-RevId: 583568447 --- .../transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index ab9a56e3b6db8e..3e3e8db504f1da 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -50,7 +50,6 @@ namespace { using mlir::DialectRegistry; using mlir::MLIRContext; using mlir::ModuleOp; -using mlir::OpPassManager; using mlir::OwningOpRef; using mlir::func::FuncOp; using ::tensorflow::monitoring::testing::CellReader; From 7cb20d0012529ff222042ac69d3cf05b48109846 Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 17 Nov 2023 23:21:26 -0800 Subject: [PATCH 260/391] Store the data size instead of the allocation size. This allows a number of small improvements: - Avoid allocating memory when the required size is 0 bytes. - When resizing, copy only the "useful" data, not the whole allocation (also simplify the copy logic). PiperOrigin-RevId: 583572792 --- tensorflow/lite/BUILD | 6 ++--- tensorflow/lite/arena_planner_test.cc | 8 +++--- tensorflow/lite/simple_memory_arena.cc | 27 +++++++++------------ tensorflow/lite/simple_memory_arena.h | 12 +++++---- tensorflow/lite/simple_memory_arena_test.cc | 6 +++-- 5 files changed, 30 insertions(+), 29 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index b70167cce5a671..d1d3d6bb123958 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -254,9 +254,10 @@ cc_test( ":arena_planner_with_profiler", ":builtin_ops", ":graph_info", - "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/testing:util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) @@ -1051,7 +1052,6 @@ cc_test( deps = [ ":simple_memory_arena", "//tensorflow/lite/core/c:common", - "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index 2a434d734f0ec9..7303e24b1636a8 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -25,11 +26,12 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/graph_info.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { @@ -1079,7 +1081,7 @@ TEST_F(ArenaPlannerTest, SimpleProfilerTest) { SetGraph(&graph); Execute(0, graph.nodes().size() - 1); - EXPECT_EQ(gNumAlloc, 2); + EXPECT_EQ(gNumAlloc, 1); EXPECT_EQ(gNumDealloc, 0); Destroy(); EXPECT_EQ(gNumDealloc, 2); diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index 9c6a596ed82d10..7541788b4f2b01 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -44,11 +44,11 @@ T AlignTo(size_t alignment, T offset) { namespace tflite { bool ResizableAlignedBuffer::Resize(size_t new_size) { - const size_t new_allocation_size = RequiredAllocationSize(new_size); - if (new_allocation_size <= allocation_size_) { + if (new_size <= data_size_) { // Skip reallocation when resizing down. return false; } + const size_t new_allocation_size = RequiredAllocationSize(new_size); #ifdef TF_LITE_TENSORFLOW_PROFILER PauseHeapMonitoring(/*pause=*/true); OnTfLiteArenaAlloc(subgraph_index_, reinterpret_cast(this), @@ -57,26 +57,21 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { auto new_buffer = std::unique_ptr(new char[new_allocation_size]); char* new_aligned_ptr = reinterpret_cast( AlignTo(alignment_, reinterpret_cast(new_buffer.get()))); - if (new_size > 0 && allocation_size_ > 0) { + if (new_size > 0 && data_size_ > 0) { // Copy data when both old and new buffers are bigger than 0 bytes. - const size_t new_alloc_alignment_adjustment = - new_aligned_ptr - new_buffer.get(); - const size_t old_alloc_alignment_adjustment = aligned_ptr_ - buffer_.get(); - const size_t copy_amount = - std::min(allocation_size_ - old_alloc_alignment_adjustment, - new_allocation_size - new_alloc_alignment_adjustment); + const size_t copy_amount = std::min(new_size, data_size_); std::memcpy(new_aligned_ptr, aligned_ptr_, copy_amount); } buffer_ = std::move(new_buffer); aligned_ptr_ = new_aligned_ptr; #ifdef TF_LITE_TENSORFLOW_PROFILER - if (allocation_size_ > 0) { + if (data_size_ > 0) { OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - allocation_size_); + RequiredAllocationSize(data_size_)); } #endif - allocation_size_ = new_allocation_size; + data_size_ = new_size; #ifdef TF_LITE_TENSORFLOW_PROFILER PauseHeapMonitoring(/*pause=*/false); #endif @@ -86,10 +81,10 @@ bool ResizableAlignedBuffer::Resize(size_t new_size) { void ResizableAlignedBuffer::Release() { #ifdef TF_LITE_TENSORFLOW_PROFILER OnTfLiteArenaDealloc(subgraph_index_, reinterpret_cast(this), - allocation_size_); + RequiredAllocationSize(data_size_)); #endif buffer_.reset(); - allocation_size_ = 0; + data_size_ = 0; aligned_ptr_ = nullptr; } @@ -235,8 +230,8 @@ TFLITE_ATTRIBUTE_WEAK void DumpArenaInfo( void SimpleMemoryArena::DumpDebugInfo( const std::string& name, const std::vector& execution_plan) const { - tflite::DumpArenaInfo(name, execution_plan, - underlying_buffer_.GetAllocationSize(), active_allocs_); + tflite::DumpArenaInfo(name, execution_plan, underlying_buffer_.GetDataSize(), + active_allocs_); } } // namespace tflite diff --git a/tensorflow/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h index 05bb52e6a225e4..71622530402c11 100644 --- a/tensorflow/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -58,9 +58,7 @@ struct ArenaAllocWithUsageInterval { class ResizableAlignedBuffer { public: explicit ResizableAlignedBuffer(size_t alignment, int subgraph_index) - : allocation_size_(0), - alignment_(alignment), - subgraph_index_(subgraph_index) { + : data_size_(0), alignment_(alignment), subgraph_index_(subgraph_index) { // To silence unused private member warning, only used with // TF_LITE_TENSORFLOW_PROFILER (void)subgraph_index_; @@ -75,8 +73,12 @@ class ResizableAlignedBuffer { // Pointer to the data array. char* GetPtr() const { return aligned_ptr_; } + // Size of the data array (NOT of the allocation). + size_t GetDataSize() const { return data_size_; } // Size of the allocation (NOT of the data array). - size_t GetAllocationSize() const { return allocation_size_; } + size_t GetAllocationSize() const { + return RequiredAllocationSize(data_size_); + } // Alignment of the data array. size_t GetAlignment() const { return alignment_; } @@ -86,7 +88,7 @@ class ResizableAlignedBuffer { } std::unique_ptr buffer_; - size_t allocation_size_; + size_t data_size_; size_t alignment_; char* aligned_ptr_; diff --git a/tensorflow/lite/simple_memory_arena_test.cc b/tensorflow/lite/simple_memory_arena_test.cc index 5763c087b1a319..af5a4d8ed668ea 100644 --- a/tensorflow/lite/simple_memory_arena_test.cc +++ b/tensorflow/lite/simple_memory_arena_test.cc @@ -56,7 +56,7 @@ TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { char* resolved_ptr = nullptr; bool reallocated = false; ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); - ASSERT_TRUE(reallocated); + EXPECT_FALSE(reallocated); // Don't allocate when zero bytes are needed. EXPECT_EQ(resolved_ptr, nullptr); } @@ -349,7 +349,9 @@ TEST_P(BufferAndPlanClearingTest, TestClearBufferAndClearPlan) { // Just committing won't work, allocations need to be made again. ASSERT_EQ(arena.Commit(&reallocated), kTfLiteOk); - ASSERT_TRUE(reallocated); + // There was no allocation, the buffer has 0 bytes (was released) and the high + // water mark is 0 (plan was cleared). + EXPECT_FALSE(reallocated); char* resolved_ptr = nullptr; ASSERT_NE(arena.ResolveAlloc(&context, allocs[0], &resolved_ptr), kTfLiteOk); From a34a5ed71e505c43521558b6d3834d7f4f4a7e20 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 17 Nov 2023 23:41:18 -0800 Subject: [PATCH 261/391] [xla:ffi] Add support for int64_t attributes to XLA FFI PiperOrigin-RevId: 583574834 --- third_party/xla/xla/ffi/api/api.h | 1 + third_party/xla/xla/ffi/api/c_api.h | 5 +++-- third_party/xla/xla/ffi/call_frame.cc | 6 ++++++ third_party/xla/xla/ffi/call_frame.h | 5 +++-- third_party/xla/xla/ffi/ffi_test.cc | 8 ++++++-- third_party/xla/xla/service/gpu/ir_emitter_unnested.cc | 3 +++ .../xla/xla/service/gpu/runtime3/custom_call_thunk.h | 2 +- 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 46105cf58dd602..22c8178b772e5c 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -556,6 +556,7 @@ class Handler : public Ffi { } XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int32_t, XLA_FFI_AttrType_I32); +XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(int64_t, XLA_FFI_AttrType_I64); XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(float, XLA_FFI_AttrType_F32); #undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 3e96b8035296a0..f1945a5756ac8a 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -153,8 +153,9 @@ typedef enum { typedef enum { XLA_FFI_AttrType_I32 = 1, - XLA_FFI_AttrType_F32 = 2, - XLA_FFI_AttrType_STRING = 3, + XLA_FFI_AttrType_I64 = 2, + XLA_FFI_AttrType_F32 = 3, + XLA_FFI_AttrType_STRING = 4, } XLA_FFI_AttrType; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index e828ab202ade1c..4c79a31e109e50 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -47,6 +47,10 @@ void CallFrameBuilder::AddI32Attr(std::string name, int32_t value) { attrs_.try_emplace(std::move(name), value); } +void CallFrameBuilder::AddI64Attr(std::string name, int64_t value) { + attrs_.try_emplace(std::move(name), value); +} + void CallFrameBuilder::AddF32Attr(std::string name, float value) { attrs_.try_emplace(std::move(name), value); } @@ -229,6 +233,8 @@ struct CallFrame::FixupAttribute { struct CallFrame::AttributeType { XLA_FFI_AttrType operator()(int32_t&) { return XLA_FFI_AttrType_I32; } + XLA_FFI_AttrType operator()(int64_t&) { return XLA_FFI_AttrType_I64; } + XLA_FFI_AttrType operator()(float&) { return XLA_FFI_AttrType_F32; } XLA_FFI_AttrType operator()(CallFrame::String&) { diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 2f5327f40bca57..a66676c4d191a1 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -42,7 +42,7 @@ class CallFrame; // forward declare class CallFrameBuilder { public: - using Attribute = std::variant; + using Attribute = std::variant; using AttributesMap = absl::flat_hash_map; CallFrame Build(); @@ -51,6 +51,7 @@ class CallFrameBuilder { absl::Span dims); void AddI32Attr(std::string name, int32_t value); + void AddI64Attr(std::string name, int64_t value); void AddF32Attr(std::string name, float value); void AddStringAttr(std::string name, std::string value); @@ -91,7 +92,7 @@ class CallFrame { struct NamedAttribute; struct String; - using Attribute = std::variant; + using Attribute = std::variant; CallFrame(absl::Span args, const CallFrameBuilder::AttributesMap& attrs); diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index cd93c3900ca324..1e598dd503a4d6 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -103,14 +103,18 @@ TEST(FfiTest, BuiltinAttributes) { TEST(FfiTest, DecodingErrors) { CallFrameBuilder builder; builder.AddI32Attr("i32", 42); + builder.AddI64Attr("i64", 42); builder.AddF32Attr("f32", 42.0f); builder.AddStringAttr("str", "foo"); auto call_frame = builder.Build(); - auto fn = [](int32_t, float, std::string_view) { return absl::OkStatus(); }; + auto fn = [](int32_t, int64_t, float, std::string_view) { + return absl::OkStatus(); + }; auto handler = Ffi::Bind() .Attr("not_i32_should_fail") + .Attr("not_i64_should_fail") .Attr("f32") .Attr("not_str_should_fail") .To(fn); @@ -119,7 +123,7 @@ TEST(FfiTest, DecodingErrors) { ASSERT_EQ( status.message(), - "Failed to decode all FFI handler operands (bad operands at: 0, 2)"); + "Failed to decode all FFI handler operands (bad operands at: 0, 1, 3)"); } TEST(FfiTest, BufferArgument) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 9921870f1b6853..732f5d0fb9bcdd 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1623,6 +1623,9 @@ static StatusOr BuildAttributesMap( case 32: attributes[name] = static_cast(integer.getInt()); return OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return OkStatus(); default: return absl::InvalidArgumentError(absl::StrCat( "Unsupported integer attribute bit width for attribute: ", name)); diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h index f544485edd8f87..a00cedcccf8547 100644 --- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.h @@ -68,7 +68,7 @@ class CustomCallThunk : public Thunk { Shape shape; }; - using Attribute = std::variant; + using Attribute = std::variant; using AttributesMap = absl::flat_hash_map; CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target, From df0708126668a894cc7f8f90b59e32a05a7190d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 18 Nov 2023 01:01:59 -0800 Subject: [PATCH 262/391] compat: Update forward compatibility horizon to 2023-11-18 PiperOrigin-RevId: 583586101 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d2c52044b7901b..f346a0c39e9617 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 17) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 18) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 68330c70b25d1a38c575e145a8f9fb5b2811e53e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 18 Nov 2023 01:02:01 -0800 Subject: [PATCH 263/391] Update GraphDef version to 1684. PiperOrigin-RevId: 583586108 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 9c5b17b70cee85..cff03d82d9340e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1683 // Updated: 2023/11/17 +#define TF_GRAPH_DEF_VERSION 1684 // Updated: 2023/11/18 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 294dff19ff7cb7b52a99a1adab1a8f4e6b2cb032 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Sat, 18 Nov 2023 01:38:07 -0800 Subject: [PATCH 264/391] [xla:ffi] Add API for accessing a variable number of arguments in an XLA FFI handler PiperOrigin-RevId: 583591842 --- third_party/xla/xla/ffi/api/api.h | 69 ++++++++++++++++++++++++++--- third_party/xla/xla/ffi/ffi_test.cc | 21 +++++++++ 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 22c8178b772e5c..f5f638ce5bdc60 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -147,6 +148,9 @@ XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, // Type tags for distinguishing handler argument types //===----------------------------------------------------------------------===// +// Forward declare class defined below. +class RemainingArgs; + namespace internal { // A type tag to distinguish arguments tied to the attributes in the @@ -158,6 +162,10 @@ struct AttrTag {}; template struct CtxTag {}; +// Checks if remaining arguments are in the parameter pack. +template +using HasRemainingArgs = std::disjunction...>; + } // namespace internal //===----------------------------------------------------------------------===// @@ -172,6 +180,12 @@ class Binding { return {std::move(*this)}; } + Binding RemainingArgs() && { + static_assert(!internal::HasRemainingArgs::value, + "remaining arguments can be passed just once"); + return {std::move(*this)}; + } + template Binding> Ctx() && { return {std::move(*this)}; @@ -348,6 +362,42 @@ struct Decode> { } // namespace internal +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing a variable number of arguments. +//===----------------------------------------------------------------------===// + +class RemainingArgs { + public: + RemainingArgs(const XLA_FFI_Args* args, size_t offset) + : args_(args), offset_(offset) { + assert(offset <= args_->num_args && "illegal remaining args offset"); + } + + size_t size() const { return args_->num_args - offset_; } + bool empty() const { return size() == 0; } + + template + std::optional get(size_t index) const { + size_t idx = offset_ + index; + if (idx >= args_->num_args) { + return std::nullopt; + } + return ArgDecoding::Decode(args_->types[idx], args_->args[idx]); + } + + private: + const XLA_FFI_Args* args_; // not owned + size_t offset_; +}; + +template <> +struct internal::Decode { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx) { + return RemainingArgs(&ctx.call_frame->args, offsets.args); + } +}; + //===----------------------------------------------------------------------===// // Template metaprogramming for decoding handler signature //===----------------------------------------------------------------------===// @@ -445,11 +495,20 @@ class Handler : public Ffi { // Check that the number of passed arguments matches the signature. Each // individual argument decoding will check the actual type. - if (call_frame->args.num_args != kNumArgs) { - return InvalidArgument( - call_frame->api, - StrCat("Wrong number of arguments: expected ", kNumArgs, " but got ", - call_frame->args.num_args)); + if (internal::HasRemainingArgs::value) { + if (call_frame->args.num_args < kNumArgs - 1) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of arguments: expected at least ", + kNumArgs - 1, " but got ", call_frame->args.num_args)); + } + } else { + if (call_frame->args.num_args != kNumArgs) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of arguments: expected ", kNumArgs, + " but got ", call_frame->args.num_args)); + } } // Check that the number of passed attributes matches the signature. Each diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 1e598dd503a4d6..ce1cbe4289c9ee 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -147,6 +147,27 @@ TEST(FfiTest, BufferArgument) { TF_ASSERT_OK(status); } +TEST(FfiTest, RemainingArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto fn = [&](RemainingArgs args) { + EXPECT_EQ(args.size(), 1); + EXPECT_TRUE(args.get(0).has_value()); + EXPECT_FALSE(args.get(1).has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Arg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, RunOptionsCtx) { auto call_frame = CallFrameBuilder().Build(); auto* expected = reinterpret_cast(0x01234567); From 0ac768e43c17837d6deb685d2a47baf6b4db5857 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 18 Nov 2023 02:39:19 -0800 Subject: [PATCH 265/391] Lazily compute `xla::Executable::hlo_proto()` to reduce memory consumption. PiperOrigin-RevId: 583601536 --- third_party/xla/xla/service/executable.h | 25 ++++++++++++++++--- .../xla/xla/service/gpu/gpu_compiler.cc | 1 - third_party/xla/xla/service/service.cc | 8 +++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index dc52a64bb81465..dcda9507fc0e3b 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -379,7 +379,19 @@ class Executable { ? module_config().debug_options().xla_dump_hlo_snapshots() : false; } - HloProto const* hlo_proto() const { return hlo_proto_.get(); } + + HloProto const* hlo_proto() const { + if (!hlo_proto_->has_hlo_module()) { + *hlo_proto_->mutable_hlo_module() = module().ToProto(); + } + return hlo_proto_.get(); + } + + const BufferAssignmentProto* buffer_assignment_proto() const { + return hlo_proto_ != nullptr && hlo_proto_->has_buffer_assignment() + ? &hlo_proto_->buffer_assignment() + : nullptr; + } std::string& debug_info() { return debug_info_; } void set_debug_info(const std::string& debug_info) { @@ -403,9 +415,6 @@ class Executable { // for execution. const std::shared_ptr hlo_module_; - // The serialized HLO proto. Non-null only if dumping snapshots is enabled. - std::unique_ptr hlo_proto_; - // Execution count, used to generate a unique filename for each dumped // execution. int64_t execution_count_ = 0; @@ -415,6 +424,14 @@ class Executable { // Generic debug information as a string. std::string debug_info_; + + private: + // The serialized HLO proto. Non-null only if dumping snapshots is enabled. + // This field may also be only partially set: if only + // hlo_proto_->buffer_assignment is set and hlo_proto_->hlo_module isn't, the + // hlo_module proto will be computed on the fly when requested with + // hlo_proto(). This avoids wasting CPU and memory if the proto isn't needed. + std::unique_ptr hlo_proto_; }; } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index d0a7a1807118da..0df21d7a56cfd3 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1740,7 +1740,6 @@ StatusOr> GpuCompiler::RunBackend( // Dump computation proto state and buffer assignment for // CompiledMemoryAnalysis. auto hlo_proto = std::make_unique(); - *hlo_proto->mutable_hlo_module() = gpu_executable->module().ToProto(); *hlo_proto->mutable_buffer_assignment() = gpu_executable->buffer_assignment()->ToProto(); gpu_executable->set_hlo_proto(std::move(hlo_proto)); diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index db251a4cd535d8..48b98e197190d5 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -774,17 +774,19 @@ StatusOr> Service::BuildExecutable( std::unique_ptr executable, backend->compiler()->RunBackend(std::move(module), executor, options)); - const HloProto* hlo_proto_after_opt = executable->hlo_proto(); + const BufferAssignmentProto* buffer_assignment_proto_after_opt = + executable->buffer_assignment_proto(); // If dumping is enabled RunBackend(...) will emit a hlo_proto in the // executable. This contains the buffer_assignment that is only available // after RunBackend(). If hlo_proto_before_opt is not null, then we replace // its buffer_assignment with the one from after_opt and then store it into // the executable. - if (hlo_proto_before_opt != nullptr && hlo_proto_after_opt != nullptr) { + if (hlo_proto_before_opt != nullptr && + buffer_assignment_proto_after_opt != nullptr) { CHECK(DumpingEnabledForHloModule(executable->module())); *hlo_proto_before_opt->mutable_buffer_assignment() = - hlo_proto_after_opt->buffer_assignment(); + std::move(*buffer_assignment_proto_after_opt); executable->set_hlo_proto(std::move(hlo_proto_before_opt)); } return std::move(executable); From 7befb8072ea121546d5f6b4e9837bda84fe43cad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 18 Nov 2023 05:07:54 -0800 Subject: [PATCH 266/391] [XLA] Move `CopyDenseElementsDataToXlaFormat` closer to the ir_emitter - the only place that function is used. This is a pure refactoring in preparation for some memory optimizations. PiperOrigin-RevId: 583623116 --- third_party/xla/xla/service/gpu/BUILD | 23 ++- .../xla/xla/service/gpu/ir_emission_utils.cc | 140 ++++++++++++++++++ .../xla/xla/service/gpu/ir_emission_utils.h | 5 + .../xla/service/gpu/ir_emission_utils_test.cc | 52 +++++++ .../xla/translate/hlo_to_mhlo/hlo_utils.cc | 99 ------------- .../xla/xla/translate/hlo_to_mhlo/hlo_utils.h | 3 - .../translate/hlo_to_mhlo/hlo_utils_test.cc | 33 ----- 7 files changed, 219 insertions(+), 136 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a9a6d3d540287b..027bab5e98735f 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1069,6 +1069,10 @@ cc_library( ":hlo_traversal", ":target_util", "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1081,13 +1085,22 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/translate/mhlo_to_hlo:type_to_shape", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:ml_dtypes", ], ) @@ -1095,15 +1108,23 @@ xla_cc_test( name = "ir_emission_utils_test", srcs = ["ir_emission_utils_test.cc"], deps = [ - ":hlo_traversal", ":ir_emission_utils", + "//xla:literal", + "//xla:literal_util", + "//xla:types", "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", "//xla/mlir_hlo:lhlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/translate/hlo_to_mhlo:hlo_utils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 27f7a2bc6e4e04..3f94385d270e52 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -16,35 +16,75 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include +#include +#include #include +#include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/FPEnv.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/target_util.h" #include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { namespace gpu { @@ -965,5 +1005,105 @@ bool IsAMDGPU(const llvm::Module* module) { return llvm::Triple(module->getTargetTriple()).isAMDGPU(); } +namespace { +template +void CopyDenseElementsBy(mlir::DenseElementsAttr data, + std::vector* output) { + output->resize(data.getNumElements() * sizeof(T)); + int64_t i = 0; + for (T element : data.getValues()) { + std::memcpy(&(*output)[i], &element, sizeof(T)); + i += sizeof(T); + } +} + +template <> +void CopyDenseElementsBy(mlir::DenseElementsAttr data, + std::vector* output) { + output->resize(CeilOfRatio(data.getNumElements(), int64_t{2})); + absl::Span output_span = + absl::MakeSpan(reinterpret_cast(output->data()), output->size()); + PackInt4(data.getRawData(), output_span); +} +} // namespace + +Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, + std::vector* output) { + mlir::Type element_type = data.getType().getElementType(); + + // TODO(hinsu): Support remaining XLA primitive types. + if (element_type.isInteger(1)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isInteger(4)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isInteger(8)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isInteger(16)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isInteger(32)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isInteger(64)) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E5M2()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E4M3FN()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E4M3B11FNUZ()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E5M2FNUZ()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isFloat8E4M3FNUZ()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isBF16()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isF16()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isF32()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (element_type.isF64()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (auto complex_type = element_type.dyn_cast()) { + if (complex_type.getElementType().isF32()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + if (complex_type.getElementType().isF64()) { + CopyDenseElementsBy(data, output); + return OkStatus(); + } + } + return Internal("Unsupported type in CopyDenseElementsDataToXlaFormat"); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 6ae8a2a61df7f3..80e7b45501ce02 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_ #define XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_ +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" @@ -229,6 +231,9 @@ std::string GetIrNameFromLoc(mlir::Location loc); // Whether the module's target is an AMD GPU. bool IsAMDGPU(const llvm::Module* module); +Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, + std::vector* output); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 1eec6e59991d10..4aa4a539e858d9 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -15,12 +15,31 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" +#include +#include +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tests/hlo_test_base.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/types.h" #include "xla/util.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -458,5 +477,38 @@ ENTRY entry { EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); } +TEST_F(IrEmissionUtilsTest, LiteralToAttrToXlaFormat) { + mlir::MLIRContext context; + context.loadDialect(); + mlir::Builder builder(&context); + + // int16 + { + Literal x = LiteralUtil::CreateR2({{0, 1, 2}, {3, 4, 5}}); + TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, + CreateDenseElementsAttrFromLiteral(x, builder)); + + std::vector data; + TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); + for (int i = 0; i < 6; i++) { + int16_t x; + memcpy(&x, &data[i * 2], 2); + EXPECT_EQ(x, i); + } + } + + // int4 + { + Literal x = LiteralUtil::CreateR2( + {{s4(0), s4(1), s4(2)}, {s4(3), s4(4), s4(5)}}); + TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, + CreateDenseElementsAttrFromLiteral(x, builder)); + + std::vector data; + TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); + EXPECT_EQ(data, std::vector({0x01, 0x23, 0x45})); + } +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc index a9c0164bf4b6f7..7c2d2c7fd24345 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -93,27 +93,6 @@ StatusOr GetPermutationIfAvailable(const Shape& shape, return makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext()); } - -template -void CopyDenseElementsBy(mlir::DenseElementsAttr data, - std::vector* output) { - output->resize(data.getNumElements() * sizeof(T)); - int64_t i = 0; - for (T element : data.getValues()) { - std::memcpy(&(*output)[i], &element, sizeof(T)); - i += sizeof(T); - } -} - -template <> -void CopyDenseElementsBy(mlir::DenseElementsAttr data, - std::vector* output) { - output->resize(CeilOfRatio(data.getNumElements(), int64_t{2})); - absl::Span output_span = - absl::MakeSpan(reinterpret_cast(output->data()), output->size()); - PackInt4(data.getRawData(), output_span); -} - } // namespace StatusOr ConvertTensorShapeToMemRefType( @@ -152,84 +131,6 @@ StatusOr CreateDenseElementsAttrFromLiteral( element_type); } -Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, - std::vector* output) { - mlir::Type element_type = data.getType().getElementType(); - - // TODO(hinsu): Support remaining XLA primitive types. - if (element_type.isInteger(1)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(4)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(8)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(16)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(32)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(64)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E5M2()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E4M3FN()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E4M3B11FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E5M2FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E4M3FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isBF16()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF16()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF32()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF64()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (auto complex_type = element_type.dyn_cast()) { - if (complex_type.getElementType().isF32()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (complex_type.getElementType().isF64()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - } - return Internal("Unsupported type in CopyDenseElementsDataToXlaFormat"); -} - StatusOr GetElementTypeBytes(mlir::Type type) { if (type.isInteger(1)) { return 1; diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index 275f59a43ebdba..8dd290e0c2d307 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -34,9 +34,6 @@ namespace xla { StatusOr CreateDenseElementsAttrFromLiteral( const LiteralBase& literal, mlir::Builder builder); -Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, - std::vector* output); - StatusOr GetElementTypeBytes(mlir::Type type); // Creates an DenseIntElementsAttr using the elements of the vector and the diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc index ef6ae79b2b5da3..295d656275d883 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -69,38 +69,5 @@ TEST(ConvertTensorShapeToType, Simple) { } } -TEST(LiteralToAttrToXlaFormat, Simple) { - mlir::MLIRContext context; - context.loadDialect(); - mlir::Builder builder(&context); - - // int16 - { - Literal x = LiteralUtil::CreateR2({{0, 1, 2}, {3, 4, 5}}); - TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, - CreateDenseElementsAttrFromLiteral(x, builder)); - - std::vector data; - TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); - for (int i = 0; i < 6; i++) { - int16_t x; - memcpy(&x, &data[i * 2], 2); - EXPECT_EQ(x, i); - } - } - - // int4 - { - Literal x = LiteralUtil::CreateR2( - {{s4(0), s4(1), s4(2)}, {s4(3), s4(4), s4(5)}}); - TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, - CreateDenseElementsAttrFromLiteral(x, builder)); - - std::vector data; - TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); - EXPECT_EQ(data, std::vector({0x01, 0x23, 0x45})); - } -} - } // namespace } // namespace xla From 0e8478ffbc60b6e5670cd8bac42be0f854a464b6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 18 Nov 2023 14:13:43 -0800 Subject: [PATCH 267/391] [xla:ffi] Add basic public FFI API and a test PiperOrigin-RevId: 583678161 --- third_party/xla/xla/ffi/BUILD | 31 +++- third_party/xla/xla/ffi/api/BUILD | 40 +++++ third_party/xla/xla/ffi/api/c_api.h | 1 + third_party/xla/xla/ffi/api/ffi.h | 160 ++++++++++++++++-- third_party/xla/xla/ffi/api/ffi_test.cc | 75 ++++++++ third_party/xla/xla/ffi/ffi.h | 35 ---- .../xla/xla/ffi/{ffi.cc => ffi_api.cc} | 12 +- third_party/xla/xla/ffi/ffi_api.h | 73 ++++++++ third_party/xla/xla/ffi/ffi_test.cc | 1 + third_party/xla/xla/service/gpu/BUILD | 3 +- .../xla/xla/service/gpu/custom_call_test.cc | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 2 +- .../xla/xla/service/gpu/runtime3/BUILD | 2 +- .../service/gpu/runtime3/custom_call_thunk.cc | 2 +- 14 files changed, 385 insertions(+), 53 deletions(-) create mode 100644 third_party/xla/xla/ffi/api/ffi_test.cc rename third_party/xla/xla/ffi/{ffi.cc => ffi_api.cc} (95%) create mode 100644 third_party/xla/xla/ffi/ffi_api.h diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index af726018ff3db7..abfa265e432b2e 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -18,10 +18,13 @@ cc_library( hdrs = ["call_frame.h"], visibility = ["//visibility:public"], deps = [ + ":api", + "//xla:status", "//xla:types", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", + "//xla/service:executable", "//xla/stream_executor:device_memory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -32,7 +35,6 @@ cc_library( cc_library( name = "ffi", - srcs = ["ffi.cc"], hdrs = ["ffi.h"], visibility = ["//visibility:public"], deps = [ @@ -56,6 +58,32 @@ cc_library( ], ) +cc_library( + name = "ffi_api", + srcs = ["ffi_api.cc"], + hdrs = ["ffi_api.h"], + visibility = ["//visibility:public"], + deps = [ + ":api", + ":call_frame", + "//xla:status", + "//xla:statusor", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/ffi/api:c_api", + "//xla/ffi/api:c_api_internal", + "//xla/runtime:memref_view", + "//xla/service:executable", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + xla_cc_test( name = "ffi_test", srcs = ["ffi_test.cc"], @@ -63,6 +91,7 @@ xla_cc_test( ":api", ":call_frame", ":ffi", + ":ffi_api", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/service:executable", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index d18f89ed489303..9f9eea78f4082f 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -1,3 +1,4 @@ +load("//xla:xla.bzl", "xla_cc_test") load("@local_tsl//tsl:tsl.default.bzl", "filegroup") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") @@ -5,6 +6,22 @@ package( default_visibility = ["//visibility:public"], ) +#===-------------------------------------------------------------------------------------------===// +# Public XLA FFI API +#===-------------------------------------------------------------------------------------------===// + +# XLA FFI is a header only library that does not have any dependencies on XLA. The intent is that +# users that do want to register custom FFI handlers with XLA should copy these headers to their +# project, build a shared object with an XLA FFI handler implementation, and load it at run time. +# +# `api.h` and `ffi.h` headers provide a C++ library for decoding XLA FFI C API structs into a more +# user friendly C++ types. Shared objects defining XLA FFI handlers should be built with private +# symbol visibility to avoid potential ODR violations coming from template instantiations of +# different XLA FFI versions. +# +# `ffi.h` defines builtin decoding for canonical XLA types, but users can add their own decodings +# with template specializations. + filegroup( name = "api_headers", srcs = ["api.h"], @@ -46,3 +63,26 @@ cc_library( ":c_api", ], ) + +#===-------------------------------------------------------------------------------------------===// +# Internal tests for XLA FFI API +#===-------------------------------------------------------------------------------------------===// + +xla_cc_test( + name = "ffi_test", + srcs = ["ffi_test.cc"], + deps = [ + ":api", + ":ffi", + "//xla:xla_data_proto_cc", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", + "//xla/service:executable", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index f1945a5756ac8a..9c6ac426affadc 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -84,6 +84,7 @@ typedef struct XLA_FFI_Error XLA_FFI_Error; // Codes are based on https://abseil.io/docs/cpp/guides/status-codes typedef enum { + XLA_FFI_Error_Code_OK = 0, XLA_FFI_Error_Code_CANCELLED = 1, XLA_FFI_Error_Code_UNKNOWN = 2, XLA_FFI_Error_Code_INVALID_ARGUMENT = 3, diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 9a4882e656c22b..b56c8ae8a2c6e8 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -20,8 +20,13 @@ limitations under the License. #error Two different XLA FFI implementations cannot be included together #endif // XLA_FFI_API_H_ +#include +#include #include +#include #include +#include +#include #include "xla/ffi/api/c_api.h" @@ -31,17 +36,138 @@ limitations under the License. namespace xla::ffi { -namespace internal { -// TODO(ezhulenev): We need to log error message somewhere, currently we -// silently destroy it. -inline void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) { - XLA_FFI_Error_Destroy_Args destroy_args; - destroy_args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; - destroy_args.priv = nullptr; - destroy_args.error = error; - api->XLA_FFI_Error_Destroy(&destroy_args); -} -} // namespace internal +// Because we can't depend on any of the XLA libraries (any libraries at all +// really) in public XLA FFI API, we have to duplicate some of the enums/types +// widely used in XLA code base, and some of the basic types available in ABSL. + +//===----------------------------------------------------------------------===// +// XLA types +//===----------------------------------------------------------------------===// + +// This enum corresponds to xla::PrimitiveType enum defined in `hlo.proto`. +enum class DataType : uint8_t { + // Invalid primitive type to serve as default. + PRIMITIVE_TYPE_INVALID = 0, + + // Predicates are two-state booleans. + PRED = 1, + + // Signed integral values of fixed width. + S8 = 2, + S16 = 3, + S32 = 4, + S64 = 5, + + // Unsigned integral values of fixed width. + U8 = 6, + U16 = 7, + U32 = 8, + U64 = 9, + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbitrary points in the computation. + F16 = 10, + F32 = 11, + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16, + + F64 = 12, +}; + +//===----------------------------------------------------------------------===// +// Span is non-owning view into contiguous values of type `T`. +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): Replace with `std::span` when C++20 is available. +template +class Span { + public: + Span(T* data, size_t size) : data_(data), size_(size) {} + Span(const std::vector>& vec) // NOLINT + : Span(vec.data(), vec.size()) {} + + T& operator[](size_t index) const { return data_[index]; } + + size_t size() const { return size_; } + + T* begin() const { return data_; } + T* end() const { return data_ + size_; } + + private: + T* data_; + size_t size_; +}; + +//===----------------------------------------------------------------------===// +// Error +//===----------------------------------------------------------------------===// + +class Error { + public: + Error() = default; + Error(XLA_FFI_Error_Code errc, std::string message) + : errc_(errc), message_(std::move(message)) {} + + static Error Success() { return Error(); } + + bool success() const { return errc_ == XLA_FFI_Error_Code_OK; } + bool failure() const { return !success(); } + + std::optional errc() const { return errc_; } + const std::string& message() const { return message_; } + + private: + XLA_FFI_Error_Code errc_; + std::string message_; +}; + +//===----------------------------------------------------------------------===// +// Arguments +//===----------------------------------------------------------------------===// + +struct BufferBase { + DataType primitive_type; + void* data; + Span dimensions; +}; + +//===----------------------------------------------------------------------===// +// Arguments decoding +//===----------------------------------------------------------------------===// + +template <> +struct ArgDecoding { + static std::optional Decode(XLA_FFI_ArgType type, void* arg) { + if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; + auto* buf = reinterpret_cast(arg); + + return BufferBase{static_cast(buf->primitive_type), buf->data, + Span(buf->dims, buf->rank)}; + } +}; + +//===----------------------------------------------------------------------===// +// Result encoding +//===----------------------------------------------------------------------===// + +template <> +struct ResultEncoding { + static XLA_FFI_Error* Encode(XLA_FFI_Api* api, Error error) { + if (error.success()) return nullptr; + + XLA_FFI_Error_Create_Args args; + args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE; + args.priv = nullptr; + args.errc = *error.errc(); + args.message = error.message().c_str(); + return api->XLA_FFI_Error_Create(&args); + } +}; //===----------------------------------------------------------------------===// // PlatformStream @@ -65,12 +191,22 @@ struct CtxDecoding> { args.stream = nullptr; if (XLA_FFI_Error* error = api->XLA_FFI_Stream_Get(&args); error) { - internal::DestroyError(api, error); + DestroyError(api, error); return std::nullopt; } return reinterpret_cast(args.stream); } + + // TODO(ezhulenev): We need to log error message somewhere, currently we + // silently destroy it. + static void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) { + XLA_FFI_Error_Destroy_Args destroy_args; + destroy_args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE; + destroy_args.priv = nullptr; + destroy_args.error = error; + api->XLA_FFI_Error_Destroy(&destroy_args); + } }; } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc new file mode 100644 index 00000000000000..cbe4719292d6c0 --- /dev/null +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/api/ffi.h" + +#include +#include + +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/test.h" + +namespace xla::ffi { + +TEST(FfiTest, DataTypeEnumValue) { + // C API passes primitive type as `uint8_t`, and we need to guarantee that + // PrimitiveType and DataType use same values for all supported data types. + auto encoded = [](auto value) { return static_cast(value); }; + + EXPECT_EQ(encoded(PrimitiveType::PRED), encoded(DataType::PRED)); + + EXPECT_EQ(encoded(PrimitiveType::S8), encoded(DataType::S8)); + EXPECT_EQ(encoded(PrimitiveType::S16), encoded(DataType::S16)); + EXPECT_EQ(encoded(PrimitiveType::S32), encoded(DataType::S32)); + EXPECT_EQ(encoded(PrimitiveType::S64), encoded(DataType::S64)); + + EXPECT_EQ(encoded(PrimitiveType::U8), encoded(DataType::U8)); + EXPECT_EQ(encoded(PrimitiveType::U16), encoded(DataType::U16)); + EXPECT_EQ(encoded(PrimitiveType::U32), encoded(DataType::U32)); + EXPECT_EQ(encoded(PrimitiveType::U64), encoded(DataType::U64)); + + EXPECT_EQ(encoded(PrimitiveType::F16), encoded(DataType::F16)); + EXPECT_EQ(encoded(PrimitiveType::F32), encoded(DataType::F32)); + EXPECT_EQ(encoded(PrimitiveType::F64), encoded(DataType::F64)); + + EXPECT_EQ(encoded(PrimitiveType::BF16), encoded(DataType::BF16)); +} + +TEST(FfiTest, BufferArgument) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto fn = [&](BufferBase buffer) { + EXPECT_EQ(buffer.data, storage.data()); + EXPECT_EQ(buffer.primitive_type, DataType::F32); + EXPECT_EQ(buffer.dimensions.size(), 2); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Arg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +} // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 18c9ca83268bc7..b335c3d61b025f 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -22,7 +22,6 @@ limitations under the License. #include #include -#include // IWYU pragma: begin_exports #include "xla/ffi/api/api.h" @@ -31,11 +30,9 @@ limitations under the License. #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep -#include "xla/ffi/call_frame.h" #include "xla/runtime/memref_view.h" #include "xla/service/service_executable_run_options.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -102,38 +99,6 @@ struct ResultEncoding { } }; -//===----------------------------------------------------------------------===// -// Result encoding -//===----------------------------------------------------------------------===// - -// Takes ownership of the XLA FFI error and returns underlying status. Frees -// `error` if it's not nullptr; returns OK status otherwise. -Status TakeStatus(XLA_FFI_Error* error); - -struct CallOptions { - const ServiceExecutableRunOptions* run_options = nullptr; -}; - -Status Call(Ffi& handler, CallFrame& call_frame, - const CallOptions& options = {}); - -Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, - const CallOptions& options = {}); - -//===----------------------------------------------------------------------===// -// XLA FFI registry -//===----------------------------------------------------------------------===// - -// Returns registered FFI handler for a given name, or an error if it's not -// found in the static registry. -StatusOr FindHandler(std::string_view name); - -//===----------------------------------------------------------------------===// -// XLA FFI Api Implementation -//===----------------------------------------------------------------------===// - -XLA_FFI_Api* GetXlaFfiApi(); - } // namespace xla::ffi #endif // XLA_FFI_FFI_H_ diff --git a/third_party/xla/xla/ffi/ffi.cc b/third_party/xla/xla/ffi/ffi_api.cc similarity index 95% rename from third_party/xla/xla/ffi/ffi.cc rename to third_party/xla/xla/ffi/ffi_api.cc index 6f2b02cda3867f..303ba5ea6cdf05 100644 --- a/third_party/xla/xla/ffi/ffi.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include #include @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/ffi/api/api.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/call_frame.h" @@ -47,6 +48,13 @@ struct XLA_FFI_ExecutionContext { namespace xla::ffi { +//===----------------------------------------------------------------------===// +// Calling XLA FFI handlers +//===----------------------------------------------------------------------===// + +// WARNING: These functions defined in `call_frame.h` as we need to make them +// available without having to depend on `ffi.h` header. + Status TakeStatus(XLA_FFI_Error* error) { if (error == nullptr) return absl::OkStatus(); Status status = std::move(error->status); @@ -121,6 +129,8 @@ static Status ActualStructSizeIsGreaterOrEqual(std::string_view struct_name, static absl::StatusCode ToStatusCode(XLA_FFI_Error_Code errc) { switch (errc) { + case XLA_FFI_Error_Code_OK: + return absl::StatusCode::kOk; case XLA_FFI_Error_Code_CANCELLED: return absl::StatusCode::kCancelled; case XLA_FFI_Error_Code_UNKNOWN: diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h new file mode 100644 index 00000000000000..b127ed0b2b5b65 --- /dev/null +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_FFI_FFI_API_H_ +#define XLA_FFI_FFI_API_H_ + +#include + +#include "xla/ffi/api/api.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep +#include "xla/ffi/call_frame.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/status.h" +#include "xla/statusor.h" + +namespace xla::ffi { + +// This is an implementation of XLA FFI API defined in `api/c_api.h` header. It +// should be linked statically into the "main" XLA binary, and third party FFI +// handlers can be linked and registered dynamically. +// +// FFI handlers registered statically (and built from the same XLA commit with +// the same toolchain) can also use `api/c_api_internal.h` to get access to +// various internal data structures. + +//===----------------------------------------------------------------------===// +// Calling XLA FFI handlers +//===----------------------------------------------------------------------===// + +struct CallOptions { + const ServiceExecutableRunOptions* run_options = nullptr; +}; + +// Takes ownership of the XLA FFI error and returns underlying status. Frees +// `error` if it's not nullptr; returns OK status otherwise. +Status TakeStatus(XLA_FFI_Error* error); + +Status Call(Ffi& handler, CallFrame& call_frame, + const CallOptions& options = {}); + +Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, + const CallOptions& options = {}); + +//===----------------------------------------------------------------------===// +// XLA FFI registry +//===----------------------------------------------------------------------===// + +// Returns registered FFI handler for a given name, or an error if it's not +// found in the static registry. +StatusOr FindHandler(std::string_view name); + +//===----------------------------------------------------------------------===// +// XLA FFI Api Implementation +//===----------------------------------------------------------------------===// + +XLA_FFI_Api* GetXlaFfiApi(); + +} // namespace xla::ffi + +#endif // XLA_FFI_FFI_API_H_ diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index ce1cbe4289c9ee..1d19fee975f220 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/device_memory.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 027bab5e98735f..ef176d7051ed6e 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -136,6 +136,7 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/runtime:custom_call", "//xla/runtime:custom_call_registry", "//xla/runtime:executable", @@ -303,7 +304,7 @@ cc_library( "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/ffi", + "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index d1ed30504594b9..e6a29c01b9b594 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/client/lib/constants.h" #include "xla/client/xla_builder.h" #include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/runtime/custom_call.h" #include "xla/runtime/custom_call_registry.h" #include "xla/runtime/executable.h" diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 732f5d0fb9bcdd..489a367a0c7405 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -75,7 +75,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "xla/ffi/api/c_api.h" -#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index 71616707501cc7..c5d7e682ba52ff 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -153,8 +153,8 @@ cc_library( "//xla:shape_util", "//xla:status", "//xla:util", - "//xla/ffi", "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/service:buffer_assignment", "//xla/service:custom_call_status", diff --git a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc index b53a258a591372..8edbca55309afe 100644 --- a/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/custom_call_thunk.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/call_frame.h" -#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" From e447637210a529c9cd903455118d200c4996cdce Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 18 Nov 2023 15:16:31 -0800 Subject: [PATCH 268/391] Make the SpecifiedLayout class opaque. Also need to enabling pickling to xc.Layout so that AOT serialization continues to work. PiperOrigin-RevId: 583684299 --- third_party/xla/xla/python/xla_client.py | 2 +- third_party/xla/xla/python/xla_compiler.cc | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 9e13fe5e4b0d66..4a476977f18118 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 216 +_version = 217 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 2342ace4ca1c7c..ac0b0933eae6a1 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/python/exceptions.h" #include "xla/python/py_client.h" @@ -299,7 +300,24 @@ void BuildXlaCompilerSubmodule(py::module& m) { const Layout& other) { return layout != other; }) .def("__hash__", [](const Layout& layout) { return absl::HashOf(layout); }) - .def("to_string", &Layout::ToString); + .def("to_string", &Layout::ToString) + .def(py::pickle( + [](const Layout& self) -> py::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return py::make_tuple(py::bytes(result)); + }, + [](py::tuple t) { + LayoutProto result; + result.ParseFromString(t[0].cast()); + return Layout::CreateFromProto(result); + })); py::class_ shape_class(m, "Shape"); shape_class From a438cdd38c2df6427f7e3e4432475cefcd47b476 Mon Sep 17 00:00:00 2001 From: looi Date: Sun, 19 Nov 2023 08:54:26 +0000 Subject: [PATCH 269/391] TFLite GPU: fix certain tests for OSS users After 3dc509f31848c7778dc68fabc59ab39c2e0d1e4a, cannot build for OSS users. Also, EXPECT_OK needs to be replaced with ASSERT_OK. --- tensorflow/lite/delegates/gpu/build_defs.bzl | 8 ++ .../lite/delegates/gpu/cl/kernels/BUILD | 123 ++++++------------ .../lite/delegates/gpu/cl/testing/BUILD | 5 +- .../delegates/gpu/common/tasks/special/BUILD | 7 +- .../tasks/special/conv_pointwise_test.cc | 8 +- 5 files changed, 61 insertions(+), 90 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/build_defs.bzl b/tensorflow/lite/delegates/gpu/build_defs.bzl index d98a201551176a..cdea91aec86507 100644 --- a/tensorflow/lite/delegates/gpu/build_defs.bzl +++ b/tensorflow/lite/delegates/gpu/build_defs.bzl @@ -28,3 +28,11 @@ def tflite_angle_heapcheck_deps(): # copybara:comment_begin(oss-only) return ["@com_google_googletest//:gtest_main"] # copybara:comment_end + +def gtest_main_no_heapcheck_deps(): + # copybara:uncomment_begin(google-only) + # return ["@com_google_googletest//:gtest_main_no_heapcheck"] + # copybara:uncomment_end + # copybara:comment_begin(oss-only) + return ["@com_google_googletest//:gtest_main"] + # copybara:comment_end diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 889b178463e5f1..23ac928214a42d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -2,6 +2,7 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "tf_gpu_tests_tags", ) +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,14 +19,13 @@ cc_test( "notsan", "requires-gpu-nvidia", ], + # TODO(b/279977471) Once b/279347631 is resolved, check for heap again deps = [ ":cl_test", - # TODO(b/279977471) Once b/279347631 is resolved, check for heap again - "@com_google_googletest//:gtest_main_no_heapcheck", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:add_test_util", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -42,8 +42,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:cast_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_library( @@ -76,8 +75,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:concat_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -94,8 +92,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_constants_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -112,8 +109,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_generic_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -130,8 +126,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:conv_weights_converter_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_library( @@ -173,8 +168,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -191,8 +185,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_3x3_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -209,8 +202,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_3x3_thin_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -227,8 +219,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_4x4_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -245,8 +236,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:convolution_transposed_thin_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -263,8 +253,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:cumsum_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -281,8 +270,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -300,8 +288,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_3x3_stride_h2_test_util", "//tensorflow/lite/delegates/gpu/common/tasks:depthwise_conv_3x3_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -318,8 +305,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:elementwise_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -341,8 +327,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/tasks:fully_connected", "//tensorflow/lite/delegates/gpu/common/tasks:fully_connected_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -359,8 +344,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:gather_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -377,8 +361,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:lstm_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -414,8 +397,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:max_unpooling_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -432,8 +414,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:mean_stddev_normalization_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -450,8 +431,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:one_hot_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -468,8 +448,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:padding_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -486,8 +465,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:pooling_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -504,8 +482,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:prelu_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -522,8 +499,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:quantize_and_dequantize_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -540,8 +516,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reduce_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -558,8 +533,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:relu_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -576,8 +550,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:resampler_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -594,8 +567,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -612,8 +584,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:reshape_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -630,8 +601,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:select_v2_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -648,8 +618,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:softmax_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -666,8 +635,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:softmax_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -684,8 +652,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:space_to_depth_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -702,8 +669,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:split_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -720,8 +686,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:strided_slice_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -738,8 +703,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:tile_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -756,8 +720,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:transpose_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -774,8 +737,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:resize_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_test( @@ -792,8 +754,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common/tasks:winograd_test_util", - "@com_google_googletest//:gtest_main_no_heapcheck", - ], + ] + gtest_main_no_heapcheck_deps(), ) test_suite( diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD index 75e36c0c9ca877..e333bb6daf5628 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], @@ -35,8 +37,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/cl/kernels:cl_test", "//tensorflow/lite/delegates/gpu/common:gpu_model_test_util", "//tensorflow/lite/delegates/gpu/common:status", - "@com_google_googletest//:gtest_main_no_heapcheck", # constant buffers leak on nvidia - ], + ] + gtest_main_no_heapcheck_deps(), # constant buffers leak on nvidia ) cc_binary( diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD index 74af3502abc116..864893c195702b 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/lite/delegates/gpu:build_defs.bzl", "gtest_main_no_heapcheck_deps") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], @@ -26,10 +28,9 @@ cc_test( "notsan", "requires-gpu-nvidia", ], + # TODO(b/279977471) Once b/279347631 is resolved, check for heap again deps = [ ":conv_pointwise", - # TODO(b/279977471) Once b/279347631 is resolved, check for heap again - "@com_google_googletest//:gtest_main_no_heapcheck", "//tensorflow/lite/delegates/gpu/cl/kernels:cl_test", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:testing_util", @@ -37,7 +38,7 @@ cc_test( "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common:types", - ], + ] + gtest_main_no_heapcheck_deps(), ) cc_library( diff --git a/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc b/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc index e77a488587df78..0af40dfaf8f07a 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise_test.cc @@ -59,11 +59,11 @@ TEST_F(OpenCLOperationTest, SliceMulMeanConcat) { op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; GPUOperation operation = CreateConvPointwise(op_def, op_attr); - EXPECT_OK(env->ExecuteGPUOperation( + ASSERT_OK(env->ExecuteGPUOperation( {src_tensor, weights_tensor}, std::make_unique(std::move(operation)), BHWC(1, 2, 1, 2), &dst_tensor)); - EXPECT_OK(PointWiseNear({5.5f, 5.5f, 8.5f, 8.5f}, dst_tensor.data, eps)); + ASSERT_OK(PointWiseNear({5.5f, 5.5f, 8.5f, 8.5f}, dst_tensor.data, eps)); } } } @@ -93,11 +93,11 @@ TEST_F(OpenCLOperationTest, SliceMulSumConcat) { op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; GPUOperation operation = CreateConvPointwise(op_def, op_attr); - EXPECT_OK(env->ExecuteGPUOperation( + ASSERT_OK(env->ExecuteGPUOperation( {src_tensor, weights_tensor}, std::make_unique(std::move(operation)), BHWC(1, 2, 1, 2), &dst_tensor)); - EXPECT_OK( + ASSERT_OK( PointWiseNear({11.0f, 11.0f, 17.0f, 17.0f}, dst_tensor.data, eps)); } } From 085b3a101dcb9c973097fb02636254f64ddbf3da Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 19 Nov 2023 01:02:04 -0800 Subject: [PATCH 270/391] Update GraphDef version to 1685. PiperOrigin-RevId: 583747205 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index cff03d82d9340e..199a727c62be14 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1684 // Updated: 2023/11/18 +#define TF_GRAPH_DEF_VERSION 1685 // Updated: 2023/11/19 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From f926abe9a4822ec35825a19fe95162e9ebb92196 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 19 Nov 2023 01:02:14 -0800 Subject: [PATCH 271/391] compat: Update forward compatibility horizon to 2023-11-19 PiperOrigin-RevId: 583747253 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index f346a0c39e9617..91d105f6c88132 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 18) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 19) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From af0aa775051585fc62c8dc2bf245fa938bc54ad3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sun, 19 Nov 2023 10:47:24 -0800 Subject: [PATCH 272/391] [xla:ffi] Add DataType enum to C API PiperOrigin-RevId: 583827999 --- third_party/xla/xla/ffi/BUILD | 1 + third_party/xla/xla/ffi/api/c_api.h | 26 ++++++++++- third_party/xla/xla/ffi/api/ffi.h | 58 +++++++------------------ third_party/xla/xla/ffi/api/ffi_test.cc | 6 +-- third_party/xla/xla/ffi/call_frame.cc | 28 +++++++++++- third_party/xla/xla/ffi/ffi.h | 6 +-- third_party/xla/xla/ffi/ffi_test.cc | 2 +- 7 files changed, 76 insertions(+), 51 deletions(-) diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index abfa265e432b2e..9b1b2cc67b442e 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -28,6 +28,7 @@ cc_library( "//xla/stream_executor:device_memory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 9c6ac426affadc..61de6c9e61f741 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -128,6 +128,30 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Error_Destroy_Args, error); typedef void XLA_FFI_Error_Destroy(XLA_FFI_Error_Destroy_Args* args); +//===----------------------------------------------------------------------===// +// DataType +//===----------------------------------------------------------------------===// + +// This enum corresponds to xla::PrimitiveType enum defined in `xla_data.proto`. +// LINT.IfChange +typedef enum { + XLA_FFI_DataType_INVALID = 0, + XLA_FFI_DataType_PRED = 1, + XLA_FFI_DataType_S8 = 2, + XLA_FFI_DataType_S16 = 3, + XLA_FFI_DataType_S32 = 4, + XLA_FFI_DataType_S64 = 5, + XLA_FFI_DataType_U8 = 6, + XLA_FFI_DataType_U16 = 7, + XLA_FFI_DataType_U32 = 8, + XLA_FFI_DataType_U64 = 9, + XLA_FFI_DataType_F16 = 10, + XLA_FFI_DataType_F32 = 11, + XLA_FFI_DataType_F64 = 12, + XLA_FFI_DataType_BF16 = 16, +} XLA_FFI_DataType; +// LINT.ThenChange(ffi_test.cc) + //===----------------------------------------------------------------------===// // Builtin argument types //===----------------------------------------------------------------------===// @@ -136,8 +160,8 @@ struct XLA_FFI_Buffer { size_t struct_size; void* priv; + XLA_FFI_DataType dtype; void* data; - uint8_t primitive_type; int64_t rank; int64_t* dims; // length == rank }; diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index b56c8ae8a2c6e8..3c0f3f5f50f1dc 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -36,47 +36,21 @@ limitations under the License. namespace xla::ffi { -// Because we can't depend on any of the XLA libraries (any libraries at all -// really) in public XLA FFI API, we have to duplicate some of the enums/types -// widely used in XLA code base, and some of the basic types available in ABSL. - -//===----------------------------------------------------------------------===// -// XLA types -//===----------------------------------------------------------------------===// - -// This enum corresponds to xla::PrimitiveType enum defined in `hlo.proto`. enum class DataType : uint8_t { - // Invalid primitive type to serve as default. - PRIMITIVE_TYPE_INVALID = 0, - - // Predicates are two-state booleans. - PRED = 1, - - // Signed integral values of fixed width. - S8 = 2, - S16 = 3, - S32 = 4, - S64 = 5, - - // Unsigned integral values of fixed width. - U8 = 6, - U16 = 7, - U32 = 8, - U64 = 9, - - // Floating-point values of fixed width. - // - // Note: if f16s are not natively supported on the device, they will be - // converted to f16 from f32 at arbitrary points in the computation. - F16 = 10, - F32 = 11, - - // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit - // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent - // and 7 bits for the mantissa. - BF16 = 16, - - F64 = 12, + INVALID = XLA_FFI_DataType_INVALID, + PRED = XLA_FFI_DataType_PRED, + S8 = XLA_FFI_DataType_S8, + S16 = XLA_FFI_DataType_S16, + S32 = XLA_FFI_DataType_S32, + S64 = XLA_FFI_DataType_S64, + U8 = XLA_FFI_DataType_U8, + U16 = XLA_FFI_DataType_U16, + U32 = XLA_FFI_DataType_U32, + U64 = XLA_FFI_DataType_U64, + F16 = XLA_FFI_DataType_F16, + F32 = XLA_FFI_DataType_F32, + F64 = XLA_FFI_DataType_F64, + BF16 = XLA_FFI_DataType_BF16, }; //===----------------------------------------------------------------------===// @@ -131,7 +105,7 @@ class Error { //===----------------------------------------------------------------------===// struct BufferBase { - DataType primitive_type; + DataType dtype; void* data; Span dimensions; }; @@ -146,7 +120,7 @@ struct ArgDecoding { if (type != XLA_FFI_ArgType_BUFFER) return std::nullopt; auto* buf = reinterpret_cast(arg); - return BufferBase{static_cast(buf->primitive_type), buf->data, + return BufferBase{static_cast(buf->dtype), buf->data, Span(buf->dims, buf->rank)}; } }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index cbe4719292d6c0..5d347130e8cf79 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -28,8 +28,8 @@ limitations under the License. namespace xla::ffi { TEST(FfiTest, DataTypeEnumValue) { - // C API passes primitive type as `uint8_t`, and we need to guarantee that - // PrimitiveType and DataType use same values for all supported data types. + // Verify that xla::PrimitiveType and xla::ffi::DataType use the same + // integer value for encoding data types. auto encoded = [](auto value) { return static_cast(value); }; EXPECT_EQ(encoded(PrimitiveType::PRED), encoded(DataType::PRED)); @@ -61,7 +61,7 @@ TEST(FfiTest, BufferArgument) { auto fn = [&](BufferBase buffer) { EXPECT_EQ(buffer.data, storage.data()); - EXPECT_EQ(buffer.primitive_type, DataType::F32); + EXPECT_EQ(buffer.dtype, DataType::F32); EXPECT_EQ(buffer.dimensions.size(), 2); return Error::Success(); }; diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index 4c79a31e109e50..3a698f7988fb22 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep @@ -170,12 +171,37 @@ CallFrame::~CallFrame() = default; absl::Span bargs) { auto res = std::make_unique(bargs.size()); + // We rely on casting to and from underlying integral type to convert from + // PrimitiveType to XLA FFI DataType, and for safety convert all unknown types + // to invalid type, otherwise we can accidentally cause UB. + auto to_data_type = [](PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::PRIMITIVE_TYPE_INVALID: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::BF16: + return static_cast(primitive_type); + default: + DCHECK(false) << "Unsupported primitive type" << primitive_type; + return XLA_FFI_DataType_INVALID; + } + }; + // Convert call frame builder arguments to call frame arguments. for (const CallFrameBuilder::Buffer& barg : bargs) { Buffer buffer; buffer.dims = barg.dims; buffer.buffer.data = const_cast(barg.memory.opaque()); - buffer.buffer.primitive_type = static_cast(barg.type); + buffer.buffer.dtype = to_data_type(barg.type); buffer.buffer.rank = buffer.dims.size(); res->arguments.push_back(std::move(buffer)); } diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index b335c3d61b025f..b1502ce1b306fb 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -44,14 +44,14 @@ namespace xla::ffi { //===----------------------------------------------------------------------===// struct Buffer { - PrimitiveType primitive_type; + PrimitiveType dtype; se::DeviceMemoryBase data; absl::Span dimensions; // TODO(ezhulenev): Remove this implicit conversion once we'll migrate to FFI // handlers from runtime custom calls. operator runtime::MemrefView() { // NOLINT - return runtime::MemrefView{primitive_type, data.opaque(), dimensions}; + return runtime::MemrefView{dtype, data.opaque(), dimensions}; } }; @@ -66,7 +66,7 @@ struct ArgDecoding { auto* buf = reinterpret_cast(arg); Buffer buffer; - buffer.primitive_type = PrimitiveType(buf->primitive_type); + buffer.dtype = PrimitiveType(buf->dtype); buffer.data = se::DeviceMemoryBase(buf->data); buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank); return buffer; diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 1d19fee975f220..4b6e475eb42b80 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -136,8 +136,8 @@ TEST(FfiTest, BufferArgument) { auto call_frame = builder.Build(); auto fn = [&](Buffer buffer) { + EXPECT_EQ(buffer.dtype, PrimitiveType::F32); EXPECT_EQ(buffer.data.opaque(), storage.data()); - EXPECT_EQ(buffer.primitive_type, PrimitiveType::F32); EXPECT_EQ(buffer.dimensions.size(), 2); return absl::OkStatus(); }; From f2a2cec8c33f9e7b44edbf86a72323cda2e259f2 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Sun, 19 Nov 2023 11:30:26 -0800 Subject: [PATCH 273/391] [XLA:GPU] [NFC] Better error messages when failing Triton autotuning PiperOrigin-RevId: 583832961 --- third_party/xla/xla/service/gpu/ir_emitter_triton.cc | 4 +++- .../xla/xla/service/gpu/ir_emitter_triton_test.cc | 5 +++-- third_party/xla/xla/service/gpu/triton_autotuner.cc | 5 +++-- third_party/xla/xla/service/gpu/triton_autotuner_test.cc | 9 +++++---- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index dc19d7d572e18a..caef52d7d8977e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -1923,7 +1923,9 @@ StatusOr TritonWrapper( .getInt(); VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B"; if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) { - return ResourceExhausted("Shared memory size limit exceeded."); + return absl::ResourceExhaustedError(absl::StrFormat( + "Shared memory size limit exceeded: requested %d, available: %d", + shared_mem_bytes, device_info.shared_memory_per_block_optin())); } TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 946bf2e19fa88e..992bc3e6ff937e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -775,8 +775,9 @@ ENTRY entry { se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, /*minor=*/0}, dev_info, config, &llvm_module, &EmitMatMul, mlir_context), - tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED, - "Shared memory size limit exceeded.")); + tsl::testing::StatusIs( + tsl::error::RESOURCE_EXHAUSTED, + ::testing::HasSubstr("Shared memory size limit exceeded"))); config.block_m = 64; config.block_n = 128; diff --git a/third_party/xla/xla/service/gpu/triton_autotuner.cc b/third_party/xla/xla/service/gpu/triton_autotuner.cc index 8e91a17bc9d07c..4c10204b3a4466 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner.cc @@ -121,9 +121,10 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotunerUtil::Autotune( hlo, config_, [&]() -> StatusOr { if (config_.IsDeviceless()) { - return InternalError( + return absl::InternalError(absl::StrCat( "Expect autotune result cache hit for deviceless " - "compilation."); + "compilation (HLO: ", + hlo->ToString())); } return InternalError("Expect autotune result cache hit."); })); diff --git a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc index 83a001277f8f2e..3bdc62a3691e9a 100644 --- a/third_party/xla/xla/service/gpu/triton_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/triton_autotuner_test.cc @@ -220,10 +220,11 @@ class TritonAutotunerTest : public HloTestBase { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto status_or = HloTestBase::RunHloPass(&pipeline, module.get()); - EXPECT_TRUE(tsl::errors::IsInternal(status_or.status())); - EXPECT_EQ("Expect autotune result cache hit for deviceless compilation.", - status_or.status().message()); + EXPECT_THAT(HloTestBase::RunHloPass(&pipeline, module.get()), + tsl::testing::StatusIs( + tsl::error::INTERNAL, + ::testing::HasSubstr( + "Expect autotune result cache hit for deviceless"))); } }; From c511a099b294095611e05a88daf54ca60da62f17 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Sun, 19 Nov 2023 17:24:33 -0800 Subject: [PATCH 274/391] Remove direct `QuantizationOptions` dependency from `AddQuantizePtqPreCalibrationStablehloPasses`. Instead pass on the `CalibrationOptions` directly which is the only use case of quantization options. PiperOrigin-RevId: 583869927 --- .../mlir/quantization/tensorflow/python/quantize_model.cc | 3 ++- .../compiler/mlir/quantization/tensorflow/quantize_passes.cc | 5 ++--- .../compiler/mlir/quantization/tensorflow/quantize_passes.h | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 72cf30186d3fb3..886486b87b82d4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -550,7 +550,8 @@ absl::StatusOr QuantizePtqModelPreCalibration( /*name=*/kTfQuantPtqPreCalibrationStepStableHloName, /*add_passes_func=*/ [&quantization_options](mlir::PassManager &pm) { - AddQuantizePtqPreCalibrationStablehloPasses(pm, quantization_options); + AddQuantizePtqPreCalibrationStablehloPasses( + pm, quantization_options.calibration_options()); }, context, *module_ref)); } else { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 36ebd0d9ccc297..f865866622300c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -232,12 +232,11 @@ void AddQuantizePtqPostCalibrationPasses( // StableHLO Quantization passes that are ran if StableHLO opset is selected. void AddQuantizePtqPreCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options) { + mlir::PassManager &pm, const CalibrationOptions &calibration_options) { pm.addPass( mlir::quant::stablehlo::createLiftQuantizableSpotsAsFunctionsPass()); pm.addNestedPass( - mlir::quant::CreateInsertCustomAggregationOpsPass( - quantization_options.calibration_options())); + mlir::quant::CreateInsertCustomAggregationOpsPass(calibration_options)); pm.addPass(mlir::quant::CreateIssueIDsOfCustomAggregationOpsPass()); // NOMUTANTS -- Add tests after all passes in function below are migrated. // StableHLO Quantizer currently uses TF's calibration passes. Serialize diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h index 2ef01587da0f09..52edb9e8ff0c3c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h @@ -51,7 +51,7 @@ void AddQuantizePtqPostCalibrationPasses( // StableHLO Quantization passes that are ran if StableHLO opset is selected. void AddQuantizePtqPreCalibrationStablehloPasses( - mlir::PassManager &pm, const QuantizationOptions &quantization_options); + mlir::PassManager &pm, const CalibrationOptions &quantization_options); void AddQuantizePtqPostCalibrationStablehloPasses( mlir::PassManager &pm, From cea6d6d4592a4b00eea54e5d3a7e0cd785b3b4f6 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Mon, 20 Nov 2023 00:04:44 -0800 Subject: [PATCH 275/391] PR #6905: [NVIDIA XLA GPU] Expose collective matmul trigger in debug options Imported from GitHub PR https://github.com/openxla/xla/pull/6905 expose threshold_for_windowed_einsum_mib through debug options as the first step of enabling collective matmul for gpu. Copybara import of the project: -- cab214a996384705b0550d764bae5c22aa2b2074 by TJ Xu : Expose collective matmul trigger in debug options Merging this change closes #6905 PiperOrigin-RevId: 583925092 --- third_party/xla/xla/debug_options_flags.cc | 9 ++++++++- third_party/xla/xla/service/gpu/gpu_compiler.cc | 3 ++- .../service/spmd/stateful_rng_spmd_partitioner.h | 16 +++++++++------- .../spmd/stateful_rng_spmd_partitioner_test.cc | 10 ++++++++++ third_party/xla/xla/xla.proto | 5 ++++- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index d46444f7910594..b8673b2a43dd67 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -212,7 +212,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_target_config_filename(""); opts.set_xla_gpu_enable_cub_radix_sort(true); opts.set_xla_gpu_enable_cudnn_layer_norm(false); - + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000); return opts; } @@ -1428,6 +1428,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cub_radix_sort), debug_options->xla_gpu_enable_cub_radix_sort(), "Enable radix sort using CUB for simple shapes")); + flag_list->push_back(tsl::Flag( + "xla_gpu_threshold_for_windowed_einsum_mib", + int64_setter_for( + &DebugOptions::set_xla_gpu_threshold_for_windowed_einsum_mib), + debug_options->xla_gpu_threshold_for_windowed_einsum_mib(), + "Threshold to enable windowed einsum (collective matmul) in MB." + "Default is 100000")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0df21d7a56cfd3..ff0294cafa5baf 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -521,7 +521,8 @@ Status GpuCompiler::OptimizeHloModule(HloModule* hlo_module, /*is_spmd=*/true, /*propagate_metadata=*/false, hlo_module->config().allow_spmd_sharding_propagation_to_output()); spmd_pipeline.AddPass( - num_partitions, hlo_module->config().replica_count()); + num_partitions, hlo_module->config().replica_count(), + debug_options.xla_gpu_threshold_for_windowed_einsum_mib()); spmd_pipeline.AddPass(); TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status()); } else { diff --git a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h index 76ef60a3c3178a..8c51f549bc087e 100644 --- a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h @@ -45,9 +45,11 @@ class StatefulRngSpmdPartitioningVisitor class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { public: - StatefulRngSpmdPartitioner(int64_t num_partitions, int64_t num_replicas) - : spmd::SpmdPartitioner(num_partitions, num_replicas, - GetSpmdPartitionerOptions()) {} + StatefulRngSpmdPartitioner(int64_t num_partitions, int64_t num_replicas, + int64_t threshold_for_windowed_einsum_mib = 100000) + : spmd::SpmdPartitioner( + num_partitions, num_replicas, + GetSpmdPartitionerOptions(threshold_for_windowed_einsum_mib)) {} protected: std::unique_ptr CreateVisitor( @@ -64,12 +66,12 @@ class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { const HloInstruction* hlo) override; private: - static spmd::SpmdPartitionerOptions GetSpmdPartitionerOptions() { + static spmd::SpmdPartitionerOptions GetSpmdPartitionerOptions( + int64_t threshold_for_windowed_einsum_mib) { spmd::SpmdPartitionerOptions options; options.allow_module_signature_change = true; - // Setting windowed einsum threshold to be large to disable it for GPU by - // default. - options.threshold_for_windowed_einsum_mib = 100000; + options.threshold_for_windowed_einsum_mib = + threshold_for_windowed_einsum_mib; return options; } }; diff --git a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc index 56fc6f9d0553b4..7907e533b78618 100644 --- a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc @@ -116,6 +116,16 @@ ENTRY entry { VerifyNoAllReduce(module.get()); } +TEST_F(StatefulRngSpmdPartitionerTest, VerifyThresholdSetCorrectly) { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + int64_t threshold = 400; + debug_options.set_xla_gpu_threshold_for_windowed_einsum_mib(threshold); + StatefulRngSpmdPartitioner rng_spmd_partitioner( + /*num_partitions=*/2, /*num_replicas*/ 1, + debug_options.xla_gpu_threshold_for_windowed_einsum_mib()); + EXPECT_EQ(rng_spmd_partitioner.options().threshold_for_windowed_einsum_mib, + threshold); +} } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index ec3cc81380d4ff..4a9b819ef1f90a 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -672,7 +672,10 @@ message DebugOptions { // Enable radix sort using CUB. bool xla_gpu_enable_cub_radix_sort = 259; - // Next id: 265 + // Threshold to enable windowed einsum (collective matmul) in MB. + int64 xla_gpu_threshold_for_windowed_einsum_mib = 265; + + // Next id: 266 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 6a077215cff10fc5f71c1d5f3a8d1f051a1355c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 00:10:09 -0800 Subject: [PATCH 276/391] Integrate LLVM at llvm/llvm-project@506c47df00bb Updates LLVM usage to match [506c47df00bb](https://github.com/llvm/llvm-project/commit/506c47df00bb) PiperOrigin-RevId: 583926235 --- third_party/llvm/generated.patch | 4571 +++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- .../Dialect/gml_st/cpu_tiling/matmul.mlir | 2 +- .../tests/Dialect/gml_st/lower_vectors.mlir | 3 +- 4 files changed, 4576 insertions(+), 4 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..6220f1e8bd02f2 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,4572 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/compiler-rt/test/msan/vararg_shadow.cpp b/compiler-rt/test/msan/vararg_shadow.cpp +--- a/compiler-rt/test/msan/vararg_shadow.cpp ++++ b/compiler-rt/test/msan/vararg_shadow.cpp +@@ -3,8 +3,8 @@ + // Without -fno-sanitize-memory-param-retval we can't even pass poisoned values. + // RUN: %clangxx_msan -fno-sanitize-memory-param-retval -fsanitize-memory-track-origins=0 -O3 %s -o %t + +-// The most of targets fail the test. +-// XFAIL: target={{(x86|aarch64|loongarch64|mips|powerpc64).*}} ++// Nothing works yet. ++// XFAIL: target={{(aarch64|loongarch64|mips|powerpc64).*}} + + #include + #include +diff -ruN --strip-trailing-cr a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp +--- a/llvm/lib/Target/X86/X86ISelLowering.cpp ++++ b/llvm/lib/Target/X86/X86ISelLowering.cpp +@@ -49796,8 +49796,8 @@ + } + } + +- // If we also load/broadcast this to a wider type, then just extract the +- // lowest subvector. ++ // If we also broadcast this to a wider type, then just extract the lowest ++ // subvector. + if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() && + (RegVT.is128BitVector() || RegVT.is256BitVector())) { + SDValue Ptr = Ld->getBasePtr(); +@@ -49805,9 +49805,8 @@ + for (SDNode *User : Chain->uses()) { + if (User != N && + (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD || +- User->getOpcode() == X86ISD::VBROADCAST_LOAD || +- ISD::isNormalLoad(User)) && +- cast(User)->getChain() == Chain && ++ User->getOpcode() == X86ISD::VBROADCAST_LOAD) && ++ cast(User)->getChain() == Chain && + !User->hasAnyUseOfValue(1) && + User->getValueSizeInBits(0).getFixedValue() > + RegVT.getFixedSizeInBits()) { +@@ -49820,13 +49819,9 @@ + Extract = DAG.getBitcast(RegVT, Extract); + return DCI.CombineTo(N, Extract, SDValue(User, 1)); + } +- if ((User->getOpcode() == X86ISD::VBROADCAST_LOAD || +- (ISD::isNormalLoad(User) && +- cast(User)->getBasePtr() != Ptr)) && ++ if (User->getOpcode() == X86ISD::VBROADCAST_LOAD && + getTargetConstantFromBasePtr(Ptr)) { +- // See if we are loading a constant that has also been broadcast or +- // we are loading a constant that also matches in the lower +- // bits of a longer constant (but from a different constant pool ptr). ++ // See if we are loading a constant that has also been broadcast. + APInt Undefs, UserUndefs; + SmallVector Bits, UserBits; + if (getTargetConstantBitsFromNode(SDValue(N, 0), 8, Undefs, Bits) && +diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp ++++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +@@ -4669,16 +4669,22 @@ + + /// Compute the shadow address for a given va_arg. + Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, +- unsigned ArgOffset, unsigned ArgSize) { +- // Make sure we don't overflow __msan_va_arg_tls. +- if (ArgOffset + ArgSize > kParamTLSSize) +- return nullptr; ++ unsigned ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), + "_msarg_va_s"); + } + ++ /// Compute the shadow address for a given va_arg. ++ Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, ++ unsigned ArgOffset, unsigned ArgSize) { ++ // Make sure we don't overflow __msan_va_arg_tls. ++ if (ArgOffset + ArgSize > kParamTLSSize) ++ return nullptr; ++ return getShadowPtrForVAArgument(Ty, IRB, ArgOffset); ++ } ++ + /// Compute the origin address for a given va_arg. + Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); +@@ -4772,6 +4778,24 @@ + unsigned FpOffset = AMD64GpEndOffset; + unsigned OverflowOffset = AMD64FpEndOffset; + const DataLayout &DL = F.getParent()->getDataLayout(); ++ ++ auto CleanUnusedTLS = [&](Value *ShadowBase, unsigned BaseOffset) { ++ // Make sure we don't overflow __msan_va_arg_tls. ++ if (OverflowOffset <= kParamTLSSize) ++ return false; // Not needed, end is not reacheed. ++ ++ // The tails of __msan_va_arg_tls is not large enough to fit full ++ // value shadow, but it will be copied to backup anyway. Make it ++ // clean. ++ if (BaseOffset < kParamTLSSize) { ++ Value *TailSize = ConstantInt::getSigned(IRB.getInt32Ty(), ++ kParamTLSSize - BaseOffset); ++ IRB.CreateMemSet(ShadowBase, ConstantInt::getNullValue(IRB.getInt8Ty()), ++ TailSize, Align(8)); ++ } ++ return true; // Incomplete ++ }; ++ + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { + bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); + bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); +@@ -4784,19 +4808,22 @@ + assert(A->getType()->isPointerTy()); + Type *RealTy = CB.getParamByValType(ArgNo); + uint64_t ArgSize = DL.getTypeAllocSize(RealTy); +- Value *ShadowBase = getShadowPtrForVAArgument( +- RealTy, IRB, OverflowOffset, alignTo(ArgSize, 8)); ++ uint64_t AlignedSize = alignTo(ArgSize, 8); ++ unsigned BaseOffset = OverflowOffset; ++ Value *ShadowBase = ++ getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); + Value *OriginBase = nullptr; + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); +- OverflowOffset += alignTo(ArgSize, 8); +- if (!ShadowBase) +- continue; ++ OverflowOffset += AlignedSize; ++ ++ if (CleanUnusedTLS(ShadowBase, BaseOffset)) ++ continue; // We have no space to copy shadow there. ++ + Value *ShadowPtr, *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, + /*isStore*/ false); +- + IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr, + kShadowTLSAlignment, ArgSize); + if (MS.TrackOrigins) +@@ -4811,36 +4838,39 @@ + Value *ShadowBase, *OriginBase = nullptr; + switch (AK) { + case AK_GeneralPurpose: +- ShadowBase = +- getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); ++ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(IRB, GpOffset); + GpOffset += 8; ++ assert(GpOffset <= kParamTLSSize); + break; + case AK_FloatingPoint: +- ShadowBase = +- getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); ++ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(IRB, FpOffset); + FpOffset += 16; ++ assert(FpOffset <= kParamTLSSize); + break; + case AK_Memory: + if (IsFixed) + continue; + uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); ++ uint64_t AlignedSize = alignTo(ArgSize, 8); ++ unsigned BaseOffset = OverflowOffset; + ShadowBase = +- getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); +- if (MS.TrackOrigins) ++ getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); ++ if (MS.TrackOrigins) { + OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); +- OverflowOffset += alignTo(ArgSize, 8); ++ } ++ OverflowOffset += AlignedSize; ++ if (CleanUnusedTLS(ShadowBase, BaseOffset)) ++ continue; // We have no space to copy shadow there. + } + // Take fixed arguments into account for GpOffset and FpOffset, + // but don't actually store shadows for them. + // TODO(glider): don't call get*PtrForVAArgument() for them. + if (IsFixed) + continue; +- if (!ShadowBase) +- continue; + Value *Shadow = MSV.getShadow(A); + IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment); + if (MS.TrackOrigins) { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll b/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll +--- a/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll ++++ b/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll +@@ -1400,7 +1400,7 @@ + ; AVX-64-LABEL: f4xi64_i128: + ; AVX-64: # %bb.0: + ; AVX-64-NEXT: vextractf128 $1, %ymm0, %xmm1 +-; AVX-64-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1] ++; AVX-64-NEXT: vmovdqa {{.*#+}} xmm2 = [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0] + ; AVX-64-NEXT: vpaddq %xmm2, %xmm1, %xmm1 + ; AVX-64-NEXT: vpaddq %xmm2, %xmm0, %xmm0 + ; AVX-64-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0 +@@ -1535,7 +1535,7 @@ + ; AVX-64-NEXT: vextractf128 $1, %ymm1, %xmm2 + ; AVX-64-NEXT: vmovdqa {{.*#+}} xmm3 = [2,3] + ; AVX-64-NEXT: vpaddq %xmm3, %xmm2, %xmm2 +-; AVX-64-NEXT: vmovdqa {{.*#+}} xmm4 = [0,1] ++; AVX-64-NEXT: vmovdqa {{.*#+}} xmm4 = [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0] + ; AVX-64-NEXT: vpaddq %xmm4, %xmm1, %xmm1 + ; AVX-64-NEXT: vinsertf128 $1, %xmm2, %ymm1, %ymm1 + ; AVX-64-NEXT: vextractf128 $1, %ymm0, %xmm2 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll +@@ -2157,7 +2157,7 @@ + ; AVX2-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] + ; AVX2-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1,2],ymm11[3],ymm10[4,5,6,7,8,9,10],ymm11[11],ymm10[12,13,14,15] + ; AVX2-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm10[2,3,2,3,2,3,2,3,8,9,8,9,6,7,4,5,18,19,18,19,18,19,18,19,24,25,24,25,22,23,20,21] +-; AVX2-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> ++; AVX2-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] + ; AVX2-SLOW-NEXT: vpblendvb %ymm10, %ymm8, %ymm11, %ymm8 + ; AVX2-SLOW-NEXT: vpblendd {{.*#+}} ymm11 = ymm5[0,1],ymm6[2],ymm5[3,4,5],ymm6[6],ymm5[7] + ; AVX2-SLOW-NEXT: vextracti128 $1, %ymm11, %xmm12 +@@ -2329,7 +2329,7 @@ + ; AVX2-FAST-NEXT: vmovdqa {{.*#+}} ymm12 = <2,5,1,u,4,u,u,u> + ; AVX2-FAST-NEXT: vpermd %ymm11, %ymm12, %ymm11 + ; AVX2-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm11[2,3,2,3,2,3,2,3,8,9,0,1,6,7,8,9,18,19,18,19,18,19,18,19,24,25,16,17,22,23,24,25] +-; AVX2-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> ++; AVX2-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] + ; AVX2-FAST-NEXT: vpblendvb %ymm11, %ymm10, %ymm12, %ymm10 + ; AVX2-FAST-NEXT: vpblendd {{.*#+}} ymm12 = ymm4[0,1],ymm6[2],ymm4[3,4,5],ymm6[6],ymm4[7] + ; AVX2-FAST-NEXT: vextracti128 $1, %ymm12, %xmm13 +@@ -2496,7 +2496,7 @@ + ; AVX2-FAST-PERLANE-NEXT: vpermq {{.*#+}} ymm12 = ymm11[2,3,0,1] + ; AVX2-FAST-PERLANE-NEXT: vpblendw {{.*#+}} ymm11 = ymm11[0,1,2],ymm12[3],ymm11[4,5,6,7,8,9,10],ymm12[11],ymm11[12,13,14,15] + ; AVX2-FAST-PERLANE-NEXT: vpshufb {{.*#+}} ymm12 = ymm11[2,3,2,3,2,3,2,3,8,9,8,9,6,7,4,5,18,19,18,19,18,19,18,19,24,25,24,25,22,23,20,21] +-; AVX2-FAST-PERLANE-NEXT: vmovdqa {{.*#+}} xmm11 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> ++; AVX2-FAST-PERLANE-NEXT: vmovdqa {{.*#+}} xmm11 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] + ; AVX2-FAST-PERLANE-NEXT: vpblendvb %ymm11, %ymm8, %ymm12, %ymm8 + ; AVX2-FAST-PERLANE-NEXT: vpblendd {{.*#+}} ymm12 = ymm5[0,1],ymm6[2],ymm5[3,4,5],ymm6[6],ymm5[7] + ; AVX2-FAST-PERLANE-NEXT: vextracti128 $1, %ymm12, %xmm13 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll +@@ -1685,7 +1685,7 @@ + ; AVX2-ONLY-NEXT: # ymm10 = mem[0,1,0,1] + ; AVX2-ONLY-NEXT: vpblendvb %ymm10, %ymm7, %ymm8, %ymm7 + ; AVX2-ONLY-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,1,6,11,16,21,26,31,20,25,30,19,24,29,u,u,u,u,u,u] +-; AVX2-ONLY-NEXT: vmovdqa {{.*#+}} xmm10 = <255,255,255,255,255,255,255,255,255,255,255,255,255,0,0,0> ++; AVX2-ONLY-NEXT: vmovdqa {{.*#+}} xmm10 = [255,255,255,255,255,255,255,255,255,255,255,255,255,0,0,0] + ; AVX2-ONLY-NEXT: vpblendvb %ymm10, %ymm6, %ymm7, %ymm6 + ; AVX2-ONLY-NEXT: vmovdqa 144(%rdi), %xmm7 + ; AVX2-ONLY-NEXT: vpshufb {{.*#+}} xmm11 = xmm7[u,u,u,u,u,u,u,u,u,u],zero,zero,zero,xmm7[1,6,11] +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll +@@ -1238,12 +1238,13 @@ + ; AVX512F-NEXT: vshufi64x2 {{.*#+}} zmm3 = zmm3[0,1,2,3],zmm6[4,5,6,7] + ; AVX512F-NEXT: vmovdqa (%rdx), %ymm6 + ; AVX512F-NEXT: vmovdqa 32(%rdx), %ymm7 +-; AVX512F-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,10,11,128,128,128,128,12,13,128,128,128,128,14,15,128,128,128,128,16,17,128,128,128,128,18,19,128,128,128,128] +-; AVX512F-NEXT: vpshufb %ymm9, %ymm7, %ymm10 +-; AVX512F-NEXT: vmovdqa {{.*#+}} ymm11 = <5,5,u,6,6,u,7,7> +-; AVX512F-NEXT: vpermd %ymm7, %ymm11, %ymm7 +-; AVX512F-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm7, %ymm7 +-; AVX512F-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 ++; AVX512F-NEXT: vmovdqa {{.*#+}} ymm9 = <5,5,u,6,6,u,7,7> ++; AVX512F-NEXT: vpermd %ymm7, %ymm9, %ymm9 ++; AVX512F-NEXT: vmovdqa {{.*#+}} ymm10 = [0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0] ++; AVX512F-NEXT: vpandn %ymm9, %ymm10, %ymm9 ++; AVX512F-NEXT: vmovdqa {{.*#+}} ymm10 = [128,128,10,11,128,128,128,128,12,13,128,128,128,128,14,15,128,128,128,128,16,17,128,128,128,128,18,19,128,128,128,128] ++; AVX512F-NEXT: vpshufb %ymm10, %ymm7, %ymm7 ++; AVX512F-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 + ; AVX512F-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm7 + ; AVX512F-NEXT: vmovdqa (%rdi), %ymm3 + ; AVX512F-NEXT: vpshufb %ymm5, %ymm3, %ymm3 +@@ -1258,7 +1259,7 @@ + ; AVX512F-NEXT: vpshufb %xmm2, %xmm0, %xmm0 + ; AVX512F-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm0 + ; AVX512F-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm3[4,5,6,7] +-; AVX512F-NEXT: vpshufb %ymm9, %ymm6, %ymm1 ++; AVX512F-NEXT: vpshufb %ymm10, %ymm6, %ymm1 + ; AVX512F-NEXT: vmovdqa {{.*#+}} ymm2 = + ; AVX512F-NEXT: vpermd %ymm6, %ymm2, %ymm2 + ; AVX512F-NEXT: vmovdqa {{.*#+}} ymm3 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll +@@ -2831,15 +2831,15 @@ + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,0,0] + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm9, %zmm2 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm11, %zmm2 +-; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm10 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm9 ++; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm9 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm10 + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm11 = [128,128,128,128,12,13,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128] +-; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm9, %ymm4 +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,1,1] ++; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm10, %ymm4 ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,1,1,1] + ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} ymm21 = [65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535] +-; AVX512F-SLOW-NEXT: vpandnq %ymm9, %ymm21, %ymm9 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm9, %zmm9 +-; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm9 ++; AVX512F-SLOW-NEXT: vpandnq %ymm10, %ymm21, %ymm10 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm10, %zmm10 ++; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm10 + ; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm2 + ; AVX512F-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm12[0],xmm2[0],xmm12[1],xmm2[1],xmm12[2],xmm2[2],xmm12[3],xmm2[3] + ; AVX512F-SLOW-NEXT: vpshufb %xmm13, %xmm4, %xmm4 +@@ -2860,7 +2860,7 @@ + ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535] + ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm2, %zmm7, %zmm4 + ; AVX512F-SLOW-NEXT: vpbroadcastq (%r8), %ymm2 +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm10[0,1,1,1] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm9[0,1,1,1] + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm2, %zmm2 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm18[0,1,2,1,4,5,6,5] +@@ -2909,15 +2909,16 @@ + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm5, %zmm1 + ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm16, %zmm1 +-; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm10, %ymm0 +-; AVX512F-SLOW-NEXT: vpbroadcastq 16(%r8), %ymm3 +-; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm3, %ymm3 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 ++; AVX512F-SLOW-NEXT: vpbroadcastq 16(%r8), %ymm0 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = [65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535] ++; AVX512F-SLOW-NEXT: vpandn %ymm0, %ymm3, %ymm0 ++; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm9, %ymm3 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm0 + ; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm0 + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm0, 64(%r9) + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm4, 256(%r9) + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm2, (%r9) +-; AVX512F-SLOW-NEXT: vmovdqa64 %zmm9, 192(%r9) ++; AVX512F-SLOW-NEXT: vmovdqa64 %zmm10, 192(%r9) + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm19, 128(%r9) + ; AVX512F-SLOW-NEXT: vzeroupper + ; AVX512F-SLOW-NEXT: retq +@@ -3018,10 +3019,11 @@ + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm13, %zmm7 + ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535] + ; AVX512F-FAST-NEXT: vpternlogq $226, %zmm3, %zmm20, %zmm7 +-; AVX512F-FAST-NEXT: vmovdqa64 %ymm24, %ymm3 +-; AVX512F-FAST-NEXT: vpshufb %ymm3, %ymm0, %ymm0 + ; AVX512F-FAST-NEXT: vpbroadcastq 16(%r8), %ymm3 +-; AVX512F-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm3, %ymm3 ++; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535] ++; AVX512F-FAST-NEXT: vpandn %ymm3, %ymm13, %ymm3 ++; AVX512F-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 ++; AVX512F-FAST-NEXT: vpshufb %ymm11, %ymm0, %ymm0 + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 + ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm0 + ; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm13 = [30,31,28,29,26,27,30,31,30,31,28,29,30,31,28,29,30,31,28,29,26,27,30,31,30,31,28,29,30,31,28,29] +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll +@@ -2522,7 +2522,8 @@ + ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm12, %xmm10, %xmm12 + ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd 8(%rax), %ymm10 +-; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm10, %ymm10 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} ymm20 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512F-ONLY-FAST-NEXT: vpandnq %ymm10, %ymm20, %ymm10 + ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm8 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3] + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm14 = xmm8[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm8 = xmm8[0,1,2,3,6,7,4,5,6,7,4,5,12,13,14,15] +@@ -2788,7 +2789,8 @@ + ; AVX512DQ-FAST-NEXT: vpshufb %xmm12, %xmm10, %xmm12 + ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax + ; AVX512DQ-FAST-NEXT: vpbroadcastd 8(%rax), %ymm10 +-; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm10, %ymm10 ++; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} ymm20 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512DQ-FAST-NEXT: vpandnq %ymm10, %ymm20, %ymm10 + ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm8 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3] + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm14 = xmm8[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm8 = xmm8[0,1,2,3,6,7,4,5,6,7,4,5,12,13,14,15] +@@ -5250,336 +5252,330 @@ + ; + ; AVX512F-ONLY-SLOW-LABEL: store_i16_stride7_vf32: + ; AVX512F-ONLY-SLOW: # %bb.0: +-; AVX512F-ONLY-SLOW-NEXT: subq $632, %rsp # imm = 0x278 ++; AVX512F-ONLY-SLOW-NEXT: subq $648, %rsp # imm = 0x288 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm1, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm1, %ymm16 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm13 = +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm13, %ymm2, %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 +-; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm12 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm12, %ymm1, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm1, %ymm15 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm14, %ymm2, %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 +-; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %ymm10 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %ymm11 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, (%rsp) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %ymm4 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm4, %ymm30 + ; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %xmm3 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %xmm6 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm11[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm4 = ymm10[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[2,2,2,3,6,6,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm2[2],ymm4[3,4],ymm2[5],ymm4[6,7,8,9],ymm2[10],ymm4[11,12],ymm2[13],ymm4[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm6[0],xmm3[0],xmm6[1],xmm3[1],xmm6[2],xmm3[2],xmm6[3],xmm3[3] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm2[0,1,3,2,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %xmm8 +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,5,7,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> +-; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm2, %zmm3, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm4 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm4, %ymm2, %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %ymm5 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm5, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm31 ++; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %ymm15 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm2, %ymm15, %ymm5 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %ymm6 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm3, %ymm6, %ymm7 ++; AVX512F-ONLY-SLOW-NEXT: vpor %ymm5, %ymm7, %ymm5 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %xmm7 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %xmm8 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm4, %ymm5, %ymm4 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %ymm11 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm11, %ymm9 ++; AVX512F-ONLY-SLOW-NEXT: vpor %ymm4, %ymm9, %ymm4 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %ymm3 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %ymm6 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm12, %ymm6, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm14, %ymm3, %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %ymm12 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %ymm7 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm12, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm13, %ymm7, %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm13, %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm14, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vpor %ymm1, %ymm0, %ymm0 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %ymm13 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm13, %ymm0 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %ymm14 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm14, %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %xmm0 +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm2 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm2, %xmm8, %xmm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm2, %xmm20 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm2 = xmm0[1,1,2,2] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm0[0],xmm8[0],xmm0[1],xmm8[1],xmm0[2],xmm8[2],xmm0[3],xmm8[3] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm2 = xmm2[0,1,3,2,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = +-; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm2, %zmm1, %zmm4 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %xmm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %xmm2 +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm5 = xmm4[0,1,2,3,4,5,7,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm4 = xmm4[0,1,3,2,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> +-; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm9 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %ymm4 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm2, %ymm4, %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %ymm0 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm3, %ymm0, %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vpor %ymm1, %ymm2, %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm6[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm2 = ymm15[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm2[2,2,2,3,6,6,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm1[2],ymm2[3,4],ymm1[5],ymm2[6,7,8,9],ymm1[10],ymm2[11,12],ymm1[13],ymm2[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm8[0],xmm7[0],xmm8[1],xmm7[1],xmm8[2],xmm7[2],xmm8[3],xmm7[3] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm1[0,1,3,2,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm3, %zmm2, %zmm9 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm9, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm10, %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[1,2,2,3,5,6,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm5[0,1],ymm4[2],ymm5[3,4],ymm4[5],ymm5[6,7,8,9],ymm4[10],ymm5[11,12],ymm4[13],ymm5[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[0,0,2,1,4,4,6,5] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm9 = ymm10[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm9[0,0,0,0,4,4,4,4] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm9[0,1,2],ymm5[3],ymm9[4,5],ymm5[6],ymm9[7,8,9,10],ymm5[11],ymm9[12,13],ymm5[14],ymm9[15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm27 = [2,2,3,3,10,9,11,10] +-; AVX512F-ONLY-SLOW-NEXT: vpermi2q %zmm4, %zmm5, %zmm27 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %xmm9 ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm8[4],xmm7[4],xmm8[5],xmm7[5],xmm8[6],xmm7[6],xmm8[7],xmm7[7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,4,5,7,6] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm2, %xmm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm3 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> ++; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm1, %zmm2, %zmm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %xmm2 ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm7 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm9, %xmm1 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[1,1,2,2] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2,3],xmm1[4],xmm3[5,6],xmm1[7] ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm9[0],xmm2[1],xmm9[1],xmm2[2],xmm9[2],xmm2[3],xmm9[3] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm3[0,1,3,2,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = ++; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm3, %zmm1, %zmm8 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %xmm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %xmm12 ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm12[0],xmm3[0],xmm12[1],xmm3[1],xmm12[2],xmm3[2],xmm12[3],xmm3[3] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm1[0,1,2,3,4,5,7,6] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[0,1,3,2,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm26 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> ++; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm8, %zmm1, %zmm26 ++; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm15, %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm6[1,2,2,3,5,6,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0,1],ymm1[2],ymm8[3,4],ymm1[5],ymm8[6,7,8,9],ymm1[10],ymm8[11,12],ymm1[13],ymm8[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm8 = ymm15[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[0,0,0,0,4,4,4,4] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm8[0,1,2],ymm6[3],ymm8[4,5],ymm6[6],ymm8[7,8,9,10],ymm6[11],ymm8[12,13],ymm6[14],ymm8[15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] ++; AVX512F-ONLY-SLOW-NEXT: vpermi2q %zmm1, %zmm6, %zmm28 + ; AVX512F-ONLY-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rax), %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm5 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm5, %ymm4, %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm30 +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm12[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm9[4],xmm2[5],xmm9[5],xmm2[6],xmm9[6],xmm2[7],xmm9[7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm1, %xmm25 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rax), %ymm8 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm8, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512F-ONLY-SLOW-NEXT: vpandn %ymm1, %ymm2, %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm6 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm6, %ymm8, %ymm2 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm12[4],xmm3[4],xmm12[5],xmm3[5],xmm12[6],xmm3[6],xmm12[7],xmm3[7] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm5[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,2,1,4,4,6,5] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[1,1,1,1,5,5,5,5] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,0,0,4,4,4,4] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,1,1,3,4,5,5,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm3[0,1],ymm1[2],ymm3[3,4],ymm1[5],ymm3[6,7,8,9],ymm1[10],ymm3[11,12],ymm1[13],ymm3[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rax), %ymm1 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm1[0,1,1,3,4,5,5,7] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] ++; AVX512F-ONLY-SLOW-NEXT: vpandn %ymm3, %ymm8, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm6, %ymm1, %ymm6 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm0[0,0,2,1,4,4,6,5] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm6[3],ymm3[4,5],ymm6[6],ymm3[7,8,9,10],ymm6[11],ymm3[12,13],ymm6[14],ymm3[15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 ++; AVX512F-ONLY-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm15 = [22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27,22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27] ++; AVX512F-ONLY-SLOW-NEXT: # ymm15 = mem[0,1,0,1] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm15, %ymm13, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm14[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm22 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm6 = ymm13[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0],ymm6[1],ymm3[2,3],ymm6[4],ymm3[5,6,7,8],ymm6[9],ymm3[10,11],ymm6[12],ymm3[13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm5[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm11[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm6[0],ymm3[1],ymm6[2,3],ymm3[4],ymm6[5,6,7,8],ymm3[9],ymm6[10,11],ymm3[12],ymm6[13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm5 = ymm5[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm5[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm13 = ymm3[0,1,2],ymm5[3],ymm3[4,5],ymm5[6],ymm3[7,8,9,10],ymm5[11],ymm3[12,13],ymm5[14],ymm3[15] ++; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm4, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm0[1,2,2,3,5,6,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0,1],ymm3[2],ymm5[3,4],ymm3[5],ymm5[6,7,8,9],ymm3[10],ymm5[11,12],ymm3[13],ymm5[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm24 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm4[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,3,6,6,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm23 +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm0 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] ++; AVX512F-ONLY-SLOW-NEXT: # zmm0 = mem[0,1,2,3,0,1,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermd %zmm1, %zmm0, %zmm29 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu (%rsp), %ymm11 # 32-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm11[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] + ; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm30[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm9 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm14[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] + ; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 +-; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm13, %ymm0 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[1,2,2,3,5,6,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4],ymm0[5],ymm1[6,7,8,9],ymm0[10],ymm1[11,12],ymm0[13],ymm1[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm13[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,3,6,6,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm16, %ymm4 +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm4[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm29[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm18 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm15, %ymm8 +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm15[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm17[3,3,3,3,7,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm31[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %xmm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %xmm9 +-; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm9, %xmm1 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm0[1,1,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm15[0,1],xmm1[2],xmm15[3,4],xmm1[5],xmm15[6,7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm0[0],xmm9[0],xmm0[1],xmm9[1],xmm0[2],xmm9[2],xmm0[3],xmm9[3] +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm9[4],xmm0[4],xmm9[5],xmm0[5],xmm9[6],xmm0[6],xmm9[7],xmm0[7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %xmm9 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm20, %xmm0 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm0, %xmm9, %xmm0 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %xmm15 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm11 = xmm15[1,1,2,2] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm11[0],xmm0[1],xmm11[2,3],xmm0[4],xmm11[5,6],xmm0[7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm9[0],xmm15[1],xmm9[1],xmm15[2],xmm9[2],xmm15[3],xmm9[3] +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm9 = xmm15[4],xmm9[4],xmm15[5],xmm9[5],xmm15[6],xmm9[6],xmm15[7],xmm9[7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm9, %xmm25 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[1,1,1,1,5,5,5,5] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm6[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm24 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm7[0,1,1,3,4,5,5,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm12[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm22 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,0,2,1,4,4,6,5] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1,2],ymm3[3],ymm6[4,5],ymm3[6],ymm6[7,8,9,10],ymm3[11],ymm6[12,13],ymm3[14],ymm6[15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 +-; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm3 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] +-; AVX512F-ONLY-SLOW-NEXT: # zmm3 = mem[0,1,2,3,0,1,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rax), %ymm6 +-; AVX512F-ONLY-SLOW-NEXT: vpermd %zmm6, %zmm3, %zmm3 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm28, %ymm7 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm7, %ymm6, %ymm11 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,1,1,3,4,5,5,7] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm6 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm11, %zmm6 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm11 +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %xmm13 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %xmm14 +-; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm13[0],xmm14[0],xmm13[1],xmm14[1],xmm13[2],xmm14[2],xmm13[3],xmm14[3] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm2 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 +-; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm15 = xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] +-; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm14, %xmm14 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[1,1,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm13[0,1],xmm14[2],xmm13[3,4],xmm14[5],xmm13[6,7] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm4, %ymm2 +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,0,0,0,4,4,4,4] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm29, %ymm12 +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm29[0,1,1,3,4,5,5,7] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm14 = ymm9[0,1],ymm14[2],ymm9[3,4],ymm14[5],ymm9[6,7,8,9],ymm14[10],ymm9[11,12],ymm14[13],ymm9[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm31, %xmm4 +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm4[0,2,3,3,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 ++; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm1, %xmm3 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm0[1,1,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm6[0,1],xmm3[2],xmm6[3,4],xmm3[5],xmm6[6,7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm21 ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm20 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %xmm1 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm0 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %xmm6 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm7 = xmm6[1,1,2,2] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm7[0],xmm0[1],xmm7[2,3],xmm0[4],xmm7[5,6],xmm0[7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm5 = xmm6[0],xmm1[0],xmm6[1],xmm1[1],xmm6[2],xmm1[2],xmm6[3],xmm1[3] ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm1[4],xmm6[5],xmm1[5],xmm6[6],xmm1[6],xmm6[7],xmm1[7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm18 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm12, %xmm2 ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm3, %xmm4 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %xmm3 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %xmm6 ++; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm7 = xmm3[0],xmm6[0],xmm3[1],xmm6[1],xmm3[2],xmm6[2],xmm3[3],xmm6[3] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm7, %xmm10 ++; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] ++; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm6, %xmm6 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm3[1,1,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm7 = xmm3[0,1],xmm6[2],xmm3[3,4],xmm6[5],xmm3[6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm11[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm30[0,1,1,3,4,5,5,7] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm15, %ymm11, %ymm3 ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm30[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm15 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm14[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,2,1,4,4,6,5] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[1,1,1,1,5,5,5,5] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm14 = ymm14[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[2,2,2,2,6,6,6,6] ++; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm12 = ymm12[0],ymm14[1],ymm12[2,3],ymm14[4],ymm12[5,6,7,8],ymm14[9],ymm12[10,11],ymm14[12],ymm12[13,14,15] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm25, %xmm1 ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm14 = xmm1[0,2,3,3,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,2,1] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[2,1,2,3,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,4] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm0[0,0,1,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,0,1,1] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm22[2,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm17[0,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm16[0,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,1,3,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm5[0,1,3,2,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm5 = xmm5[0,0,1,1] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,1,3,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,0,1,1] ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm30 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm0 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm13 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $184, %zmm30, %zmm13, %zmm0 ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,1,3] ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm9, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm13, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm14, %zmm5 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm8 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm31 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm31, %zmm8 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm1 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm2[0,1,2,3],zmm1[4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: # ymm2 = mem[2,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: # ymm5 = mem[2,1,3,2] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm27[2,2,3,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm24[2,1,3,2] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm16 = ymm23[2,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload ++; AVX512F-ONLY-SLOW-NEXT: # ymm17 = mem[2,3,3,3,6,7,7,7] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm21[0,0,2,1] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm20, %xmm9 ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm11 = xmm9[2,1,2,3,4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm11 = xmm11[0,1,2,3,4,5,5,4] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm19[0,0,1,1] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm18, %xmm9 ++; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm9[0,2,3,3,4,5,6,7] + ; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,2,1] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm15 = xmm15[2,1,2,3,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm15 = xmm15[0,1,2,3,4,5,5,4] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[0,0,1,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,1] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm30[2,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm23[0,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm19[2,1,3,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm18[0,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[0,1,3,2,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,1,1] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm5[2,1,3,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,1] +-; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm12[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm4 = ymm8[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,2,1,4,4,6,5] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[1,1,1,1,5,5,5,5] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm5[2],ymm4[3,4],ymm5[5],ymm4[6,7,8,9],ymm5[10],ymm4[11,12],ymm5[13],ymm4[14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm12 = ymm8[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm12[2,2,2,2,6,6,6,6] +-; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6,7,8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14,15] +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm28, %zmm12 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $184, %zmm12, %zmm10, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,1,1,3] +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm30, %zmm0 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm31, %zmm1 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm1 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm0 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm15, %zmm9 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm11[0,1,2,3],zmm0[4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: # ymm11 = mem[2,1,3,2] +-; AVX512F-ONLY-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: # ymm12 = mem[2,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: # ymm15 = mem[2,3,3,3,6,7,7,7] +-; AVX512F-ONLY-SLOW-NEXT: vpermq $96, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: # ymm17 = mem[0,0,2,1] +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw $230, {{[-0-9]+}}(%r{{[sb]}}p), %xmm8 # 16-byte Folded Reload +-; AVX512F-ONLY-SLOW-NEXT: # xmm8 = mem[2,1,2,3,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm8[0,1,2,3,4,5,5,4] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[0,0,1,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm26[0,0,1,1] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm25, %xmm13 +-; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm13 = xmm13[0,2,3,3,4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,2,1] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm24[2,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm22[2,1,3,2] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm21[2,2,3,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm20[0,0,1,1] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm16[0,0,2,1] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,1,3,2] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,2,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm16 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm16, %zmm0 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm0 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm3 +-; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm15[2,1,3,2] +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm9 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm7 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm17, %zmm1 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm18, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm1, %zmm10, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm1 +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm9 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm1, %zmm1 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm1 ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,1] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,0,2,1] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,1,3,2] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,2,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm18 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm1 + ; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm1 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm19, %zmm8, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm28, %zmm9, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm0, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm10[0,1,2,3],zmm8[4,5,6,7] +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm6 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm6 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm8 +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd (%rax), %ymm9 +-; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm10 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm10, %zmm9, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm14, %zmm2 +-; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm4, %zmm4 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm4 +-; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm2 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] +-; AVX512F-ONLY-SLOW-NEXT: # zmm2 = mem[0,1,2,3,0,1,2,3] +-; AVX512F-ONLY-SLOW-NEXT: vpermd (%rax), %zmm2, %zmm2 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm27, %zmm2 +-; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm8, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm8[0,1,2,3],zmm2[4,5,6,7] ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm16, %zmm14, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm29 ++; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm17[2,1,3,2] ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm5 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm0, %zmm0 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm0 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm0 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm22, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm30, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm31, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm4 ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm8 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm4, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd (%rax), %ymm7 ++; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm8 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm7, %zmm7 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm7 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm7 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm15, %zmm6, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm3, %zmm3 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm3 ++; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm5 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] ++; AVX512F-ONLY-SLOW-NEXT: # zmm5 = mem[0,1,2,3,0,1,2,3] ++; AVX512F-ONLY-SLOW-NEXT: vpermd (%rax), %zmm5, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm28, %zmm5 ++; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm5 + ; AVX512F-ONLY-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm2, 128(%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm9, (%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm6, 320(%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm1, 256(%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm7, 192(%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm0, 64(%rax) +-; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm3, 384(%rax) +-; AVX512F-ONLY-SLOW-NEXT: addq $632, %rsp # imm = 0x278 ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm5, 128(%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm7, (%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm2, 320(%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm4, 256(%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm0, 192(%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm1, 64(%rax) ++; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm29, 384(%rax) ++; AVX512F-ONLY-SLOW-NEXT: addq $648, %rsp # imm = 0x288 + ; AVX512F-ONLY-SLOW-NEXT: vzeroupper + ; AVX512F-ONLY-SLOW-NEXT: retq + ; +@@ -5613,9 +5609,9 @@ + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm13 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm13, %ymm6 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm14 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm15 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm14, %ymm7 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm15, %ymm7 + ; AVX512F-ONLY-FAST-NEXT: vporq %ymm6, %ymm7, %ymm25 + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm4, %ymm10, %ymm4 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm6 +@@ -5629,8 +5625,9 @@ + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm15 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm15, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm1, %ymm14 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %ymm4 + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm4, %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 +@@ -5661,11 +5658,11 @@ + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 + ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] + ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm3, %xmm3 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm30 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm5, %xmm8 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm19 = [2,1,3,3,8,8,9,9] + ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm3, %zmm2, %zmm19 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] +-; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] ++; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[3,3,3,3,7,7,7,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm3[2],ymm2[3,4],ymm3[5],ymm2[6,7,8,9],ymm3[10],ymm2[11,12],ymm3[13],ymm2[14,15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [2,2,2,3,8,8,8,9] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm3 +@@ -5693,52 +5690,54 @@ + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm27 = [0,0,0,1,8,9,9,11] + ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm27 + ; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm13, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[1,2,2,3,5,6,6,7] ++; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[1,2,2,3,5,6,6,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd {{.*#+}} ymm5 = [18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21] + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm13, %ymm3 +-; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm14[0,0,2,1,4,4,6,5] ++; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm15[0,0,2,1,4,4,6,5] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1,2],ymm7[3],ymm3[4,5],ymm7[6],ymm3[7,8,9,10],ymm7[11],ymm3[12,13],ymm7[14],ymm3[15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] + ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm28 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm8 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm15 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm3 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm3, %ymm17 ++; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm9 + ; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm25, %zmm0, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm30, %xmm9 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm1, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm1, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm8, %xmm18 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm25 = <0,0,1,1,12,13,u,15> + ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm2, %zmm1, %zmm25 + ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd 8(%rax), %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512F-ONLY-FAST-NEXT: vpandn %ymm1, %ymm2, %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm3 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm3, %ymm7 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm3, %ymm16 + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm1, %zmm30 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18,19,20,21,18,19,20,21,24,25,26,27,22,23,22,23] + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[1,1,1,1,5,5,5,5] +-; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 ++; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm31, %ymm13 + ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} ymm1 = ymm13[0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11] + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm12[0,1,1,3,4,5,5,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm7 = ymm6[0,1],ymm1[2],ymm6[3,4],ymm1[5],ymm6[6,7,8,9],ymm1[10],ymm6[11,12],ymm1[13],ymm6[14,15] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm15, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm14, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm14, %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm4[0,0,2,1,4,4,6,5] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1,2],ymm5[3],ymm1[4,5],ymm5[6],ymm1[7,8,9,10],ymm5[11],ymm1[12,13],ymm5[14],ymm1[15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm29 = <0,1,u,3,10,10,11,11> + ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm1, %zmm21, %zmm29 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rax), %ymm6 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm2, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm14 ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm1, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] ++; AVX512F-ONLY-FAST-NEXT: vpandn %ymm1, %ymm5, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm14 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[14,15,12,13,u,u,u,u,u,u,u,u,u,u,u,u,30,31,28,29,u,u,u,u,30,31,28,29,u,u,u,u] + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm12[3,3,3,3,7,7,7,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0],ymm2[1],ymm5[2,3],ymm2[4],ymm5[5,6,7,8],ymm2[9],ymm5[10,11],ymm2[12],ymm5[13,14,15] +@@ -5749,21 +5748,21 @@ + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm13 = ymm13[0,1],ymm12[2],ymm13[3,4],ymm12[5],ymm13[6,7,8,9],ymm12[10],ymm13[11,12],ymm12[13],ymm13[14,15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm21 = [2,2,2,3,8,10,10,11] + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm13 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm12 = ymm4[3,3,3,3,7,7,7,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm12[2],ymm2[3,4],ymm12[5],ymm2[6,7,8,9],ymm12[10],ymm2[11,12],ymm12[13],ymm2[14,15] +-; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm15, %ymm12 ++; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm3, %ymm12 + ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[1,2,2,3,5,6,6,7] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm12 = ymm4[0,1],ymm12[2],ymm4[3,4],ymm12[5],ymm4[6,7,8,9],ymm12[10],ymm4[11,12],ymm12[13],ymm4[14,15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = [2,1,3,2,10,10,10,11] + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm31, %zmm12 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm18 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm22, %zmm18, %zmm13 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm17 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm22, %zmm17, %zmm13 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm12 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 (%rax), %zmm15 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 (%rax), %zmm3 + ; AVX512F-ONLY-FAST-NEXT: vbroadcasti64x4 {{.*#+}} zmm4 = [30,5,0,0,31,6,0,31,30,5,0,0,31,6,0,31] + ; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[0,1,2,3,0,1,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermi2d %zmm15, %zmm6, %zmm4 ++; AVX512F-ONLY-FAST-NEXT: vpermi2d %zmm3, %zmm6, %zmm4 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm4 + ; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm6 = xmm11[4],xmm10[4],xmm11[5],xmm10[5],xmm11[6],xmm10[6],xmm11[7],xmm10[7] + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} xmm12 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] +@@ -5778,14 +5777,15 @@ + ; AVX512F-ONLY-FAST-NEXT: # xmm6 = xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm24, %xmm1 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm13 = xmm1[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm6, %xmm6 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm22 = [0,1,1,3,8,8,9,9] + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm6, %zmm22, %zmm13 + ; AVX512F-ONLY-FAST-NEXT: vprold $16, %xmm0, %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm8[1,1,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm15[1,1,2,3] + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1],xmm6[2],xmm2[3,4],xmm6[5],xmm2[6,7] +-; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm8[0],xmm0[0],xmm8[1],xmm0[1],xmm8[2],xmm0[2],xmm8[3],xmm0[3] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm0[0],xmm15[1],xmm0[1],xmm15[2],xmm0[2],xmm15[3],xmm0[3] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm11, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm2 + ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm12, %xmm2, %xmm6 +@@ -5814,8 +5814,8 @@ + ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm5 = ymm10[0,1],ymm5[2],ymm10[3,4],ymm5[5],ymm10[6,7,8,9],ymm5[10],ymm10[11,12],ymm5[13],ymm10[14,15] + ; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm10 = xmm12[0,2,3,3,4,5,6,7] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,2,1] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm17[0,0,1,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm16[2,2,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,1,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,2,3] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,1,3,2] + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm0, %zmm31, %zmm5 + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd (%rax), %ymm0 +@@ -5834,8 +5834,8 @@ + ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm12 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = <6,u,u,u,7,u,u,7> +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm3, %ymm2, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm15, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm16, %ymm2, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm16, %zmm3, %zmm3 + ; AVX512F-ONLY-FAST-NEXT: vbroadcasti64x4 {{.*#+}} zmm5 = [0,13,4,0,0,14,5,0,0,13,4,0,0,14,5,0] + ; AVX512F-ONLY-FAST-NEXT: # zmm5 = mem[0,1,2,3,0,1,2,3] + ; AVX512F-ONLY-FAST-NEXT: vpermd %zmm3, %zmm5, %zmm3 +@@ -5844,7 +5844,7 @@ + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm5 # 32-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm9 # 32-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm5, %zmm11, %zmm9 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm23, %zmm18, %zmm19 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm23, %zmm17, %zmm19 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm30 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm30 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload +@@ -5872,336 +5872,330 @@ + ; + ; AVX512DQ-SLOW-LABEL: store_i16_stride7_vf32: + ; AVX512DQ-SLOW: # %bb.0: +-; AVX512DQ-SLOW-NEXT: subq $632, %rsp # imm = 0x278 ++; AVX512DQ-SLOW-NEXT: subq $648, %rsp # imm = 0x288 + ; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %ymm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm1, %ymm0 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm1, %ymm16 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %ymm2 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm13 = +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm13, %ymm2, %ymm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 +-; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %ymm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm12 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm12, %ymm1, %ymm0 +-; AVX512DQ-SLOW-NEXT: vmovdqa %ymm1, %ymm15 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %ymm2 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm14, %ymm2, %ymm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 +-; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %ymm10 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm2 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %ymm11 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm3 ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, (%rsp) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm2 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %ymm4 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm4, %ymm30 + ; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 + ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %xmm3 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %xmm6 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm11[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm4 = ymm10[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[2,2,2,3,6,6,6,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm2[2],ymm4[3,4],ymm2[5],ymm4[6,7,8,9],ymm2[10],ymm4[11,12],ymm2[13],ymm4[14,15] +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm6[0],xmm3[0],xmm6[1],xmm3[1],xmm6[2],xmm3[2],xmm6[3],xmm3[3] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm2[0,1,3,2,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm7 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %xmm8 +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,5,7,6] +-; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> +-; AVX512DQ-SLOW-NEXT: vpermi2d %zmm2, %zmm3, %zmm4 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %ymm2 ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm4 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm4, %ymm2, %ymm2 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %ymm5 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm5, %ymm3 ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm31 ++; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %ymm15 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm2, %ymm15, %ymm5 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %ymm6 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm3, %ymm6, %ymm7 ++; AVX512DQ-SLOW-NEXT: vpor %ymm5, %ymm7, %ymm5 ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %xmm7 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %xmm8 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm4, %ymm5, %ymm4 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %ymm11 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm11, %ymm9 ++; AVX512DQ-SLOW-NEXT: vpor %ymm4, %ymm9, %ymm4 + ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %ymm3 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %ymm6 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm12, %ymm6, %ymm2 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm14, %ymm3, %ymm4 +-; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %ymm12 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %ymm7 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm12, %ymm2 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm13, %ymm7, %ymm4 +-; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm13, %ymm1 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm14, %ymm0 +-; AVX512DQ-SLOW-NEXT: vpor %ymm1, %ymm0, %ymm0 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %ymm13 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm13, %ymm0 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %ymm14 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm14, %ymm1 ++; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %xmm0 +-; AVX512DQ-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm2 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] +-; AVX512DQ-SLOW-NEXT: vpshufb %xmm2, %xmm8, %xmm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm2, %xmm20 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm2 = xmm0[1,1,2,2] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm0[0],xmm8[0],xmm0[1],xmm8[1],xmm0[2],xmm8[2],xmm0[3],xmm8[3] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm2 = xmm2[0,1,3,2,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = +-; AVX512DQ-SLOW-NEXT: vpermi2d %zmm2, %zmm1, %zmm4 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %xmm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %xmm2 +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm5 = xmm4[0,1,2,3,4,5,7,6] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm4 = xmm4[0,1,3,2,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> +-; AVX512DQ-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm9 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %ymm4 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm2, %ymm4, %ymm1 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %ymm0 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm3, %ymm0, %ymm2 ++; AVX512DQ-SLOW-NEXT: vpor %ymm1, %ymm2, %ymm1 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm6[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm2 = ymm15[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm2[2,2,2,3,6,6,6,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm1[2],ymm2[3,4],ymm1[5],ymm2[6,7,8,9],ymm1[10],ymm2[11,12],ymm1[13],ymm2[14,15] ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm8[0],xmm7[0],xmm8[1],xmm7[1],xmm8[2],xmm7[2],xmm8[3],xmm7[3] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm1[0,1,3,2,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpermi2d %zmm3, %zmm2, %zmm9 + ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm9, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vprold $16, %ymm10, %ymm4 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[1,2,2,3,5,6,6,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm5[0,1],ymm4[2],ymm5[3,4],ymm4[5],ymm5[6,7,8,9],ymm4[10],ymm5[11,12],ymm4[13],ymm5[14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[0,0,2,1,4,4,6,5] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm9 = ymm10[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm9[0,0,0,0,4,4,4,4] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm9[0,1,2],ymm5[3],ymm9[4,5],ymm5[6],ymm9[7,8,9,10],ymm5[11],ymm9[12,13],ymm5[14],ymm9[15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm27 = [2,2,3,3,10,9,11,10] +-; AVX512DQ-SLOW-NEXT: vpermi2q %zmm4, %zmm5, %zmm27 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %xmm9 ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm8[4],xmm7[4],xmm8[5],xmm7[5],xmm8[6],xmm7[6],xmm8[7],xmm7[7] ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,4,5,7,6] ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm2, %xmm2 ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm3 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> ++; AVX512DQ-SLOW-NEXT: vpermi2d %zmm1, %zmm2, %zmm3 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %xmm2 ++; AVX512DQ-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm7 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm9, %xmm1 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[1,1,2,2] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2,3],xmm1[4],xmm3[5,6],xmm1[7] ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm9[0],xmm2[1],xmm9[1],xmm2[2],xmm9[2],xmm2[3],xmm9[3] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm3[0,1,3,2,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = ++; AVX512DQ-SLOW-NEXT: vpermi2d %zmm3, %zmm1, %zmm8 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %xmm3 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %xmm12 ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm12[0],xmm3[0],xmm12[1],xmm3[1],xmm12[2],xmm3[2],xmm12[3],xmm3[3] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm1[0,1,2,3,4,5,7,6] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[0,1,3,2,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm26 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> ++; AVX512DQ-SLOW-NEXT: vpermi2d %zmm8, %zmm1, %zmm26 ++; AVX512DQ-SLOW-NEXT: vprold $16, %ymm15, %ymm1 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm6[1,2,2,3,5,6,6,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0,1],ymm1[2],ymm8[3,4],ymm1[5],ymm8[6,7,8,9],ymm1[10],ymm8[11,12],ymm1[13],ymm8[14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm8 = ymm15[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[0,0,0,0,4,4,4,4] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm8[0,1,2],ymm6[3],ymm8[4,5],ymm6[6],ymm8[7,8,9,10],ymm6[11],ymm8[12,13],ymm6[14],ymm8[15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] ++; AVX512DQ-SLOW-NEXT: vpermi2q %zmm1, %zmm6, %zmm28 + ; AVX512DQ-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rax), %ymm4 +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm0 +-; AVX512DQ-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm5 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm5, %ymm4, %ymm4 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm0 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] +-; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm30 +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm12[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm9[4],xmm2[5],xmm9[5],xmm2[6],xmm9[6],xmm2[7],xmm9[7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm1, %xmm25 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rax), %ymm8 ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm8, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm1 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512DQ-SLOW-NEXT: vpandn %ymm1, %ymm2, %ymm1 ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm6 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm6, %ymm8, %ymm2 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm12[4],xmm3[4],xmm12[5],xmm3[5],xmm12[6],xmm3[6],xmm12[7],xmm3[7] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm5[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,2,1,4,4,6,5] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[1,1,1,1,5,5,5,5] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,0,0,4,4,4,4] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,1,1,3,4,5,5,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm3[0,1],ymm1[2],ymm3[3,4],ymm1[5],ymm3[6,7,8,9],ymm1[10],ymm3[11,12],ymm1[13],ymm3[14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rax), %ymm1 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm1[0,1,1,3,4,5,5,7] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,2,3] ++; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] ++; AVX512DQ-SLOW-NEXT: vpandn %ymm3, %ymm8, %ymm3 ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm6, %ymm1, %ymm6 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm2 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm0[0,0,2,1,4,4,6,5] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm6[3],ymm3[4,5],ymm6[6],ymm3[7,8,9,10],ymm6[11],ymm3[12,13],ymm6[14],ymm3[15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 ++; AVX512DQ-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm15 = [22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27,22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27] ++; AVX512DQ-SLOW-NEXT: # ymm15 = mem[0,1,0,1] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm15, %ymm13, %ymm3 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm14[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm22 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm6 = ymm13[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0],ymm6[1],ymm3[2,3],ymm6[4],ymm3[5,6,7,8],ymm6[9],ymm3[10,11],ymm6[12],ymm3[13,14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm5[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm11[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm6[0],ymm3[1],ymm6[2,3],ymm3[4],ymm6[5,6,7,8],ymm3[9],ymm6[10,11],ymm3[12],ymm6[13,14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm5 = ymm5[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm5[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm13 = ymm3[0,1,2],ymm5[3],ymm3[4,5],ymm5[6],ymm3[7,8,9,10],ymm5[11],ymm3[12,13],ymm5[14],ymm3[15] ++; AVX512DQ-SLOW-NEXT: vprold $16, %ymm4, %ymm3 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm0[1,2,2,3,5,6,6,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0,1],ymm3[2],ymm5[3,4],ymm3[5],ymm5[6,7,8,9],ymm3[10],ymm5[11,12],ymm3[13],ymm5[14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm24 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm4[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,3,6,6,6,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] + ; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm23 +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 +-; AVX512DQ-SLOW-NEXT: vprold $16, %ymm13, %ymm0 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[1,2,2,3,5,6,6,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4],ymm0[5],ymm1[6,7,8,9],ymm0[10],ymm1[11,12],ymm0[13],ymm1[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm13[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,3,6,6,6,7] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm16, %ymm4 +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm4[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm0 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] ++; AVX512DQ-SLOW-NEXT: # zmm0 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpermd %zmm1, %zmm0, %zmm29 ++; AVX512DQ-SLOW-NEXT: vmovdqu (%rsp), %ymm11 # 32-byte Reload ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm11[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] + ; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm29[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm18 +-; AVX512DQ-SLOW-NEXT: vmovdqa %ymm15, %ymm8 +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm15[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm30[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm9 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm14[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] + ; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm17[3,3,3,3,7,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm31[3,3,3,3,7,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] + ; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %xmm0 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %xmm9 +-; AVX512DQ-SLOW-NEXT: vprold $16, %xmm9, %xmm1 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm0[1,1,2,3] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm15[0,1],xmm1[2],xmm15[3,4],xmm1[5],xmm15[6,7] +-; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm0[0],xmm9[0],xmm0[1],xmm9[1],xmm0[2],xmm9[2],xmm0[3],xmm9[3] +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm9[4],xmm0[4],xmm9[5],xmm0[5],xmm9[6],xmm0[6],xmm9[7],xmm0[7] +-; AVX512DQ-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %xmm9 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm20, %xmm0 +-; AVX512DQ-SLOW-NEXT: vpshufb %xmm0, %xmm9, %xmm0 +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %xmm15 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm11 = xmm15[1,1,2,2] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm11[0],xmm0[1],xmm11[2,3],xmm0[4],xmm11[5,6],xmm0[7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm9[0],xmm15[1],xmm9[1],xmm15[2],xmm9[2],xmm15[3],xmm9[3] +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm9 = xmm15[4],xmm9[4],xmm15[5],xmm9[5],xmm15[6],xmm9[6],xmm15[7],xmm9[7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm9, %xmm25 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[1,1,1,1,5,5,5,5] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm6[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm24 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm7[0,1,1,3,4,5,5,7] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm12[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm22 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,0,2,1,4,4,6,5] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1,2],ymm3[3],ymm6[4,5],ymm3[6],ymm6[7,8,9,10],ymm3[11],ymm6[12,13],ymm3[14],ymm6[15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 +-; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm3 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] +-; AVX512DQ-SLOW-NEXT: # zmm3 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rax), %ymm6 +-; AVX512DQ-SLOW-NEXT: vpermd %zmm6, %zmm3, %zmm3 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm28, %ymm7 +-; AVX512DQ-SLOW-NEXT: vpshufb %ymm7, %ymm6, %ymm11 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,1,1,3,4,5,5,7] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm6 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm11, %zmm6 +-; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] +-; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm11 +-; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm1 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %xmm13 +-; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %xmm14 +-; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm13[0],xmm14[0],xmm13[1],xmm14[1],xmm13[2],xmm14[2],xmm13[3],xmm14[3] +-; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm2 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 +-; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm15 = xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] +-; AVX512DQ-SLOW-NEXT: vprold $16, %xmm14, %xmm14 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[1,1,2,3] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm13[0,1],xmm14[2],xmm13[3,4],xmm14[5],xmm13[6,7] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 +-; AVX512DQ-SLOW-NEXT: vmovdqa %ymm4, %ymm2 +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,0,0,0,4,4,4,4] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm29, %ymm12 +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm29[0,1,1,3,4,5,5,7] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm14 = ymm9[0,1],ymm14[2],ymm9[3,4],ymm14[5],ymm9[6,7,8,9],ymm14[10],ymm9[11,12],ymm14[13],ymm9[14,15] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm31, %xmm4 +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm4[0,2,3,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 ++; AVX512DQ-SLOW-NEXT: vprold $16, %xmm1, %xmm3 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm0[1,1,2,3] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm6[0,1],xmm3[2],xmm6[3,4],xmm3[5],xmm6[6,7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm21 ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm20 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %xmm1 ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm0 ++; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %xmm6 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm7 = xmm6[1,1,2,2] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm7[0],xmm0[1],xmm7[2,3],xmm0[4],xmm7[5,6],xmm0[7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm5 = xmm6[0],xmm1[0],xmm6[1],xmm1[1],xmm6[2],xmm1[2],xmm6[3],xmm1[3] ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm1[4],xmm6[5],xmm1[5],xmm6[6],xmm1[6],xmm6[7],xmm1[7] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm18 ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm12, %xmm2 ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm3, %xmm4 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %xmm3 ++; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %xmm6 ++; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm7 = xmm3[0],xmm6[0],xmm3[1],xmm6[1],xmm3[2],xmm6[2],xmm3[3],xmm6[3] ++; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm7, %xmm10 ++; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] ++; AVX512DQ-SLOW-NEXT: vprold $16, %xmm6, %xmm6 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm3[1,1,2,3] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm7 = xmm3[0,1],xmm6[2],xmm3[3,4],xmm6[5],xmm3[6,7] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm11[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm30[0,1,1,3,4,5,5,7] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] ++; AVX512DQ-SLOW-NEXT: vpshufb %ymm15, %ymm11, %ymm3 ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm30[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm15 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm14[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,2,1,4,4,6,5] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[1,1,1,1,5,5,5,5] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm14 = ymm14[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[2,2,2,2,6,6,6,6] ++; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm12 = ymm12[0],ymm14[1],ymm12[2,3],ymm14[4],ymm12[5,6,7,8],ymm14[9],ymm12[10,11],ymm14[12],ymm12[13,14,15] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm25, %xmm1 ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm14 = xmm1[0,2,3,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,2,1] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[2,1,2,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,4] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm0[0,0,1,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,0,1,1] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm22[2,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm17[0,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm16[0,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,1,3,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm5[0,1,3,2,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm5 = xmm5[0,0,1,1] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,1,3,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,0,1,1] ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm30 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm0 ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm13 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] ++; AVX512DQ-SLOW-NEXT: vpternlogq $184, %zmm30, %zmm13, %zmm0 ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,1,3] ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm9, %zmm5 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 ++; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm13, %zmm4 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm14, %zmm5 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm8 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm31 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] ++; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm31, %zmm8 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm1 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm2[0,1,2,3],zmm1[4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: # ymm2 = mem[2,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: # ymm5 = mem[2,1,3,2] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm27[2,2,3,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm24[2,1,3,2] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm16 = ymm23[2,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload ++; AVX512DQ-SLOW-NEXT: # ymm17 = mem[2,3,3,3,6,7,7,7] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm21[0,0,2,1] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm20, %xmm9 ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm11 = xmm9[2,1,2,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm11 = xmm11[0,1,2,3,4,5,5,4] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm19[0,0,1,1] ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm18, %xmm9 ++; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm9[0,2,3,3,4,5,6,7] + ; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,2,1] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm15 = xmm15[2,1,2,3,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm15 = xmm15[0,1,2,3,4,5,5,4] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[0,0,1,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,1] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm30[2,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm23[0,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm19[2,1,3,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm18[0,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[0,1,3,2,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,1,1] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm5[2,1,3,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,1] +-; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm12[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm4 = ymm8[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,2,1,4,4,6,5] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[1,1,1,1,5,5,5,5] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm5[2],ymm4[3,4],ymm5[5],ymm4[6,7,8,9],ymm5[10],ymm4[11,12],ymm5[13],ymm4[14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm12 = ymm8[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] +-; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm12[2,2,2,2,6,6,6,6] +-; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6,7,8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14,15] +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm28, %zmm12 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] +-; AVX512DQ-SLOW-NEXT: vpternlogq $184, %zmm12, %zmm10, %zmm7 +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,1,1,3] +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm30, %zmm0 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm31, %zmm1 +-; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm1 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm0 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm15, %zmm9 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] +-; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm9 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm11[0,1,2,3],zmm0[4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: # ymm11 = mem[2,1,3,2] +-; AVX512DQ-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: # ymm12 = mem[2,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: # ymm15 = mem[2,3,3,3,6,7,7,7] +-; AVX512DQ-SLOW-NEXT: vpermq $96, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: # ymm17 = mem[0,0,2,1] +-; AVX512DQ-SLOW-NEXT: vpshuflw $230, {{[-0-9]+}}(%r{{[sb]}}p), %xmm8 # 16-byte Folded Reload +-; AVX512DQ-SLOW-NEXT: # xmm8 = mem[2,1,2,3,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm8[0,1,2,3,4,5,5,4] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[0,0,1,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm26[0,0,1,1] +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm25, %xmm13 +-; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm13 = xmm13[0,2,3,3,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,2,1] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm24[2,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm22[2,1,3,2] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm21[2,2,3,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm20[0,0,1,1] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm16[0,0,2,1] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,1,3,2] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,2,3] +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,2,2,3] +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm16 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm16, %zmm0 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm0 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm9 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm3 +-; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm15[2,1,3,2] +-; AVX512DQ-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm9 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm7 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm7 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm17, %zmm1 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm18, %zmm8 +-; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm1, %zmm10, %zmm8 +-; AVX512DQ-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm1 +-; AVX512DQ-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm9 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm1, %zmm1 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm1 ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,1] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,0,2,1] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,1,3,2] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,2,2,3] ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,2,2,3] ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm18 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm1 + ; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm1 + ; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm19, %zmm8, %zmm8 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm28, %zmm9, %zmm9 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm0, %zmm8 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm10[0,1,2,3],zmm8[4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm6 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm6 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm8 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm8 +-; AVX512DQ-SLOW-NEXT: vpbroadcastd (%rax), %ymm9 +-; AVX512DQ-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm10 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm10, %zmm9, %zmm9 +-; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm9 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm14, %zmm2 +-; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm4, %zmm4 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm4 +-; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm2 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] +-; AVX512DQ-SLOW-NEXT: # zmm2 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] +-; AVX512DQ-SLOW-NEXT: vpermd (%rax), %zmm2, %zmm2 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm27, %zmm2 +-; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm8, %zmm2 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm2 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm8[0,1,2,3],zmm2[4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm2 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm16, %zmm14, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm29 ++; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm17[2,1,3,2] ++; AVX512DQ-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm5 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm0, %zmm0 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm0 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm0 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm22, %zmm4 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm30, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm31, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm4 ++; AVX512DQ-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm8 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm4, %zmm4 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm5 ++; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload ++; AVX512DQ-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpbroadcastd (%rax), %ymm7 ++; AVX512DQ-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm8 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm7, %zmm7 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm7 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm7 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm15, %zmm6, %zmm5 ++; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm3, %zmm3 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm3 ++; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm5 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] ++; AVX512DQ-SLOW-NEXT: # zmm5 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] ++; AVX512DQ-SLOW-NEXT: vpermd (%rax), %zmm5, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm28, %zmm5 ++; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm5 + ; AVX512DQ-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm2, 128(%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm9, (%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm6, 320(%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm1, 256(%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm7, 192(%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm0, 64(%rax) +-; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm3, 384(%rax) +-; AVX512DQ-SLOW-NEXT: addq $632, %rsp # imm = 0x278 ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm5, 128(%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm7, (%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm2, 320(%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm4, 256(%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm0, 192(%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm1, 64(%rax) ++; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm29, 384(%rax) ++; AVX512DQ-SLOW-NEXT: addq $648, %rsp # imm = 0x288 + ; AVX512DQ-SLOW-NEXT: vzeroupper + ; AVX512DQ-SLOW-NEXT: retq + ; +@@ -6235,9 +6229,9 @@ + ; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm13 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] + ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm13, %ymm6 +-; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm14 ++; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm15 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = +-; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm14, %ymm7 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm15, %ymm7 + ; AVX512DQ-FAST-NEXT: vporq %ymm6, %ymm7, %ymm25 + ; AVX512DQ-FAST-NEXT: vpshufb %ymm4, %ymm10, %ymm4 + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm6 +@@ -6251,8 +6245,9 @@ + ; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 + ; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm15 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm15, %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqa %ymm1, %ymm14 + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %ymm4 + ; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm4, %ymm1 + ; AVX512DQ-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 +@@ -6283,11 +6278,11 @@ + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 + ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] + ; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm3, %xmm3 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm30 ++; AVX512DQ-FAST-NEXT: vmovdqa %xmm5, %xmm8 + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm19 = [2,1,3,3,8,8,9,9] + ; AVX512DQ-FAST-NEXT: vpermi2q %zmm3, %zmm2, %zmm19 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] +-; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] ++; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[3,3,3,3,7,7,7,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm3[2],ymm2[3,4],ymm3[5],ymm2[6,7,8,9],ymm3[10],ymm2[11,12],ymm3[13],ymm2[14,15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [2,2,2,3,8,8,8,9] + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm3 +@@ -6315,52 +6310,54 @@ + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm27 = [0,0,0,1,8,9,9,11] + ; AVX512DQ-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm27 + ; AVX512DQ-FAST-NEXT: vprold $16, %ymm13, %ymm0 +-; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[1,2,2,3,5,6,6,7] ++; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[1,2,2,3,5,6,6,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] + ; AVX512DQ-FAST-NEXT: vpbroadcastd {{.*#+}} ymm5 = [18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21] + ; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm13, %ymm3 +-; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm14[0,0,2,1,4,4,6,5] ++; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm15[0,0,2,1,4,4,6,5] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1,2],ymm7[3],ymm3[4,5],ymm7[6],ymm3[7,8,9,10],ymm7[11],ymm3[12,13],ymm7[14],ymm3[15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] + ; AVX512DQ-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm28 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm8 ++; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm15 + ; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm0 +-; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm3 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm3, %ymm17 ++; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] ++; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm9 + ; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm25, %zmm0, %zmm2 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm30, %xmm9 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm1, %xmm1 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm1, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm8, %xmm18 + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm25 = <0,0,1,1,12,13,u,15> + ; AVX512DQ-FAST-NEXT: vpermi2q %zmm2, %zmm1, %zmm25 + ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax + ; AVX512DQ-FAST-NEXT: vpbroadcastd 8(%rax), %ymm1 +-; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] ++; AVX512DQ-FAST-NEXT: vpandn %ymm1, %ymm2, %ymm1 + ; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm3 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] + ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm3, %ymm7 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm3, %ymm16 + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm1, %zmm30 + ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm1 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18,19,20,21,18,19,20,21,24,25,26,27,22,23,22,23] + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[1,1,1,1,5,5,5,5] +-; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 ++; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm31, %ymm13 + ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} ymm1 = ymm13[0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11] + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm12[0,1,1,3,4,5,5,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm7 = ymm6[0,1],ymm1[2],ymm6[3,4],ymm1[5],ymm6[6,7,8,9],ymm1[10],ymm6[11,12],ymm1[13],ymm6[14,15] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm15, %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa %ymm14, %ymm3 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm14, %ymm1 + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm4[0,0,2,1,4,4,6,5] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1,2],ymm5[3],ymm1[4,5],ymm5[6],ymm1[7,8,9,10],ymm5[11],ymm1[12,13],ymm5[14],ymm1[15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm29 = <0,1,u,3,10,10,11,11> + ; AVX512DQ-FAST-NEXT: vpermi2q %zmm1, %zmm21, %zmm29 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rax), %ymm6 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm1 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = +-; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm2, %ymm2 +-; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm2 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm14 ++; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm1, %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] ++; AVX512DQ-FAST-NEXT: vpandn %ymm1, %ymm5, %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm2 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm14 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[14,15,12,13,u,u,u,u,u,u,u,u,u,u,u,u,30,31,28,29,u,u,u,u,30,31,28,29,u,u,u,u] + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm12[3,3,3,3,7,7,7,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0],ymm2[1],ymm5[2,3],ymm2[4],ymm5[5,6,7,8],ymm2[9],ymm5[10,11],ymm2[12],ymm5[13,14,15] +@@ -6371,21 +6368,21 @@ + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm13 = ymm13[0,1],ymm12[2],ymm13[3,4],ymm12[5],ymm13[6,7,8,9],ymm12[10],ymm13[11,12],ymm12[13],ymm13[14,15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm21 = [2,2,2,3,8,10,10,11] + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm13 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm12 = ymm4[3,3,3,3,7,7,7,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm12[2],ymm2[3,4],ymm12[5],ymm2[6,7,8,9],ymm12[10],ymm2[11,12],ymm12[13],ymm2[14,15] +-; AVX512DQ-FAST-NEXT: vprold $16, %ymm15, %ymm12 ++; AVX512DQ-FAST-NEXT: vprold $16, %ymm3, %ymm12 + ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[1,2,2,3,5,6,6,7] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm12 = ymm4[0,1],ymm12[2],ymm4[3,4],ymm12[5],ymm4[6,7,8,9],ymm12[10],ymm4[11,12],ymm12[13],ymm4[14,15] + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = [2,1,3,2,10,10,10,11] + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm31, %zmm12 +-; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm18 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm22, %zmm18, %zmm13 ++; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm17 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm22, %zmm17, %zmm13 + ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm12 +-; AVX512DQ-FAST-NEXT: vmovdqa64 (%rax), %zmm15 ++; AVX512DQ-FAST-NEXT: vmovdqa64 (%rax), %zmm3 + ; AVX512DQ-FAST-NEXT: vbroadcasti32x8 {{.*#+}} zmm4 = [30,5,0,0,31,6,0,31,30,5,0,0,31,6,0,31] + ; AVX512DQ-FAST-NEXT: # zmm4 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] +-; AVX512DQ-FAST-NEXT: vpermi2d %zmm15, %zmm6, %zmm4 ++; AVX512DQ-FAST-NEXT: vpermi2d %zmm3, %zmm6, %zmm4 + ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm4 + ; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm6 = xmm11[4],xmm10[4],xmm11[5],xmm10[5],xmm11[6],xmm10[6],xmm11[7],xmm10[7] + ; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} xmm12 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] +@@ -6400,14 +6397,15 @@ + ; AVX512DQ-FAST-NEXT: # xmm6 = xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] + ; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm24, %xmm1 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm13 = xmm1[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm6, %xmm6 + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm22 = [0,1,1,3,8,8,9,9] + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm6, %zmm22, %zmm13 + ; AVX512DQ-FAST-NEXT: vprold $16, %xmm0, %xmm6 +-; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm8[1,1,2,3] ++; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm15[1,1,2,3] + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1],xmm6[2],xmm2[3,4],xmm6[5],xmm2[6,7] +-; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm8[0],xmm0[0],xmm8[1],xmm0[1],xmm8[2],xmm0[2],xmm8[3],xmm0[3] +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 ++; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm0[0],xmm15[1],xmm0[1],xmm15[2],xmm0[2],xmm15[3],xmm0[3] ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm11, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm2 + ; AVX512DQ-FAST-NEXT: vpshufb %xmm12, %xmm2, %xmm6 +@@ -6436,8 +6434,8 @@ + ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm5 = ymm10[0,1],ymm5[2],ymm10[3,4],ymm5[5],ymm10[6,7,8,9],ymm5[10],ymm10[11,12],ymm5[13],ymm10[14,15] + ; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm10 = xmm12[0,2,3,3,4,5,6,7] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,2,1] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm17[0,0,1,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm16[2,2,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,1,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,2,3] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,1,3,2] + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm0, %zmm31, %zmm5 + ; AVX512DQ-FAST-NEXT: vpbroadcastd (%rax), %ymm0 +@@ -6456,8 +6454,8 @@ + ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm12 + ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = <6,u,u,u,7,u,u,7> +-; AVX512DQ-FAST-NEXT: vpermd %ymm3, %ymm2, %ymm2 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm15, %zmm3 ++; AVX512DQ-FAST-NEXT: vpermd %ymm16, %ymm2, %ymm2 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm16, %zmm3, %zmm3 + ; AVX512DQ-FAST-NEXT: vbroadcasti32x8 {{.*#+}} zmm5 = [0,13,4,0,0,14,5,0,0,13,4,0,0,14,5,0] + ; AVX512DQ-FAST-NEXT: # zmm5 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] + ; AVX512DQ-FAST-NEXT: vpermd %zmm3, %zmm5, %zmm3 +@@ -6466,7 +6464,7 @@ + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm5 # 32-byte Folded Reload + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm9 # 32-byte Folded Reload + ; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm5, %zmm11, %zmm9 +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm23, %zmm18, %zmm19 ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm23, %zmm17, %zmm19 + ; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm30 + ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm30 + ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload +@@ -11708,7 +11706,7 @@ + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 96(%r8), %ymm12 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm12, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17,u,u,u,u],zero,zero ++; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15],zero,zero,ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17],zero,zero,ymm12[u,u],zero,zero + ; AVX512F-ONLY-SLOW-NEXT: vpternlogq $248, %ymm9, %ymm7, %ymm6 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 96(%r9), %ymm15 + ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +@@ -12367,7 +12365,7 @@ + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 96(%r8), %ymm8 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17,u,u,u,u],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15],zero,zero,ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17],zero,zero,ymm8[u,u],zero,zero + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm13, %ymm9, %ymm12 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm13, %ymm14 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 96(%r9), %ymm13 +@@ -13031,7 +13029,7 @@ + ; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] + ; AVX512DQ-SLOW-NEXT: vmovdqa 96(%r8), %ymm12 + ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm12, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17,u,u,u,u],zero,zero ++; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15],zero,zero,ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17],zero,zero,ymm12[u,u],zero,zero + ; AVX512DQ-SLOW-NEXT: vpternlogq $248, %ymm9, %ymm7, %ymm6 + ; AVX512DQ-SLOW-NEXT: vmovdqa 96(%r9), %ymm15 + ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +@@ -13690,7 +13688,7 @@ + ; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] + ; AVX512DQ-FAST-NEXT: vmovdqa 96(%r8), %ymm8 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17,u,u,u,u],zero,zero ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15],zero,zero,ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17],zero,zero,ymm8[u,u],zero,zero + ; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm13, %ymm9, %ymm12 + ; AVX512DQ-FAST-NEXT: vmovdqa %ymm13, %ymm14 + ; AVX512DQ-FAST-NEXT: vmovdqa 96(%r9), %ymm13 +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll +@@ -3967,7 +3967,8 @@ + ; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm0, %ymm1 + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[0,2,1,1,4,6,5,5] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,3,2] +-; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = [255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255] ++; AVX512F-SLOW-NEXT: vpandn %ymm0, %ymm14, %ymm0 + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm25 + ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm0 = [9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12] + ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm9, %ymm9 +@@ -4063,63 +4064,63 @@ + ; + ; AVX512F-FAST-LABEL: store_i8_stride5_vf64: + ; AVX512F-FAST: # %bb.0: +-; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %ymm7 ++; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %ymm6 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128,128,128,128] +-; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm7, %ymm0 +-; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 +-; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = <12,13,128,15,12,13,14,128,12,13,14,15,128,u,u,u,16,128,18,19,16,17,128,19,16,17,18,128,16,17,18,19> +-; AVX512F-FAST-NEXT: vpshufb %ymm6, %ymm3, %ymm1 ++; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm6, %ymm0 ++; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 ++; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = <12,13,128,15,12,13,14,128,12,13,14,15,128,u,u,u,16,128,18,19,16,17,128,19,16,17,18,128,16,17,18,19> ++; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm1 + ; AVX512F-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %xmm1 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm15 = <8,128,u,7,128,9,128,u,128,u,10,128,12,128,u,11> + ; AVX512F-FAST-NEXT: vpshufb %xmm15, %xmm1, %xmm0 + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm1, %xmm18 +-; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %xmm2 ++; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %xmm3 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <128,8,u,128,7,128,9,u,11,u,128,10,128,12,u,128> +-; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm1 ++; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm3, %xmm1 + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm4, %xmm25 +-; AVX512F-FAST-NEXT: vmovdqa64 %xmm2, %xmm17 ++; AVX512F-FAST-NEXT: vmovdqa64 %xmm3, %xmm17 + ; AVX512F-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 + ; AVX512F-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-FAST-NEXT: vmovdqa 32(%rcx), %ymm9 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128,128] +-; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm2 ++; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm3 + ; AVX512F-FAST-NEXT: vmovdqa 32(%rdx), %ymm8 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = + ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm4 +-; AVX512F-FAST-NEXT: vpor %ymm2, %ymm4, %ymm2 +-; AVX512F-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-FAST-NEXT: vpor %ymm3, %ymm4, %ymm3 ++; AVX512F-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-FAST-NEXT: vmovdqa 32(%rcx), %xmm10 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <128,6,128,8,u,128,7,128,9,128,11,u,128,10,128,12> +-; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm10, %xmm2 ++; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm10, %xmm3 + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm4, %xmm26 + ; AVX512F-FAST-NEXT: vmovdqa 32(%rdx), %xmm11 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = <6,128,8,128,u,7,128,9,128,11,128,u,10,128,12,128> + ; AVX512F-FAST-NEXT: vpshufb %xmm14, %xmm11, %xmm4 +-; AVX512F-FAST-NEXT: vporq %xmm2, %xmm4, %xmm21 +-; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm7[11,u,u,10,u,12,u,u,u,u,13,u,15,u,u,14,27,u,u,26,u,28,u,u,u,u,29,u,31,u,u,30] +-; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm7[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm4, %zmm22 +-; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm3[26],zero,ymm3[28],zero,zero,ymm3[27],zero,ymm3[29],zero,ymm3[31],zero,zero,ymm3[30],zero +-; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[21],zero,zero,ymm3[20],zero,ymm3[22],zero,ymm3[24],zero,zero,ymm3[23],zero,ymm3[25],zero,zero +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm23 ++; AVX512F-FAST-NEXT: vporq %xmm3, %xmm4, %xmm19 ++; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm6[11,u,u,10,u,12,u,u,u,u,13,u,15,u,u,14,27,u,u,26,u,28,u,u,u,u,29,u,31,u,u,30] ++; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm6[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm22 ++; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[26],zero,ymm2[28],zero,zero,ymm2[27],zero,ymm2[29],zero,ymm2[31],zero,zero,ymm2[30],zero ++; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[21],zero,zero,ymm2[20],zero,ymm2[22],zero,ymm2[24],zero,zero,ymm2[23],zero,ymm2[25],zero,zero ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm23 + ; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm8[27],zero,zero,ymm8[26],zero,ymm8[28],zero,ymm8[30],zero,zero,ymm8[29],zero,ymm8[31],zero,zero + ; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm4 = [128,128,19,128,21,128,128,20,128,22,128,24,128,128,23,128,128,128,19,128,21,128,128,20,128,22,128,24,128,128,23,128] + ; AVX512F-FAST-NEXT: # ymm4 = mem[0,1,0,1] + ; AVX512F-FAST-NEXT: vpshufb %ymm4, %ymm9, %ymm3 + ; AVX512F-FAST-NEXT: vmovdqa64 %ymm4, %ymm30 + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm24 +-; AVX512F-FAST-NEXT: vmovdqa (%rcx), %ymm7 +-; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm7, %ymm0 +-; AVX512F-FAST-NEXT: vmovdqa (%rdx), %ymm12 +-; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 +-; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm19 ++; AVX512F-FAST-NEXT: vmovdqa (%rcx), %ymm12 ++; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm12, %ymm0 ++; AVX512F-FAST-NEXT: vmovdqa (%rdx), %ymm6 ++; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm6, %ymm1 ++; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm20 + ; AVX512F-FAST-NEXT: vmovdqa (%rsi), %ymm5 + ; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm5, %ymm0 + ; AVX512F-FAST-NEXT: vmovdqa (%rdi), %ymm4 +-; AVX512F-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm1 +-; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm20 ++; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm4, %ymm1 ++; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 + ; AVX512F-FAST-NEXT: vmovdqa (%rdi), %xmm1 + ; AVX512F-FAST-NEXT: vpshufb %xmm15, %xmm1, %xmm0 + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm1, %xmm16 +@@ -4127,9 +4128,9 @@ + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm25, %xmm1 + ; AVX512F-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm2 + ; AVX512F-FAST-NEXT: vporq %xmm0, %xmm2, %xmm28 +-; AVX512F-FAST-NEXT: vmovdqa (%rcx), %xmm6 ++; AVX512F-FAST-NEXT: vmovdqa (%rcx), %xmm7 + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm26, %xmm0 +-; AVX512F-FAST-NEXT: vpshufb %xmm0, %xmm6, %xmm0 ++; AVX512F-FAST-NEXT: vpshufb %xmm0, %xmm7, %xmm0 + ; AVX512F-FAST-NEXT: vmovdqa (%rdx), %xmm2 + ; AVX512F-FAST-NEXT: vpshufb %xmm14, %xmm2, %xmm14 + ; AVX512F-FAST-NEXT: vporq %xmm0, %xmm14, %xmm29 +@@ -4141,25 +4142,26 @@ + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [12,128,128,128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128] + ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm15 + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm14, %zmm27 +-; AVX512F-FAST-NEXT: vmovdqa64 (%r8), %zmm26 ++; AVX512F-FAST-NEXT: vmovdqa64 (%r8), %zmm25 + ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = <4,u,5,5,5,5,u,6,30,30,30,u,31,31,31,31> +-; AVX512F-FAST-NEXT: vpermi2d %zmm26, %zmm0, %zmm31 +-; AVX512F-FAST-NEXT: vmovdqa (%r8), %ymm0 +-; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 ++; AVX512F-FAST-NEXT: vpermi2d %zmm25, %zmm0, %zmm31 + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm15 = <4,u,5,5,5,5,u,6> ++; AVX512F-FAST-NEXT: vmovdqa (%r8), %ymm0 + ; AVX512F-FAST-NEXT: vpermd %ymm0, %ymm15, %ymm15 +-; AVX512F-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm15, %ymm15 +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm1, %zmm25 ++; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} ymm26 = [255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255] ++; AVX512F-FAST-NEXT: vpandnq %ymm15, %ymm26, %ymm15 ++; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm1, %zmm26 + ; AVX512F-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12] + ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm9, %ymm9 + ; AVX512F-FAST-NEXT: vmovdqa64 %ymm30, %ymm13 +-; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm7, %ymm15 +-; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm7, %ymm1 +-; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm7 = [18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25,18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25] +-; AVX512F-FAST-NEXT: # ymm7 = mem[0,1,0,1] +-; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm8, %ymm8 +-; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm12, %ymm7 +-; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm12[27],zero,zero,ymm12[26],zero,ymm12[28],zero,ymm12[30],zero,zero,ymm12[29],zero,ymm12[31],zero,zero ++; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm12, %ymm15 ++; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 ++; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25,18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25] ++; AVX512F-FAST-NEXT: # ymm12 = mem[0,1,0,1] ++; AVX512F-FAST-NEXT: vpshufb %ymm12, %ymm8, %ymm8 ++; AVX512F-FAST-NEXT: vpshufb %ymm12, %ymm6, %ymm12 ++; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm6[27],zero,zero,ymm6[26],zero,ymm6[28],zero,ymm6[30],zero,zero,ymm6[29],zero,ymm6[31],zero,zero + ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm10 = xmm10[0],xmm11[0],xmm10[1],xmm11[1],xmm10[2],xmm11[2],xmm10[3],xmm11[3],xmm10[4],xmm11[4],xmm10[5],xmm11[5],xmm10[6],xmm11[6],xmm10[7],xmm11[7] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm30 = ymm9[2,2,3,3] + ; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm5[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] +@@ -4171,11 +4173,11 @@ + ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm14 = xmm14[0],xmm13[0],xmm14[1],xmm13[1],xmm14[2],xmm13[2],xmm14[3],xmm13[3],xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,3,3] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,3,3] +-; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,2,3,3] ++; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,2,3,3] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,2,3,3] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,2,3,3] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,2,3,3] +-; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,2,3,3] ++; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,2,3,3] + ; AVX512F-FAST-NEXT: vmovdqa64 %xmm16, %xmm13 + ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm3 = xmm13[0],xmm3[0],xmm13[1],xmm3[1],xmm13[2],xmm3[2],xmm13[3],xmm3[3],xmm13[4],xmm3[4],xmm13[5],xmm3[5],xmm13[6],xmm3[6],xmm13[7],xmm3[7] + ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,4,5,u,2,3,6,7,10,11,u,8,9,12,13> +@@ -4188,42 +4190,42 @@ + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,3,3] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,1,1] + ; AVX512F-FAST-NEXT: vinserti32x4 $2, %xmm28, %zmm3, %zmm3 +-; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm6[0],xmm2[0],xmm6[1],xmm2[1],xmm6[2],xmm2[2],xmm6[3],xmm2[3],xmm6[4],xmm2[4],xmm6[5],xmm2[5],xmm6[6],xmm2[6],xmm6[7],xmm2[7] ++; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm7[0],xmm2[0],xmm7[1],xmm2[1],xmm7[2],xmm2[2],xmm7[3],xmm2[3],xmm7[4],xmm2[4],xmm7[5],xmm2[5],xmm7[6],xmm2[6],xmm7[7],xmm2[7] + ; AVX512F-FAST-NEXT: vpshufb %xmm13, %xmm2, %xmm2 + ; AVX512F-FAST-NEXT: vinserti32x4 $2, %xmm29, %zmm2, %zmm2 +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm26, %zmm0 +-; AVX512F-FAST-NEXT: vpermq $80, {{[-0-9]+}}(%r{{[sb]}}p), %ymm6 # 32-byte Folded Reload +-; AVX512F-FAST-NEXT: # ymm6 = mem[0,0,1,1] +-; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6, %zmm6 # 32-byte Folded Reload +-; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm21[0,0,1,1] ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm25, %zmm0 ++; AVX512F-FAST-NEXT: vpermq $80, {{[-0-9]+}}(%r{{[sb]}}p), %ymm7 # 32-byte Folded Reload ++; AVX512F-FAST-NEXT: # ymm7 = mem[0,0,1,1] ++; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm7 # 32-byte Folded Reload ++; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm19[0,0,1,1] + ; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm13 # 32-byte Folded Reload + ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm16 = [255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0] +-; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm16, %zmm13 +-; AVX512F-FAST-NEXT: vpor %ymm7, %ymm15, %ymm6 +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm19, %zmm6 +-; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [18374966859431608575,18374966859431608575,18446463693966278400,18446463693966278400] +-; AVX512F-FAST-NEXT: vpternlogq $248, %ymm7, %ymm11, %ymm9 +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm9 +-; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm16, %zmm9 +-; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm22[2,2,3,3,6,6,7,7] ++; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm16, %zmm13 ++; AVX512F-FAST-NEXT: vpor %ymm15, %ymm12, %ymm7 ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm20, %zmm7 ++; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm12 = [18374966859431608575,18374966859431608575,18446463693966278400,18446463693966278400] ++; AVX512F-FAST-NEXT: vpternlogq $248, %ymm12, %ymm11, %ymm9 ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm21, %zmm9 ++; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm16, %zmm9 ++; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm7 = zmm22[2,2,3,3,6,6,7,7] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm11 = zmm23[2,2,3,3,6,6,7,7] +-; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm11 +-; AVX512F-FAST-NEXT: vpternlogq $248, %ymm7, %ymm1, %ymm12 +-; AVX512F-FAST-NEXT: vpandq %ymm7, %ymm30, %ymm1 ++; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm11 ++; AVX512F-FAST-NEXT: vpternlogq $248, %ymm12, %ymm1, %ymm6 ++; AVX512F-FAST-NEXT: vpandq %ymm12, %ymm30, %ymm1 + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm8, %zmm1 +-; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm24[2,2,3,3,6,6,7,7] +-; AVX512F-FAST-NEXT: vporq %zmm6, %zmm1, %zmm1 +-; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm6 = [0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255] +-; AVX512F-FAST-NEXT: vpternlogq $226, %zmm11, %zmm6, %zmm1 +-; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm10, %zmm12, %zmm7 ++; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm7 = zmm24[2,2,3,3,6,6,7,7] ++; AVX512F-FAST-NEXT: vporq %zmm7, %zmm1, %zmm1 ++; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm7 = [0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255] ++; AVX512F-FAST-NEXT: vpternlogq $226, %zmm11, %zmm7, %zmm1 ++; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm10, %zmm6, %zmm6 + ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm4 + ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm14, %zmm4, %zmm4 +-; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm6, %zmm4 ++; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm7, %zmm4 + ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm27 + ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm31 +-; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm25 ++; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm26 + ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm1 = <6,6,6,u,7,7,7,7,u,8,8,8,8,u,9,9> +-; AVX512F-FAST-NEXT: vpermd %zmm26, %zmm1, %zmm1 ++; AVX512F-FAST-NEXT: vpermd %zmm25, %zmm1, %zmm1 + ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm1 + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm3 = zmm3[0,0,1,1,4,4,5,5] + ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm2[0,0,1,1,4,4,5,5] +@@ -4231,7 +4233,7 @@ + ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm3 = + ; AVX512F-FAST-NEXT: vpermd %zmm0, %zmm3, %zmm0 + ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 +-; AVX512F-FAST-NEXT: vmovdqa64 %zmm25, 64(%r9) ++; AVX512F-FAST-NEXT: vmovdqa64 %zmm26, 64(%r9) + ; AVX512F-FAST-NEXT: vmovdqa64 %zmm0, (%r9) + ; AVX512F-FAST-NEXT: vmovdqa64 %zmm1, 128(%r9) + ; AVX512F-FAST-NEXT: vmovdqa64 %zmm31, 256(%r9) +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll +--- a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll ++++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll +@@ -7382,14 +7382,14 @@ + ; + ; AVX512F-SLOW-LABEL: store_i8_stride7_vf64: + ; AVX512F-SLOW: # %bb.0: +-; AVX512F-SLOW-NEXT: subq $1464, %rsp # imm = 0x5B8 ++; AVX512F-SLOW-NEXT: subq $1416, %rsp # imm = 0x588 + ; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %ymm1 + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] +-; AVX512F-SLOW-NEXT: vmovdqa %ymm1, %ymm9 ++; AVX512F-SLOW-NEXT: vmovdqa %ymm1, %ymm12 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %ymm2 + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[0,1,14],zero,ymm2[12,13,0,1,14,15],zero,ymm2[3,12,13,2,3,16],zero,ymm2[30,31,28,29,16,17],zero,ymm2[31,18,19,28,29,18],zero +-; AVX512F-SLOW-NEXT: vmovdqa %ymm2, %ymm10 ++; AVX512F-SLOW-NEXT: vmovdqa %ymm2, %ymm9 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +@@ -7400,46 +7400,44 @@ + ; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %ymm8 + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] + ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm1 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm28 + ; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm0 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] +-; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm0 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 +-; AVX512F-SLOW-NEXT: vmovdqa (%r9), %ymm1 +-; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm0[14],zero,zero,zero,zero,zero,zero,ymm0[15],zero,zero,zero,zero,zero,zero,ymm0[16],zero,zero,zero,zero,zero,zero,ymm0[17],zero,zero,zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa (%r9), %ymm2 + ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] + ; AVX512F-SLOW-NEXT: # ymm3 = mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpshufb %ymm3, %ymm1, %ymm1 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 +-; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 +-; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm14[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm14[27],zero,zero,zero,zero,ymm14[30],zero,ymm14[28],zero,zero,zero,zero,ymm14[31],zero,ymm14[29] +-; AVX512F-SLOW-NEXT: vmovdqu %ymm14, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm13[25],zero,ymm13[23],zero,zero,zero,zero,ymm13[26],zero,ymm13[24],zero,zero +-; AVX512F-SLOW-NEXT: vmovdqu %ymm13, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-SLOW-NEXT: vpshufb %ymm3, %ymm2, %ymm1 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm3, %ymm17 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 ++; AVX512F-SLOW-NEXT: vporq %ymm0, %ymm1, %ymm23 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %ymm10 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm11 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[27],zero,zero,zero,zero,ymm11[30],zero,ymm11[28],zero,zero,zero,zero,ymm11[31],zero,ymm11[29] ++; AVX512F-SLOW-NEXT: vmovdqu %ymm11, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[25],zero,ymm10[23],zero,zero,zero,zero,ymm10[26],zero,ymm10[24],zero,zero + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %ymm6 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %ymm11 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm11[30],zero,ymm11[28],zero,zero,zero,zero,ymm11[31],zero,ymm11[29],zero,zero +-; AVX512F-SLOW-NEXT: vmovdqu %ymm11, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %ymm5 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %ymm6 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero ++; AVX512F-SLOW-NEXT: vmovdqu %ymm6, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [128,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128,128,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128] + ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm6, %ymm1 ++; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm1 ++; AVX512F-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm5[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm5[30],zero,ymm5[28],zero,zero,zero,zero,ymm5[31],zero,ymm5[29],zero,zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %ymm1 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm1, %ymm21 + ; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %ymm4 + ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm1 = [128,23,128,128,128,128,26,128,24,128,128,128,128,27,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128,27,128,25] + ; AVX512F-SLOW-NEXT: # ymm1 = mem[0,1,0,1] + ; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm19 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm20 + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax +@@ -7461,179 +7459,182 @@ + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm8, %ymm18 + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm1 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm9[21],zero,ymm9[19],zero,zero,zero,zero,ymm9[22],zero,ymm9[20],zero,zero ++; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm9, %ymm1 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa (%rax), %ymm3 +-; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm3, %ymm0 ++; AVX512F-SLOW-NEXT: vmovdqa (%rax), %ymm1 ++; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm0 + ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm25 = +-; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] +-; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm1, %zmm25 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %xmm2 ++; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm2 = ymm1[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] ++; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm25 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %xmm3 + ; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %xmm15 +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm1 = +-; AVX512F-SLOW-NEXT: vpshufb %xmm1, %xmm15, %xmm0 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm1, %xmm16 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm2 = ++; AVX512F-SLOW-NEXT: vpshufb %xmm2, %xmm15, %xmm0 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm2, %xmm19 + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = +-; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm2, %xmm1 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm4, %xmm23 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm2, %xmm30 +-; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm1, %xmm0 ++; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm3, %xmm2 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm4, %xmm29 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm3, %xmm30 ++; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm2, %xmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %xmm7 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm8 = +-; AVX512F-SLOW-NEXT: vpshufb %xmm8, %xmm1, %xmm0 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %xmm8 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %xmm0 ++; AVX512F-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = ++; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm0, %xmm0 + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = +-; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm7, %xmm2 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm7, %xmm22 ++; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm8, %xmm2 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm8, %xmm22 + ; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm2, %xmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = <0,u,0,u,2,3,u,1,u,18,u,19,18,u,19,u> ++; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = <0,u,0,u,2,3,u,1,u,18,u,19,18,u,19,u> + ; AVX512F-SLOW-NEXT: vmovdqa 32(%rax), %xmm2 +-; AVX512F-SLOW-NEXT: vmovdqa %xmm2, (%rsp) # 16-byte Spill ++; AVX512F-SLOW-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill + ; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm2[0,1,2,3,4,5,5,6] + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] +-; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm7 +-; AVX512F-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm8 ++; AVX512F-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %xmm0 +-; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %xmm2 +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> +-; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm0, %xmm9 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm12 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> +-; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm2, %xmm10 +-; AVX512F-SLOW-NEXT: vporq %xmm9, %xmm10, %xmm24 ++; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %xmm13 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm12 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> ++; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm0, %xmm8 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm0, %xmm26 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> ++; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm13, %xmm9 ++; AVX512F-SLOW-NEXT: vporq %xmm8, %xmm9, %xmm24 + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm27, %ymm0 ++; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm5, %ymm8 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm0 + ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm6, %ymm9 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm6, %ymm26 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm0 +-; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm10 +-; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm0 ++; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm9 = zero,zero,zero,ymm5[14],zero,zero,zero,zero,zero,zero,ymm5[15],zero,zero,zero,zero,zero,zero,ymm5[16],zero,zero,zero,zero,zero,zero,ymm5[17],zero,zero,zero,zero,zero,zero,ymm5[18] +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm27 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm19, %ymm0 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[0,1,14],zero,ymm0[12,13,0,1,14,15],zero,ymm0[3,12,13,2,3,16],zero,ymm0[30,31,28,29,16,17],zero,ymm0[31,18,19,28,29,18],zero +-; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm5 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm21, %ymm3 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero,zero,zero,ymm3[18] ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm20, %ymm0 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm9 = ymm0[0,1,14],zero,ymm0[12,13,0,1,14,15],zero,ymm0[3,12,13,2,3,16],zero,ymm0[30,31,28,29,16,17],zero,ymm0[31,18,19,28,29,18],zero ++; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm5 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm20, %ymm9 +-; AVX512F-SLOW-NEXT: vpshufb %ymm9, %ymm14, %ymm9 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm21, %ymm6 +-; AVX512F-SLOW-NEXT: vpshufb %ymm6, %ymm13, %ymm10 +-; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm5 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = zero,zero,zero,zero,zero,zero,ymm11[14],zero,zero,zero,zero,zero,zero,ymm11[15],zero,zero,zero,zero,zero,zero,ymm11[16],zero,zero,zero,zero,zero,zero,ymm11[17],zero,zero,zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm6 ++; AVX512F-SLOW-NEXT: vpshufb %ymm6, %ymm10, %ymm9 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm10, %ymm28 ++; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm5 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %xmm9 +-; AVX512F-SLOW-NEXT: vpshufb %xmm8, %xmm9, %xmm8 +-; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %xmm10 +-; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm10, %xmm4 +-; AVX512F-SLOW-NEXT: vporq %xmm8, %xmm4, %xmm21 +-; AVX512F-SLOW-NEXT: vmovdqa (%rcx), %xmm5 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm16, %xmm4 +-; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm5, %xmm4 +-; AVX512F-SLOW-NEXT: vmovdqa %xmm5, %xmm11 +-; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm6 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm23, %xmm5 +-; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm6, %xmm8 +-; AVX512F-SLOW-NEXT: vporq %xmm4, %xmm8, %xmm19 ++; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %xmm6 ++; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm6, %xmm5 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm6, %xmm20 ++; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %xmm9 ++; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm9, %xmm4 ++; AVX512F-SLOW-NEXT: vporq %xmm5, %xmm4, %xmm21 ++; AVX512F-SLOW-NEXT: vmovdqa (%rcx), %xmm2 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm19, %xmm4 ++; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm2, %xmm4 ++; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm10 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm29, %xmm5 ++; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm10, %xmm7 ++; AVX512F-SLOW-NEXT: vporq %xmm4, %xmm7, %xmm19 + ; AVX512F-SLOW-NEXT: vmovdqa (%r9), %xmm5 + ; AVX512F-SLOW-NEXT: vmovdqa %xmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm5, %xmm4 +-; AVX512F-SLOW-NEXT: vmovdqa (%r8), %xmm8 +-; AVX512F-SLOW-NEXT: vmovdqa %xmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm8, %xmm7 +-; AVX512F-SLOW-NEXT: vpor %xmm4, %xmm7, %xmm4 ++; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm5, %xmm4 ++; AVX512F-SLOW-NEXT: vmovdqa (%r8), %xmm7 ++; AVX512F-SLOW-NEXT: vmovdqa %xmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm7, %xmm6 ++; AVX512F-SLOW-NEXT: vpor %xmm4, %xmm6, %xmm4 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm4 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm7 = xmm8[8],xmm5[8],xmm8[9],xmm5[9],xmm8[10],xmm5[10],xmm8[11],xmm5[11],xmm8[12],xmm5[12],xmm8[13],xmm5[13],xmm8[14],xmm5[14],xmm8[15],xmm5[15] ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm4 ++; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm6 = xmm7[8],xmm5[8],xmm7[9],xmm5[9],xmm7[10],xmm5[10],xmm7[11],xmm5[11],xmm7[12],xmm5[12],xmm7[13],xmm5[13],xmm7[14],xmm5[14],xmm7[15],xmm5[15] + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm5 = +-; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm7, %xmm7 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm5, %xmm29 +-; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm7[0,1,0,1],zmm4[4,5,6,7] ++; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm6, %xmm6 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm5, %xmm27 ++; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm6[0,1,0,1],zmm4[4,5,6,7] + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vmovdqa (%rax), %xmm13 +-; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm4 = xmm13[0,1,2,3,4,5,5,6] ++; AVX512F-SLOW-NEXT: vmovdqa (%rax), %xmm12 ++; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm4 = xmm12[0,1,2,3,4,5,5,6] + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[2,2,3,3] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm4 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm12 = zero,ymm3[13],zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm4, %zmm23 ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm11 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] ++; AVX512F-SLOW-NEXT: vpandn %ymm4, %ymm11, %ymm4 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = zero,ymm1[13],zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm4, %zmm23 + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm18, %ymm4 + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm4 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[30],zero,ymm4[28],zero,zero,zero,zero,ymm4[31],zero,ymm4[29],zero,zero + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm18 + ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm5 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] +-; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm16, %ymm14 + ; AVX512F-SLOW-NEXT: vpshufb %ymm5, %ymm14, %ymm4 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm29 + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] +-; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Reload +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm12 = ymm5[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm5[27],zero,zero,zero,zero,ymm5[30],zero,ymm5[28],zero,zero,zero,zero,ymm5[31],zero,ymm5[29] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm12 +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm31, %xmm7 +-; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm4 = xmm2[0],xmm7[0],xmm2[1],xmm7[1],xmm2[2],xmm7[2],xmm2[3],xmm7[3],xmm2[4],xmm7[4],xmm2[5],xmm7[5],xmm2[6],xmm7[6],xmm2[7],xmm7[7] ++; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Reload ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm8[27],zero,zero,zero,zero,ymm8[30],zero,ymm8[28],zero,zero,zero,zero,ymm8[31],zero,ymm8[29] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm11 ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm26, %xmm7 ++; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm4 = xmm13[0],xmm7[0],xmm13[1],xmm7[1],xmm13[2],xmm7[2],xmm13[3],xmm7[3],xmm13[4],xmm7[4],xmm13[5],xmm7[5],xmm13[6],xmm7[6],xmm13[7],xmm7[7] + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] +-; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm12[0,1,2,3],zmm4[0,1,0,1] ++; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm11[0,1,2,3],zmm4[0,1,0,1] + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] +-; AVX512F-SLOW-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm0[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] ++; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} ymm1 = ymm0[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm17 +-; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,3,3,6,6,7,7] ++; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[2,2,3,3,6,6,7,7] + ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm4 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm27, %ymm8 +-; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm8, %ymm12 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm12, %zmm20 +-; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm3[30],zero,ymm3[28],zero,zero,zero,zero,ymm3[31],zero,ymm3[29],zero,zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa %ymm3, %ymm6 ++; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm3, %ymm11 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm11, %zmm26 ++; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm3, %ymm3 ++; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm1, %ymm1 + ; AVX512F-SLOW-NEXT: vpshuflw $233, {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Folded Reload + ; AVX512F-SLOW-NEXT: # ymm4 = mem[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,1,1,4,4,5,5] +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm0 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm4, %zmm0 + ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm30, %xmm0 +-; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm3 = xmm0[0],xmm15[0],xmm0[1],xmm15[1],xmm0[2],xmm15[2],xmm0[3],xmm15[3],xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm3, %xmm16 +-; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm15[8],xmm0[8],xmm15[9],xmm0[9],xmm15[10],xmm0[10],xmm15[11],xmm0[11],xmm15[12],xmm0[12],xmm15[13],xmm0[13],xmm15[14],xmm0[14],xmm15[15],xmm0[15] +-; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm11[8],xmm6[8],xmm11[9],xmm6[9],xmm11[10],xmm6[10],xmm11[11],xmm6[11],xmm11[12],xmm6[12],xmm11[13],xmm6[13],xmm11[14],xmm6[14],xmm11[15],xmm6[15] +-; AVX512F-SLOW-NEXT: vmovdqa %xmm11, %xmm12 ++; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm1 = xmm0[0],xmm15[0],xmm0[1],xmm15[1],xmm0[2],xmm15[2],xmm0[3],xmm15[3],xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm1, %xmm16 ++; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm15[8],xmm0[8],xmm15[9],xmm0[9],xmm15[10],xmm0[10],xmm15[11],xmm0[11],xmm15[12],xmm0[12],xmm15[13],xmm0[13],xmm15[14],xmm0[14],xmm15[15],xmm0[15] ++; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm2[8],xmm10[8],xmm2[9],xmm10[9],xmm2[10],xmm10[10],xmm2[11],xmm10[11],xmm2[12],xmm10[12],xmm2[13],xmm10[13],xmm2[14],xmm10[14],xmm2[15],xmm10[15] ++; AVX512F-SLOW-NEXT: vmovdqa %xmm2, %xmm11 + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm15 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> + ; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm4, %xmm0 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm3, %xmm3 ++; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm1, %xmm1 + ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm3, %zmm0, %zmm30 ++; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm30 + ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm22, %xmm0 ++; AVX512F-SLOW-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload + ; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm15 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] + ; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] +-; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm20, %xmm5 ++; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm5[8],xmm9[8],xmm5[9],xmm9[9],xmm5[10],xmm9[10],xmm5[11],xmm9[11],xmm5[12],xmm9[12],xmm5[13],xmm9[13],xmm5[14],xmm9[14],xmm5[15],xmm9[15] + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> + ; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm3, %xmm0 + ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm1, %xmm1 + ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload + ; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm22 +-; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm2[8],xmm7[8],xmm2[9],xmm7[9],xmm2[10],xmm7[10],xmm2[11],xmm7[11],xmm2[12],xmm7[12],xmm2[13],xmm7[13],xmm2[14],xmm7[14],xmm2[15],xmm7[15] +-; AVX512F-SLOW-NEXT: vmovdqa64 %xmm29, %xmm1 ++; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm13[8],xmm7[8],xmm13[9],xmm7[9],xmm13[10],xmm7[10],xmm13[11],xmm7[11],xmm13[12],xmm7[12],xmm13[13],xmm7[13],xmm13[14],xmm7[14],xmm13[15],xmm7[15] ++; AVX512F-SLOW-NEXT: vmovdqa64 %xmm27, %xmm1 + ; AVX512F-SLOW-NEXT: vpshufb %xmm1, %xmm0, %xmm0 + ; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm24 = zmm24[0,1,0,1],zmm0[0,1,0,1] + ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm2 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] + ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload + ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm4 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm26, %ymm0 +-; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm7 ++; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload ++; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm13 + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm29 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm31 + ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25],zero,ymm0[23],zero,ymm0[21,22,23,26],zero,ymm0[24],zero,ymm0[28,29,26,27] ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm7 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25],zero,ymm0[23],zero,ymm0[21,22,23,26],zero,ymm0[24],zero,ymm0[28,29,26,27] + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18],zero,ymm0[18,19,20,21],zero,ymm0[19],zero,ymm0[25,26,27,22],zero,ymm0[20],zero +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm8[21],zero,ymm8[19],zero,zero,zero,zero,ymm8[22],zero,ymm8[20],zero,zero +-; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm0 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm20 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm1 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm29, %ymm0 + ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm3 + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm14[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm14[25],zero,ymm14[23],zero,zero,zero,zero,ymm14[26],zero,ymm14[24],zero,zero + ; AVX512F-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +@@ -7647,40 +7648,40 @@ + ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] + ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload + ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm1 +-; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm2 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 ++; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm2 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 + ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] + ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm5 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 ++; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm8 ++; AVX512F-SLOW-NEXT: vmovdqa64 %ymm8, %ymm27 + ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm0 +-; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm31 ++; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = zero,ymm0[13],zero,zero,zero,zero,zero,zero,ymm0[14],zero,zero,zero,zero,zero,zero,ymm0[15],zero,zero,zero,zero,zero,zero,ymm0[16],zero,zero,zero,zero,zero,zero,ymm0[17],zero,zero + ; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,1,1,3,4,5,5,7] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,2,3,2] +-; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm14, %ymm14 ++; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} ymm28 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] ++; AVX512F-SLOW-NEXT: vpandnq %ymm14, %ymm28, %ymm14 + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm14, %zmm2, %zmm2 +-; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm10 = xmm10[0],xmm9[0],xmm10[1],xmm9[1],xmm10[2],xmm9[2],xmm10[3],xmm9[3],xmm10[4],xmm9[4],xmm10[5],xmm9[5],xmm10[6],xmm9[6],xmm10[7],xmm9[7] ++; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm9 = xmm9[0],xmm5[0],xmm9[1],xmm5[1],xmm9[2],xmm5[2],xmm9[3],xmm5[3],xmm9[4],xmm5[4],xmm9[5],xmm5[5],xmm9[6],xmm5[6],xmm9[7],xmm5[7] + ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> + ; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm15, %xmm15 +-; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm10, %xmm10 +-; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm21, %zmm10, %zmm0 +-; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm6[0],xmm12[0],xmm6[1],xmm12[1],xmm6[2],xmm12[2],xmm6[3],xmm12[3],xmm6[4],xmm12[4],xmm6[5],xmm12[5],xmm6[6],xmm12[6],xmm6[7],xmm12[7] +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> ++; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm9, %xmm9 ++; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm21, %zmm9, %zmm14 ++; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm9 = xmm10[0],xmm11[0],xmm10[1],xmm11[1],xmm10[2],xmm11[2],xmm10[3],xmm11[3],xmm10[4],xmm11[4],xmm10[5],xmm11[5],xmm10[6],xmm11[6],xmm10[7],xmm11[7] ++; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> + ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm16, %xmm0 +-; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm0, %xmm10 +-; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm6, %xmm6 +-; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm19, %zmm6, %zmm6 +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm4[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm18[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpshufb %xmm10, %xmm0, %xmm8 ++; AVX512F-SLOW-NEXT: vpshufb %xmm10, %xmm9, %xmm9 ++; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm19, %zmm9, %zmm9 ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm4[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm18[2,3,2,3] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm3[2,3,2,3] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm1[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm11[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm8[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm7[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm6[2,3,2,3] + ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm1 + ; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm1[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm21 = ymm1[0,0,1,1,4,4,5,5] +@@ -7690,28 +7691,29 @@ + ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] + ; AVX512F-SLOW-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm1 # 64-byte Folded Reload + ; AVX512F-SLOW-NEXT: # zmm1 = zmm1[0,1,0,1],mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm13[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] +-; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} xmm8 = xmm13[1,1,0,0,4,5,6,7] +-; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm8 = xmm8[0,1,2,0] +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm8, %zmm3 +-; AVX512F-SLOW-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm19, %ymm8 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm0, %zmm0 +-; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload +-; AVX512F-SLOW-NEXT: # zmm8 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-SLOW-NEXT: vporq %zmm8, %zmm0, %zmm0 +-; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] +-; AVX512F-SLOW-NEXT: vpand %ymm7, %ymm8, %ymm7 ++; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm12[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] ++; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} xmm6 = xmm12[1,1,0,0,4,5,6,7] ++; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm6[0,1,2,0] ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm3 ++; AVX512F-SLOW-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm19, %ymm6 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm0 ++; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload ++; AVX512F-SLOW-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] ++; AVX512F-SLOW-NEXT: vporq %zmm6, %zmm0, %zmm0 ++; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm6 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] ++; AVX512F-SLOW-NEXT: # ymm6 = mem[0,1,0,1] ++; AVX512F-SLOW-NEXT: vpand %ymm6, %ymm13, %ymm7 + ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm5, %zmm5 + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload + ; AVX512F-SLOW-NEXT: # zmm7 = mem[2,3,2,3,6,7,6,7] + ; AVX512F-SLOW-NEXT: vporq %zmm7, %zmm5, %zmm5 + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload + ; AVX512F-SLOW-NEXT: # zmm7 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm9 = zmm20[2,3,2,3,6,7,6,7] +-; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 +-; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm9 ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm12 = zmm26[2,3,2,3,6,7,6,7] ++; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm12 ++; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 + ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm5 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] +-; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm9 ++; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm12 + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Folded Reload + ; AVX512F-SLOW-NEXT: # zmm0 = mem[2,3,2,3,6,7,6,7] + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload +@@ -7726,80 +7728,80 @@ + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm30[0,1,0,1,4,5,4,5] + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm7 = zmm22[0,1,0,1,4,5,4,5] + ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm7 +-; AVX512F-SLOW-NEXT: vpternlogq $248, %ymm8, %ymm12, %ymm14 ++; AVX512F-SLOW-NEXT: vpternlogq $248, %ymm6, %ymm10, %ymm11 + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm21[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpternlogq $236, %ymm8, %ymm4, %ymm0 +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm10[0,1,0,1] +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm14, %zmm4 ++; AVX512F-SLOW-NEXT: vpternlogq $236, %ymm6, %ymm4, %ymm0 ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm8[0,1,0,1] ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm11, %zmm4 + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload + ; AVX512F-SLOW-NEXT: # ymm5 = mem[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpshufhw $190, {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: # ymm8 = mem[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] +-; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[2,2,3,3,6,6,7,7] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm8 ++; AVX512F-SLOW-NEXT: vpshufhw $190, {{[-0-9]+}}(%r{{[sb]}}p), %ymm6 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: # ymm6 = mem[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] ++; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,3,3,6,6,7,7] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm6 + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm15[0,1,0,1] +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 +-; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] +-; AVX512F-SLOW-NEXT: vpternlogq $184, %zmm4, %zmm8, %zmm5 +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm29[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm26[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpor %ymm4, %ymm11, %ymm4 +-; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm10, %zmm4 +-; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm10, %zmm0 +-; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: # ymm11 = mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: # ymm12 = mem[0,1,0,1] +-; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: # ymm14 = mem[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpshuflw $5, (%rsp), %xmm15 # 16-byte Folded Reload ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm6, %zmm5 ++; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm6 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] ++; AVX512F-SLOW-NEXT: vpternlogq $184, %zmm4, %zmm6, %zmm5 ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm31[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm20[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpor %ymm4, %ymm8, %ymm4 ++; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 ++; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm8, %zmm0 ++; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: # ymm8 = mem[0,1,0,1] ++; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm10 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: # ymm10 = mem[0,1,0,1] ++; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: # ymm11 = mem[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpshuflw $5, {{[-0-9]+}}(%r{{[sb]}}p), %xmm15 # 16-byte Folded Reload + ; AVX512F-SLOW-NEXT: # xmm15 = mem[1,1,0,0,4,5,6,7] + ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm15[0,1,2,0] + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload + ; AVX512F-SLOW-NEXT: # ymm17 = mem[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm27[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm29[2,3,2,3] + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm19 # 32-byte Folded Reload + ; AVX512F-SLOW-NEXT: # ymm19 = mem[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm20 = ymm28[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm20 = ymm27[2,3,2,3] + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm21 # 32-byte Folded Reload + ; AVX512F-SLOW-NEXT: # ymm21 = mem[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm31[2,3,2,3] +-; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm8, %zmm0 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm11, %zmm4 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm12, %zmm8 # 32-byte Folded Reload +-; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm8 ++; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm22 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: # ymm22 = mem[2,3,2,3] ++; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm6, %zmm0 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8, %zmm4 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm6 # 32-byte Folded Reload ++; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm6 + ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload + ; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm23 +-; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm23 ++; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm15[0,0,1,0] +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm14, %zmm4 +-; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm11, %zmm4 ++; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Reload ++; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm4 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 + ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload + ; AVX512F-SLOW-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm5 +-; AVX512F-SLOW-NEXT: vporq %ymm17, %ymm18, %ymm8 +-; AVX512F-SLOW-NEXT: vporq %ymm19, %ymm20, %ymm9 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm0, %zmm8 +-; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm9[0,1,2,3],zmm8[4,5,6,7] +-; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm25 ++; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm5 ++; AVX512F-SLOW-NEXT: vporq %ymm17, %ymm18, %ymm6 ++; AVX512F-SLOW-NEXT: vporq %ymm19, %ymm20, %ymm8 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm6 ++; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm6 = zmm8[0,1,2,3],zmm6[4,5,6,7] ++; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm25 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm25 +-; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm24 ++; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Reload ++; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm24 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm24 +-; AVX512F-SLOW-NEXT: vporq %ymm21, %ymm22, %ymm7 +-; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 +-; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload +-; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm7 = zmm8[0,1,2,3],zmm7[4,5,6,7] +-; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm2 ++; AVX512F-SLOW-NEXT: vporq %ymm21, %ymm22, %ymm6 ++; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm6 ++; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload ++; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm6 = zmm7[0,1,2,3],zmm6[4,5,6,7] ++; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 +-; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Folded Reload +-; AVX512F-SLOW-NEXT: # zmm0 = mem[0,1,0,1,4,5,4,5] +-; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm6 = zmm6[0,1,0,1,4,5,4,5] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm14[0,1,0,1,4,5,4,5] ++; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm6 = zmm9[0,1,0,1,4,5,4,5] + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm6 + ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm3[0,0,1,0,4,4,5,4] + ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm0 +@@ -7812,35 +7814,33 @@ + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm5, 384(%rax) + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm4, 192(%rax) + ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm23, 64(%rax) +-; AVX512F-SLOW-NEXT: addq $1464, %rsp # imm = 0x5B8 ++; AVX512F-SLOW-NEXT: addq $1416, %rsp # imm = 0x588 + ; AVX512F-SLOW-NEXT: vzeroupper + ; AVX512F-SLOW-NEXT: retq + ; + ; AVX512F-ONLY-FAST-LABEL: store_i8_stride7_vf64: + ; AVX512F-ONLY-FAST: # %bb.0: +-; AVX512F-ONLY-FAST-NEXT: subq $1256, %rsp # imm = 0x4E8 ++; AVX512F-ONLY-FAST-NEXT: subq $1496, %rsp # imm = 0x5D8 + ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero +-; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm1, %ymm14 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[25],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero +-; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm2, %ymm13 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %ymm7 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %ymm15 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm15, %ymm17 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm7[25],zero,ymm7[23],zero,zero,zero,zero,ymm7[26],zero,ymm7[24],zero,zero,zero,zero + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero,ymm2[27],zero,ymm2[25] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm17 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %ymm15 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[23],zero,zero,zero,zero,ymm3[26],zero,ymm3[24],zero,zero,zero,zero,ymm3[27],zero,ymm3[25] ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm4 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm18 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[25],zero,ymm4[23],zero,zero,zero,zero,ymm4[26],zero,ymm4[24],zero,zero + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 +@@ -7853,431 +7853,446 @@ + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm23 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %ymm1 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,1,14],zero,ymm1[12,13,0,1,14,15],zero,ymm1[3,12,13,2,3,16],zero,ymm1[30,31,28,29,16,17],zero,ymm1[31,18,19,28,29,18],zero + ; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm1, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm1, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %ymm10 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm3 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm3, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm10, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm25 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm0, %ymm3 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm25 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm5, (%rsp) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] +-; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm5, %ymm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm30 +-; AVX512F-ONLY-FAST-NEXT: vporq %ymm3, %ymm5, %ymm24 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %xmm3 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm6, %xmm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm6, %xmm28 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm9, %xmm19 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm3, %xmm29 +-; AVX512F-ONLY-FAST-NEXT: vpor %xmm5, %xmm6, %xmm3 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %xmm10 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm11, %xmm6, %xmm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm10, %xmm9 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm10, %xmm27 +-; AVX512F-ONLY-FAST-NEXT: vpor %xmm5, %xmm9, %xmm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm15 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %xmm10 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm15, %xmm9 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm22 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] ++; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm2, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm29 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm31 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vporq %ymm0, %ymm1, %ymm23 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %xmm5 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm1, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm2, %xmm18 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm1, %xmm20 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm21 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %xmm11 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm14, %xmm1, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm1, %xmm28 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm11, %xmm5 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm5, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %xmm9 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm12 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm16 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm10, %xmm12 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm21 +-; AVX512F-ONLY-FAST-NEXT: vporq %xmm9, %xmm12, %xmm22 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm13, %ymm20 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm14, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm7, %ymm2, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm16, %ymm7 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm7[14],zero,zero,zero,zero,zero,zero,ymm7[15],zero,zero,zero,zero,zero,zero,ymm7[16],zero,zero,zero,zero,zero,zero,ymm7[17],zero,zero,zero,zero,zero,zero,ymm7[18] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm7 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[0,1,14],zero,ymm7[12,13,0,1,14,15],zero,ymm7[3,12,13,2,3,16],zero,ymm7[30,31,28,29,16,17],zero,ymm7[31,18,19,28,29,18],zero +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm7, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm18, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %ymm18, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm9, %xmm13 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm12, %xmm13, %xmm12 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm12, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm7, %ymm6 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm7, %ymm24 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm13 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm7 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm4, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm0, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm13 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm11, %xmm13, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm9 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm9, %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vporq %xmm0, %xmm2, %xmm31 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm14 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm14, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %xmm8 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm19, %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm8, %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm7, %ymm6 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm6, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm6 = zero,zero,zero,ymm15[14],zero,zero,zero,zero,zero,zero,ymm15[15],zero,zero,zero,zero,zero,zero,ymm15[16],zero,zero,zero,zero,zero,zero,ymm15[17],zero,zero,zero,zero,zero,zero,ymm15[18] ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm3[0,1,14],zero,ymm3[12,13,0,1,14,15],zero,ymm3[3,12,13,2,3,16],zero,ymm3[30,31,28,29,16,17],zero,ymm3[31,18,19,28,29,18],zero ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm7, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %ymm19, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm6 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm6 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm3, %ymm6, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm4 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm14, %xmm4, %xmm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm4, %xmm17 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm7 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm7, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm3, %xmm1, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm3, %xmm12 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %xmm5 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm2 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm1, %xmm2, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm2, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm2, %xmm1 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm2, %xmm3 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %xmm4 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm4, %xmm2 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm4, %xmm2 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vpor %xmm1, %xmm2, %xmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm23, %ymm12 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm26, %ymm6 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm26, %ymm11 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm11[25],zero,ymm11[23],zero,zero,zero,zero,ymm11[26],zero,ymm11[24],zero,zero,zero,zero +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[18],zero,zero,zero,zero,ymm11[21],zero,ymm11[19],zero,zero,zero,zero,ymm11[22],zero,ymm11[20] +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm15 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm15[25],zero,ymm15[23],zero,zero,zero,zero,ymm15[26],zero,ymm15[24],zero,zero,zero,zero ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20] ++; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm8 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] +-; AVX512F-ONLY-FAST-NEXT: # ymm2 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] +-; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm19 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm1, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm30 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] ++; AVX512F-ONLY-FAST-NEXT: # ymm14 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] ++; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm10, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm10, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm29 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm24, %zmm0, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm1 + ; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm2, %xmm2 +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm0[4,5,6,7] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm2 ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm1[4,5,6,7] + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [2,2,3,3,2,2,3,3] +-; AVX512F-ONLY-FAST-NEXT: # ymm2 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm23 = [2,2,3,3,2,2,3,3] ++; AVX512F-ONLY-FAST-NEXT: # ymm23 = mem[0,1,2,3,0,1,2,3] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %xmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm2, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm4 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm4, %ymm4 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm18 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm24 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm23 ++; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm2 = xmm0[0,1,2,3,4,5,5,6] ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm2, %ymm23, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] ++; AVX512F-ONLY-FAST-NEXT: vpandn %ymm2, %ymm3, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm18 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[30],zero,ymm10[28],zero,zero,zero,zero,ymm10[31],zero,ymm10[29],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm26 + ; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] +-; AVX512F-ONLY-FAST-NEXT: vmovdqu (%rsp), %ymm0 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm25 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm26 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] +-; AVX512F-ONLY-FAST-NEXT: # ymm26 = mem[0,1,2,3,0,1,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm26, %ymm0, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm10[0],xmm15[0],xmm10[1],xmm15[1],xmm10[2],xmm15[2],xmm10[3],xmm15[3],xmm10[4],xmm15[4],xmm10[5],xmm15[5],xmm10[6],xmm15[6],xmm10[7],xmm15[7] +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm21 = zmm1[0,1,2,3],zmm0[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm29, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm28, %xmm1 +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] +-; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] +-; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm14[8],xmm8[8],xmm14[9],xmm8[9],xmm14[10],xmm8[10],xmm14[11],xmm8[11],xmm14[12],xmm8[12],xmm14[13],xmm8[13],xmm14[14],xmm8[14],xmm14[15],xmm8[15] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm1, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm31, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm22, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[27],zero,zero,zero,zero,ymm0[30],zero,ymm0[28],zero,zero,zero,zero,ymm0[31],zero,ymm0[29] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm31 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] ++; AVX512F-ONLY-FAST-NEXT: # ymm31 = mem[0,1,2,3,0,1,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm31, %ymm2, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm16, %xmm10 ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm9[0],xmm10[0],xmm9[1],xmm10[1],xmm9[2],xmm10[2],xmm9[3],xmm10[3],xmm9[4],xmm10[4],xmm9[5],xmm10[5],xmm9[6],xmm10[6],xmm9[7],xmm10[7] ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm16 = zmm3[0,1,2,3],zmm2[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm2, %xmm22 ++; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] ++; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm12[8],xmm5[8],xmm12[9],xmm5[9],xmm12[10],xmm5[10],xmm12[11],xmm5[11],xmm12[12],xmm5[12],xmm12[13],xmm5[13],xmm12[14],xmm5[14],xmm12[15],xmm5[15] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm20 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm12, %xmm21 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm3, %xmm1 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm2, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm28, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm5 = xmm11[0],xmm0[0],xmm11[1],xmm0[1],xmm11[2],xmm0[2],xmm11[3],xmm0[3],xmm11[4],xmm0[4],xmm11[5],xmm0[5],xmm11[6],xmm0[6],xmm11[7],xmm0[7] ++; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm11[8],xmm0[9],xmm11[9],xmm0[10],xmm11[10],xmm0[11],xmm11[11],xmm0[12],xmm11[12],xmm0[13],xmm11[13],xmm0[14],xmm11[14],xmm0[15],xmm11[15] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm17, %xmm3 ++; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm3[8],xmm7[8],xmm3[9],xmm7[9],xmm3[10],xmm7[10],xmm3[11],xmm7[11],xmm3[12],xmm7[12],xmm3[13],xmm7[13],xmm3[14],xmm7[14],xmm3[15],xmm7[15] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm2, %xmm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] ++; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm12 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm27 = ymm1[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm13, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm28 = ymm1[2,3,2,3] + ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm0, %xmm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm28 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm27, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm6, %xmm1 +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm0[0],xmm6[0],xmm0[1],xmm6[1],xmm0[2],xmm6[2],xmm0[3],xmm6[3],xmm0[4],xmm6[4],xmm0[5],xmm6[5],xmm0[6],xmm6[6],xmm0[7],xmm6[7] +-; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] +-; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm13[8],xmm9[8],xmm13[9],xmm9[9],xmm13[10],xmm9[10],xmm13[11],xmm9[11],xmm13[12],xmm9[12],xmm13[13],xmm9[13],xmm13[14],xmm9[14],xmm13[15],xmm9[15] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm1, %xmm1 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm27 +-; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm10[8],xmm15[8],xmm10[9],xmm15[9],xmm10[10],xmm15[10],xmm10[11],xmm15[11],xmm10[12],xmm15[12],xmm10[13],xmm15[13],xmm10[14],xmm15[14],xmm10[15],xmm15[15] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm1, %xmm1 +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm22[0,1,0,1],zmm1[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm1[0,1,0,1],zmm0[0,1,0,1] + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rax), %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm1 = xmm0[0,1,2,3,4,5,5,6] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm29 +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm23, %ymm0 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm5 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm20, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm22 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm2, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm20 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm2, %ymm4 +-; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm16, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm0, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm30 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm11, %ymm10 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm24 = ymm0[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm13, %ymm8 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[21],zero,ymm15[19],zero,zero,zero,zero,ymm15[22],zero,ymm15[20],zero,zero + ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm18, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm4 + ; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm0 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] + ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,5,7,4,5] + ; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm1, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm16 +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm9[0],xmm13[0],xmm9[1],xmm13[1],xmm9[2],xmm13[2],xmm9[3],xmm13[3],xmm9[4],xmm13[4],xmm9[5],xmm13[5],xmm9[6],xmm13[6],xmm9[7],xmm13[7] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] +-; AVX512F-ONLY-FAST-NEXT: # ymm9 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm3 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm9, %ymm3, %ymm15 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm15[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} ymm25 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] ++; AVX512F-ONLY-FAST-NEXT: vpandnq %ymm0, %ymm25, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm4, %zmm23 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] ++; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm9 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm9[2,3,2,3] + ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm2 + ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm13 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm18 = ymm13[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm31, %zmm0, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3],xmm8[4],xmm14[4],xmm8[5],xmm14[5],xmm8[6],xmm14[6],xmm8[7],xmm14[7] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm7, %xmm7 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm13 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm8 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm0[23],zero,ymm0[23,24,25,26],zero,ymm0[24],zero,ymm0[30,31] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm8[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm12, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm12[30],zero,ymm12[28],zero,zero,zero,zero,ymm12[31],zero,ymm12[29],zero,zero,zero +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm9, %ymm2, %ymm9 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm23[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm29 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm25 = ymm13[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm7[0],xmm3[0],xmm7[1],xmm3[1],xmm7[2],xmm3[2],xmm7[3],xmm3[3],xmm7[4],xmm3[4],xmm7[5],xmm3[5],xmm7[6],xmm3[6],xmm7[7],xmm3[7] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm13, %xmm5, %xmm5 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm13, %xmm7, %xmm7 ++; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm1 # 16-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm1[23],zero,ymm1[23,24,25,26],zero,ymm1[24],zero,ymm1[30,31] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm7[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm6, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero,zero ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm7[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm3, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm0[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm12[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm26[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm4 ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm12 = xmm1[0],xmm4[0],xmm1[1],xmm4[1],xmm1[2],xmm4[2],xmm1[3],xmm4[3],xmm1[4],xmm4[4],xmm1[5],xmm4[5],xmm1[6],xmm4[6],xmm1[7],xmm4[7] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm22, %xmm4 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm4, %xmm13 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,1,0,1] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm11[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm31 # 16-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] +-; AVX512F-ONLY-FAST-NEXT: # ymm11 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm3, %ymm14 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm3 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm1, %ymm13 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm2, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm11[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm12, %xmm1 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] ++; AVX512F-ONLY-FAST-NEXT: # ymm12 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm15, %ymm9 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm11 # 16-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm12, %ymm9, %ymm9 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm9, %zmm6 +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] +-; AVX512F-ONLY-FAST-NEXT: # ymm9 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm9, %ymm5, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm3, %ymm12 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm14, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm2, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] ++; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm5, %ymm7, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm7 + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm6, %zmm0, %zmm7 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm4, %ymm10, %ymm4 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm2, %zmm4 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm8, %ymm14, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm7 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm10, %ymm8, %ymm2 + ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm5 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm4, %zmm0, %zmm5 +-; AVX512F-ONLY-FAST-NEXT: vpandq %ymm9, %ymm22, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm20, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm4, %ymm9, %ymm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm6 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm6 ++; AVX512F-ONLY-FAST-NEXT: vpandq %ymm5, %ymm27, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm28, %zmm0 + ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] + ; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm0, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm19, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm24, %ymm2 + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm17, %zmm2, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vporq %zmm4, %zmm2, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] ++; AVX512F-ONLY-FAST-NEXT: vporq %zmm3, %zmm2, %zmm2 + ; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm18, %ymm0 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm15, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vporq %zmm4, %zmm0, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm4 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $184, %zmm2, %zmm4, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm28[0,1,0,1,4,5,4,5] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm8 = zmm27[0,1,0,1,4,5,4,5] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm4, %zmm8 +-; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm13, %ymm2 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 ++; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm25, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm19, %zmm0 ++; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] ++; AVX512F-ONLY-FAST-NEXT: vporq %zmm3, %zmm0, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $184, %zmm2, %zmm0, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] ++; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm8 = mem[0,1,0,1,4,5,4,5] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm8 ++; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm1, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm12, %zmm1 + ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] + ; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm1, %zmm1 + ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm6, %zmm9 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm1, %zmm4, %zmm9 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm1 # 16-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # xmm1 = xmm1[0],mem[0],xmm1[1],mem[1],xmm1[2],mem[2],xmm1[3],mem[3],xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm18 # 64-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # zmm18 = zmm1[0,1,0,1],mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm29, %xmm3 +-; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm2 = xmm3[1,1,0,0,4,5,6,7] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm4 = [0,1,0,1,2,0,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm2, %ymm4, %ymm19 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm6 = xmm1[1,1,0,0,4,5,6,7] +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm4, %ymm17 +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm6 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm6, %xmm3, %xmm10 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm6, %xmm1, %xmm6 +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] +-; AVX512F-ONLY-FAST-NEXT: # ymm11 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm12 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu (%rsp), %ymm1 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm13 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[25],zero,ymm1[23],zero,zero,zero,zero,ymm1[26],zero,ymm1[24],zero,zero +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm11 +-; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] +-; AVX512F-ONLY-FAST-NEXT: # ymm14 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm1, %ymm15 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm2[23],zero,ymm2[23,24,25,26],zero,ymm2[24],zero,ymm2[30,31] +-; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm2, %ymm14 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm4 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] +-; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [4,5,4,5,5,7,4,5] +-; AVX512F-ONLY-FAST-NEXT: vpermd %ymm4, %ymm2, %ymm20 +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] +-; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] +-; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm22 # 64-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # zmm22 = mem[2,3,2,3,6,7,6,7] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm22 +-; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] ++; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm5, %zmm22 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm1, %zmm0, %zmm22 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 16-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # xmm0 = xmm0[0],mem[0],xmm0[1],mem[1],xmm0[2],mem[2],xmm0[3],mem[3],xmm0[4],mem[4],xmm0[5],mem[5],xmm0[6],mem[6],xmm0[7],mem[7] ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm26 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm26 = zmm0[0,1,0,1],mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm4 # 16-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm1 = xmm4[1,1,0,0,4,5,6,7] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,2,0,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm19 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm5 = xmm0[1,1,0,0,4,5,6,7] ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm5, %ymm2, %ymm17 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm10 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm0, %xmm5 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] ++; AVX512F-ONLY-FAST-NEXT: # ymm12 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm13 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm14 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[25],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm12 ++; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] + ; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm23 # 32-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: # ymm23 = mem[0,1,0,1] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm1, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm9 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm9[23],zero,ymm9[23,24,25,26],zero,ymm9[24],zero,ymm9[30,31] ++; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm0 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm15 = ymm4[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] ++; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm9 = [4,5,4,5,5,7,4,5] ++; AVX512F-ONLY-FAST-NEXT: vpermd %ymm15, %ymm9, %ymm20 ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm15 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] ++; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm9 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] ++; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm24 # 64-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # zmm24 = mem[2,3,2,3,6,7,6,7] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm24 ++; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # ymm3 = mem[0,1,0,1] ++; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm25 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: # ymm25 = mem[0,1,0,1] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,0,1,0] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,0,1,0] + ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm23, %zmm23 # 32-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm23 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm24 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm23, %zmm24 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm2, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm21 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm21 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm0 # 32-byte Folded Reload +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm0 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm12, %ymm15, %ymm2 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3, %zmm3 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm25, %zmm25 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm25 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm18 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm18 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm15, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm16 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm16 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm3 # 32-byte Folded Reload ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm3 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm13, %ymm2 + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm0, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm7[0,1,2,3],zmm2[4,5,6,7] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm16 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm16 ++; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm4[0,1,2,3],zmm2[4,5,6,7] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm23 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 + ; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] +-; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm5 = zmm31[0,1,0,1,4,5,4,5] +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm17, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm1, %ymm13, %ymm1 +-; AVX512F-ONLY-FAST-NEXT: vpor %ymm11, %ymm14, %ymm5 ++; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm11[0,1,0,1,4,5,4,5] ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm6 ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm17, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm1, %ymm14, %ymm1 ++; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm12, %ymm0 + ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm1 +-; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm5[0,1,2,3],zmm1[4,5,6,7] +-; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm20, %zmm4 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm4 +-; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm4 ++; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm1[4,5,6,7] ++; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm1 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 ++; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm22, %zmm1 + ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm4, 128(%rax) ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm1, 128(%rax) + ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm2, (%rax) +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm16, 320(%rax) +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm0, 256(%rax) +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm21, 192(%rax) +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm24, 64(%rax) +-; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm22, 384(%rax) +-; AVX512F-ONLY-FAST-NEXT: addq $1256, %rsp # imm = 0x4E8 ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm23, 320(%rax) ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm3, 256(%rax) ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm16, 192(%rax) ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm18, 64(%rax) ++; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm24, 384(%rax) ++; AVX512F-ONLY-FAST-NEXT: addq $1496, %rsp # imm = 0x5D8 + ; AVX512F-ONLY-FAST-NEXT: vzeroupper + ; AVX512F-ONLY-FAST-NEXT: retq + ; + ; AVX512DQ-FAST-LABEL: store_i8_stride7_vf64: + ; AVX512DQ-FAST: # %bb.0: +-; AVX512DQ-FAST-NEXT: subq $1256, %rsp # imm = 0x4E8 ++; AVX512DQ-FAST-NEXT: subq $1496, %rsp # imm = 0x5D8 + ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %ymm1 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero +-; AVX512DQ-FAST-NEXT: vmovdqa %ymm1, %ymm14 +-; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[25],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero +-; AVX512DQ-FAST-NEXT: vmovdqa %ymm2, %ymm13 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %ymm7 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %ymm15 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm15, %ymm17 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm7[25],zero,ymm7[23],zero,zero,zero,zero,ymm7[26],zero,ymm7[24],zero,zero,zero,zero + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %ymm1 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero,ymm2[27],zero,ymm2[25] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm17 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %ymm15 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[23],zero,zero,zero,zero,ymm3[26],zero,ymm3[24],zero,zero,zero,zero,ymm3[27],zero,ymm3[25] ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm4 + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %ymm1 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm18 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[25],zero,ymm4[23],zero,zero,zero,zero,ymm4[26],zero,ymm4[24],zero,zero + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 +@@ -8290,403 +8305,420 @@ + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %ymm1 + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm23 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 + ; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %ymm1 + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,1,14],zero,ymm1[12,13,0,1,14,15],zero,ymm1[3,12,13,2,3,16],zero,ymm1[30,31,28,29,16,17],zero,ymm1[31,18,19,28,29,18],zero + ; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %ymm1 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm1, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm1, %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %ymm10 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm3 +-; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm3, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm10, %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm25 ++; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm0, %ymm3 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm25 +-; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm5 +-; AVX512DQ-FAST-NEXT: vmovdqu %ymm5, (%rsp) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] +-; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm5, %ymm5 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm30 +-; AVX512DQ-FAST-NEXT: vporq %ymm3, %ymm5, %ymm24 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %xmm3 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %xmm6 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = +-; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm6, %xmm5 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm6, %xmm28 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm6 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm9, %xmm19 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm3, %xmm29 +-; AVX512DQ-FAST-NEXT: vpor %xmm5, %xmm6, %xmm3 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %xmm10 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm6 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = +-; AVX512DQ-FAST-NEXT: vpshufb %xmm11, %xmm6, %xmm5 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = +-; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm10, %xmm9 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm10, %xmm27 +-; AVX512DQ-FAST-NEXT: vpor %xmm5, %xmm9, %xmm5 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm15 +-; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %xmm10 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm15, %xmm9 ++; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm22 ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm2 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] ++; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm2, %ymm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm29 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm31 ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vporq %ymm0, %ymm1, %ymm23 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %xmm5 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = ++; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm1, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm2, %xmm18 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm1, %xmm20 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = ++; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm21 ++; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %xmm11 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = ++; AVX512DQ-FAST-NEXT: vpshufb %xmm14, %xmm1, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm1, %xmm28 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm11, %xmm5 ++; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm5, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %xmm9 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> ++; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm12 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm16 + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm10, %xmm12 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm21 +-; AVX512DQ-FAST-NEXT: vporq %xmm9, %xmm12, %xmm22 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm13, %ymm20 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm14, %ymm2 +-; AVX512DQ-FAST-NEXT: vpor %ymm7, %ymm2, %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm16, %ymm7 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm7[14],zero,zero,zero,zero,zero,zero,ymm7[15],zero,zero,zero,zero,zero,zero,ymm7[16],zero,zero,zero,zero,zero,zero,ymm7[17],zero,zero,zero,zero,zero,zero,ymm7[18] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm7 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[0,1,14],zero,ymm7[12,13,0,1,14,15],zero,ymm7[3,12,13,2,3,16],zero,ymm7[30,31,28,29,16,17],zero,ymm7[31,18,19,28,29,18],zero +-; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm7, %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm18, %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %ymm18, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm9, %xmm13 ++; AVX512DQ-FAST-NEXT: vpor %xmm12, %xmm13, %xmm12 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm12, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm7, %ymm6 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm7, %ymm24 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm13 + ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm7 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm4, %ymm0 +-; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm0, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm13 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm11, %xmm13, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm9 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm9, %xmm2 +-; AVX512DQ-FAST-NEXT: vporq %xmm0, %xmm2, %xmm31 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm14 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm14, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %xmm8 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm19, %xmm2 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm8, %xmm2 +-; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 ++; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm7, %ymm6 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm6, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm6 = zero,zero,zero,ymm15[14],zero,zero,zero,zero,zero,zero,ymm15[15],zero,zero,zero,zero,zero,zero,ymm15[16],zero,zero,zero,zero,zero,zero,ymm15[17],zero,zero,zero,zero,zero,zero,ymm15[18] ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm3[0,1,14],zero,ymm3[12,13,0,1,14,15],zero,ymm3[3,12,13,2,3,16],zero,ymm3[30,31,28,29,16,17],zero,ymm3[31,18,19,28,29,18],zero ++; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm7, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %ymm19, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm6 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm6 ++; AVX512DQ-FAST-NEXT: vpor %ymm3, %ymm6, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm4 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm14, %xmm4, %xmm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm4, %xmm17 ++; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm7 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm7, %xmm1 ++; AVX512DQ-FAST-NEXT: vpor %xmm3, %xmm1, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa %xmm3, %xmm12 ++; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %xmm5 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm2 ++; AVX512DQ-FAST-NEXT: vpor %xmm1, %xmm2, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %xmm2 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm2, %xmm0 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm2, %xmm1 + ; AVX512DQ-FAST-NEXT: vmovdqa %xmm2, %xmm3 + ; AVX512DQ-FAST-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %xmm4 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm2 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm4, %xmm2 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm4, %xmm2 + ; AVX512DQ-FAST-NEXT: vmovdqa %xmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 ++; AVX512DQ-FAST-NEXT: vpor %xmm1, %xmm2, %xmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm23, %ymm12 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm26, %ymm6 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm26, %ymm11 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm11[25],zero,ymm11[23],zero,zero,zero,zero,ymm11[26],zero,ymm11[24],zero,zero,zero,zero +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[18],zero,zero,zero,zero,ymm11[21],zero,ymm11[19],zero,zero,zero,zero,ymm11[22],zero,ymm11[20] +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm15 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm15[25],zero,ymm15[23],zero,zero,zero,zero,ymm15[26],zero,ymm15[24],zero,zero,zero,zero ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20] ++; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm8 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] +-; AVX512DQ-FAST-NEXT: # ymm2 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] +-; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm19 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm1, %ymm2 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm30 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] ++; AVX512DQ-FAST-NEXT: # ymm14 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] ++; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm10, %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm10, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm29 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm24, %zmm0, %zmm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm1 + ; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = +-; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm2, %xmm2 +-; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm0[4,5,6,7] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = ++; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm2 ++; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm1[4,5,6,7] + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [2,2,3,3,2,2,3,3] +-; AVX512DQ-FAST-NEXT: # ymm2 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm23 = [2,2,3,3,2,2,3,3] ++; AVX512DQ-FAST-NEXT: # ymm23 = mem[0,1,2,3,0,1,2,3] + ; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %xmm0 + ; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill +-; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] +-; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm2, %ymm0 +-; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm4 +-; AVX512DQ-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm4, %ymm4 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm18 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm24 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm23 ++; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm2 = xmm0[0,1,2,3,4,5,5,6] ++; AVX512DQ-FAST-NEXT: vpermd %ymm2, %ymm23, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] ++; AVX512DQ-FAST-NEXT: vpandn %ymm2, %ymm3, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm18 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[30],zero,ymm10[28],zero,zero,zero,zero,ymm10[31],zero,ymm10[29],zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm26 + ; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] +-; AVX512DQ-FAST-NEXT: vmovdqu (%rsp), %ymm0 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm0 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm25 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vbroadcasti64x2 {{.*#+}} ymm26 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] +-; AVX512DQ-FAST-NEXT: # ymm26 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm26, %ymm0, %ymm1 +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm10[0],xmm15[0],xmm10[1],xmm15[1],xmm10[2],xmm15[2],xmm10[3],xmm15[3],xmm10[4],xmm15[4],xmm10[5],xmm15[5],xmm10[6],xmm15[6],xmm10[7],xmm15[7] +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] +-; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm21 = zmm1[0,1,2,3],zmm0[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm29, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm28, %xmm1 +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] +-; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] +-; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm14[8],xmm8[8],xmm14[9],xmm8[9],xmm14[10],xmm8[10],xmm14[11],xmm8[11],xmm14[12],xmm8[12],xmm14[13],xmm8[13],xmm14[14],xmm8[14],xmm14[15],xmm8[15] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm1, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm31, %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm22, %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[27],zero,zero,zero,zero,ymm0[30],zero,ymm0[28],zero,zero,zero,zero,ymm0[31],zero,ymm0[29] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vbroadcasti64x2 {{.*#+}} ymm31 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] ++; AVX512DQ-FAST-NEXT: # ymm31 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm31, %ymm2, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm16, %xmm10 ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm9[0],xmm10[0],xmm9[1],xmm10[1],xmm9[2],xmm10[2],xmm9[3],xmm10[3],xmm9[4],xmm10[4],xmm9[5],xmm10[5],xmm9[6],xmm10[6],xmm9[7],xmm10[7] ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] ++; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm16 = zmm3[0,1,2,3],zmm2[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm2, %xmm22 ++; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] ++; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm12[8],xmm5[8],xmm12[9],xmm5[9],xmm12[10],xmm5[10],xmm12[11],xmm5[11],xmm12[12],xmm5[12],xmm12[13],xmm5[13],xmm12[14],xmm5[14],xmm12[15],xmm5[15] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm20 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm12, %xmm21 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> ++; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm3, %xmm1 + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm2, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm28, %xmm0 ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm5 = xmm11[0],xmm0[0],xmm11[1],xmm0[1],xmm11[2],xmm0[2],xmm11[3],xmm0[3],xmm11[4],xmm0[4],xmm11[5],xmm0[5],xmm11[6],xmm0[6],xmm11[7],xmm0[7] ++; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm11[8],xmm0[9],xmm11[9],xmm0[10],xmm11[10],xmm0[11],xmm11[11],xmm0[12],xmm11[12],xmm0[13],xmm11[13],xmm0[14],xmm11[14],xmm0[15],xmm11[15] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm17, %xmm3 ++; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm3[8],xmm7[8],xmm3[9],xmm7[9],xmm3[10],xmm7[10],xmm3[11],xmm7[11],xmm3[12],xmm7[12],xmm3[13],xmm7[13],xmm3[14],xmm7[14],xmm3[15],xmm7[15] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm2, %xmm2 ++; AVX512DQ-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] ++; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm12 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm1 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm27 = ymm1[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm13, %ymm1 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm28 = ymm1[2,3,2,3] + ; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm0, %xmm0 + ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm28 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm27, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa %xmm6, %xmm1 +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm0[0],xmm6[0],xmm0[1],xmm6[1],xmm0[2],xmm6[2],xmm0[3],xmm6[3],xmm0[4],xmm6[4],xmm0[5],xmm6[5],xmm0[6],xmm6[6],xmm0[7],xmm6[7] +-; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] +-; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm13[8],xmm9[8],xmm13[9],xmm9[9],xmm13[10],xmm9[10],xmm13[11],xmm9[11],xmm13[12],xmm9[12],xmm13[13],xmm9[13],xmm13[14],xmm9[14],xmm13[15],xmm9[15] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm0 +-; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm1, %xmm1 +-; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm27 +-; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm10[8],xmm15[8],xmm10[9],xmm15[9],xmm10[10],xmm15[10],xmm10[11],xmm15[11],xmm10[12],xmm15[12],xmm10[13],xmm15[13],xmm10[14],xmm15[14],xmm10[15],xmm15[15] +-; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm1, %xmm1 +-; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm22[0,1,0,1],zmm1[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm1[0,1,0,1],zmm0[0,1,0,1] + ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill + ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rax), %xmm0 +-; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm1 = xmm0[0,1,2,3,4,5,5,6] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm29 +-; AVX512DQ-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill ++; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] ++; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm23, %ymm0 + ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill +-; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm5 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm20, %ymm0 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm22 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm2, %ymm1 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm20 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm2, %ymm4 +-; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm16, %ymm0 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm0, %ymm1 +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm30 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20],zero,zero ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm11, %ymm10 ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm24 = ymm0[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm13, %ymm8 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[21],zero,ymm15[19],zero,zero,zero,zero,ymm15[22],zero,ymm15[20],zero,zero + ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm18, %ymm1 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm4 + ; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm0 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] + ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,5,7,4,5] + ; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm1, %ymm0 +-; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm16 +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm9[0],xmm13[0],xmm9[1],xmm13[1],xmm9[2],xmm13[2],xmm9[3],xmm13[3],xmm9[4],xmm13[4],xmm9[5],xmm13[5],xmm9[6],xmm13[6],xmm9[7],xmm13[7] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] +-; AVX512DQ-FAST-NEXT: # ymm9 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm3 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm9, %ymm3, %ymm15 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm15[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} ymm25 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] ++; AVX512DQ-FAST-NEXT: vpandnq %ymm0, %ymm25, %ymm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm4, %zmm23 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] ++; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm9 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm9[2,3,2,3] + ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm2 + ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm13 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm18 = ymm13[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm31, %zmm0, %zmm0 +-; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3],xmm8[4],xmm14[4],xmm8[5],xmm14[5],xmm8[6],xmm14[6],xmm8[7],xmm14[7] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> +-; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm7, %xmm7 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm13 +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm8 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm0[23],zero,ymm0[23,24,25,26],zero,ymm0[24],zero,ymm0[30,31] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm8[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vmovdqa %ymm12, %ymm1 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm12[30],zero,ymm12[28],zero,zero,zero,zero,ymm12[31],zero,ymm12[29],zero,zero,zero +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb %ymm9, %ymm2, %ymm9 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm23[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm29 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm25 = ymm13[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm7[0],xmm3[0],xmm7[1],xmm3[1],xmm7[2],xmm3[2],xmm7[3],xmm3[3],xmm7[4],xmm3[4],xmm7[5],xmm3[5],xmm7[6],xmm3[6],xmm7[7],xmm3[7] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> ++; AVX512DQ-FAST-NEXT: vpshufb %xmm13, %xmm5, %xmm5 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm13, %xmm7, %xmm7 ++; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm1 # 16-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm1[23],zero,ymm1[23,24,25,26],zero,ymm1[24],zero,ymm1[30,31] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm7[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa %ymm6, %ymm2 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero,zero ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm7[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm3, %ymm0 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm0[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm12[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm26[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm4 ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm12 = xmm1[0],xmm4[0],xmm1[1],xmm4[1],xmm1[2],xmm4[2],xmm1[3],xmm4[3],xmm1[4],xmm4[4],xmm1[5],xmm4[5],xmm1[6],xmm4[6],xmm1[7],xmm4[7] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> ++; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm22, %xmm4 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm4, %xmm13 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,1,0,1] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm11[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm31 # 16-byte Folded Reload +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] +-; AVX512DQ-FAST-NEXT: # ymm11 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm3, %ymm14 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm3 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm1, %ymm13 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm2, %ymm1 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm11[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm12, %xmm1 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] ++; AVX512DQ-FAST-NEXT: # ymm12 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm15, %ymm9 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm11 # 16-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpor %ymm12, %ymm9, %ymm9 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm9, %zmm6 +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] +-; AVX512DQ-FAST-NEXT: # ymm9 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm9, %ymm5, %ymm0 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm3, %ymm12 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm14, %ymm2 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm2, %zmm2 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] ++; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm5, %ymm7, %ymm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm7 + ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm6, %zmm0, %zmm7 +-; AVX512DQ-FAST-NEXT: vpor %ymm4, %ymm10, %ymm4 +-; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm2, %zmm4 +-; AVX512DQ-FAST-NEXT: vpor %ymm8, %ymm14, %ymm2 ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm7 ++; AVX512DQ-FAST-NEXT: vpor %ymm10, %ymm8, %ymm2 + ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm5 +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm4, %zmm0, %zmm5 +-; AVX512DQ-FAST-NEXT: vpandq %ymm9, %ymm22, %ymm0 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm20, %zmm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 ++; AVX512DQ-FAST-NEXT: vpor %ymm4, %ymm9, %ymm3 ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm6 ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm6 ++; AVX512DQ-FAST-NEXT: vpandq %ymm5, %ymm27, %ymm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm28, %zmm0 + ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] + ; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm0, %zmm0 +-; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm19, %ymm2 ++; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm24, %ymm2 + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm17, %zmm2, %zmm2 +-; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vporq %zmm4, %zmm2, %zmm2 ++; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] ++; AVX512DQ-FAST-NEXT: vporq %zmm3, %zmm2, %zmm2 + ; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 +-; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm18, %ymm0 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm15, %zmm0 +-; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vporq %zmm4, %zmm0, %zmm0 +-; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm4 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] +-; AVX512DQ-FAST-NEXT: vpternlogq $184, %zmm2, %zmm4, %zmm0 +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm28[0,1,0,1,4,5,4,5] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm8 = zmm27[0,1,0,1,4,5,4,5] +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm4, %zmm8 +-; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm13, %ymm2 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 ++; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm25, %ymm0 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm19, %zmm0 ++; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] ++; AVX512DQ-FAST-NEXT: vporq %zmm3, %zmm0, %zmm3 ++; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] ++; AVX512DQ-FAST-NEXT: vpternlogq $184, %zmm2, %zmm0, %zmm3 ++; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] ++; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm8 = mem[0,1,0,1,4,5,4,5] ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm8 ++; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm1, %ymm1 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm12, %zmm1 + ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] + ; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm1, %zmm1 + ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm6, %zmm9 +-; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm1, %zmm4, %zmm9 +-; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload +-; AVX512DQ-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm1 # 16-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # xmm1 = xmm1[0],mem[0],xmm1[1],mem[1],xmm1[2],mem[2],xmm1[3],mem[3],xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] +-; AVX512DQ-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm18 # 64-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # zmm18 = zmm1[0,1,0,1],mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm29, %xmm3 +-; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm2 = xmm3[1,1,0,0,4,5,6,7] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm4 = [0,1,0,1,2,0,0,1] +-; AVX512DQ-FAST-NEXT: vpermd %ymm2, %ymm4, %ymm19 +-; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload +-; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm6 = xmm1[1,1,0,0,4,5,6,7] +-; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm4, %ymm17 +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm6 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vpshufb %xmm6, %xmm3, %xmm10 +-; AVX512DQ-FAST-NEXT: vpshufb %xmm6, %xmm1, %xmm6 +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] +-; AVX512DQ-FAST-NEXT: # ymm11 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm12 +-; AVX512DQ-FAST-NEXT: vmovdqu (%rsp), %ymm1 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm13 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[25],zero,ymm1[23],zero,zero,zero,zero,ymm1[26],zero,ymm1[24],zero,zero +-; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm11 +-; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] +-; AVX512DQ-FAST-NEXT: # ymm14 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm1 +-; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm1, %ymm15 +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm2[23],zero,ymm2[23,24,25,26],zero,ymm2[24],zero,ymm2[30,31] +-; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm2, %ymm14 +-; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload +-; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm4 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] +-; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [4,5,4,5,5,7,4,5] +-; AVX512DQ-FAST-NEXT: vpermd %ymm4, %ymm2, %ymm20 +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] +-; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] +-; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm22 # 64-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # zmm22 = mem[2,3,2,3,6,7,6,7] +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm22 +-; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] ++; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm5, %zmm22 ++; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm1, %zmm0, %zmm22 ++; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload ++; AVX512DQ-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 16-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # xmm0 = xmm0[0],mem[0],xmm0[1],mem[1],xmm0[2],mem[2],xmm0[3],mem[3],xmm0[4],mem[4],xmm0[5],mem[5],xmm0[6],mem[6],xmm0[7],mem[7] ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] ++; AVX512DQ-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm26 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm26 = zmm0[0,1,0,1],mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm4 # 16-byte Reload ++; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm1 = xmm4[1,1,0,0,4,5,6,7] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,2,0,0,1] ++; AVX512DQ-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm19 ++; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload ++; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm5 = xmm0[1,1,0,0,4,5,6,7] ++; AVX512DQ-FAST-NEXT: vpermd %ymm5, %ymm2, %ymm17 ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] ++; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm10 ++; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm0, %xmm5 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] ++; AVX512DQ-FAST-NEXT: # ymm12 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm13 ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm14 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[25],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero ++; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm12 ++; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] + ; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm23 # 32-byte Folded Reload +-; AVX512DQ-FAST-NEXT: # ymm23 = mem[0,1,0,1] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm1 ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm1, %ymm2 ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm9 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm9[23],zero,ymm9[23,24,25,26],zero,ymm9[24],zero,ymm9[30,31] ++; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm0 ++; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Reload ++; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm15 = ymm4[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] ++; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm9 = [4,5,4,5,5,7,4,5] ++; AVX512DQ-FAST-NEXT: vpermd %ymm15, %ymm9, %ymm20 ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm15 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] ++; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm9 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] ++; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm24 # 64-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # zmm24 = mem[2,3,2,3,6,7,6,7] ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm24 ++; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # ymm3 = mem[0,1,0,1] ++; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm25 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: # ymm25 = mem[0,1,0,1] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,0,1,0] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,0,1,0] + ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm23, %zmm23 # 32-byte Folded Reload +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm23 +-; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm24 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm23, %zmm24 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm2, %zmm0 +-; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm21 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm21 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm0 # 32-byte Folded Reload +-; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm0 +-; AVX512DQ-FAST-NEXT: vpor %ymm12, %ymm15, %ymm2 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3, %zmm3 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm25, %zmm25 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm25 ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm18 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm18 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm15, %zmm3 ++; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm16 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm16 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm3 # 32-byte Folded Reload ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm3 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm3 ++; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm13, %ymm2 + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm0, %zmm2 +-; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload +-; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm7[0,1,2,3],zmm2[4,5,6,7] +-; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm16 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm16 ++; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload ++; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm4[0,1,2,3],zmm2[4,5,6,7] ++; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm23 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 + ; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload + ; AVX512DQ-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] +-; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm5 = zmm31[0,1,0,1,4,5,4,5] +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm17, %zmm2 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm2 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 +-; AVX512DQ-FAST-NEXT: vpor %ymm1, %ymm13, %ymm1 +-; AVX512DQ-FAST-NEXT: vpor %ymm11, %ymm14, %ymm5 ++; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm11[0,1,0,1,4,5,4,5] ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm6 ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm17, %zmm2 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm2 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 ++; AVX512DQ-FAST-NEXT: vpor %ymm1, %ymm14, %ymm1 ++; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm12, %ymm0 + ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm1 +-; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm5[0,1,2,3],zmm1[4,5,6,7] +-; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm20, %zmm4 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm4 +-; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm4 ++; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm1[4,5,6,7] ++; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm1 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 ++; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm22, %zmm1 + ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm4, 128(%rax) ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm1, 128(%rax) + ; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm2, (%rax) +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm16, 320(%rax) +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm0, 256(%rax) +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm21, 192(%rax) +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm24, 64(%rax) +-; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm22, 384(%rax) +-; AVX512DQ-FAST-NEXT: addq $1256, %rsp # imm = 0x4E8 ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm23, 320(%rax) ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm3, 256(%rax) ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm16, 192(%rax) ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm18, 64(%rax) ++; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm24, 384(%rax) ++; AVX512DQ-FAST-NEXT: addq $1496, %rsp # imm = 0x5D8 + ; AVX512DQ-FAST-NEXT: vzeroupper + ; AVX512DQ-FAST-NEXT: retq + ; +diff -ruN --strip-trailing-cr a/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll +--- a/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll ++++ b/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll +@@ -1303,6 +1303,7 @@ + ; CHECK-NEXT: [[TMP64:%.*]] = xor i64 [[TMP63]], 87960930222080 + ; CHECK-NEXT: [[TMP65:%.*]] = inttoptr i64 [[TMP64]] to ptr + ; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 inttoptr (i64 add (i64 ptrtoint (ptr @__msan_va_arg_tls to i64), i64 688) to ptr), ptr align 8 [[TMP65]], i64 64, i1 false) ++; CHECK-NEXT: call void @llvm.memset.p0.i32(ptr align 8 inttoptr (i64 add (i64 ptrtoint (ptr @__msan_va_arg_tls to i64), i64 752) to ptr), i8 0, i32 48, i1 false) + ; CHECK-NEXT: store i64 1280, ptr @__msan_va_arg_overflow_size_tls, align 8 + ; CHECK-NEXT: call void (ptr, i32, ...) @_Z5test2I11LongDouble4EvT_iz(ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], i32 noundef 20, ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]]) + ; CHECK-NEXT: ret void diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 50ee3694c2231a..1a88322162b80d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "ec42d547eba5c0ad0bddbecc8902d35383968e78" - LLVM_SHA256 = "c7ec22eb1026b8d09afe2a70b2e2f5cf09a1805c4d16b004e72bba5b4153e2cf" + LLVM_COMMIT = "506c47df00bbd9e527ecc5ac6e192b5fe5daa2c5" + LLVM_SHA256 = "5db67a5293810e6aebd2f757d660dbe2271fc86016c56cfcc56a89705dc22d80" tf_http_archive( name = name, diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir index 7a9863ea96c8ed..d207c0780098a6 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir @@ -129,7 +129,7 @@ func.func @matmul_narrow_static(%lhs: tensor<2x16xf32>, %rhs: tensor<16x64xf32>, // PACKED: tensor.empty() : tensor<8x16x8x1xf32> // PACKED-COUNT: scf.for -// PACKED: vector.transpose +// PACKED: vector.shape_cast // PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> // PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir index 78fd35133dc6c0..16fc1f1837b147 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/gml_st/lower_vectors.mlir @@ -161,7 +161,8 @@ func.func @optimize_pack_with_transpose(%arg0: memref<1024x1024xf32>) -> // FLATTEN-NOT: vector.transpose // FLATTEN: %[[COLLAPSE:.*]] = memref.collapse_shape %[[ALLOC]] // FLATTEN-SAME: memref<128x1024x8x1xf32> into memref<128x1024x8xf32> -// FLATTEN: vector.transfer_write %[[READ]], %[[COLLAPSE]] +// FLATTEN: %[[CAST:.*]] = vector.shape_cast +// FLATTEN: vector.transfer_write %[[CAST]], %[[COLLAPSE]] // ----- From b8c5e00edc6c29393102dfa3c691b78aaa9d4af7 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 20 Nov 2023 00:22:17 -0800 Subject: [PATCH 277/391] PR #7088: Fix incorrect device memory limit on non root processes. Imported from GitHub PR https://github.com/openxla/xla/pull/7088 @hawkinsp In a multiprocess config, this code was attempting to get the memory limit from device 0 which is only local to process 0. For all other processes, the memory limit was being set to 0. This was causing the memory limit in the latency hiding scheduler to be different between process 0 and the other processes, leading to a different scheduled order of collectives and causing deadlocks. To fix this, use addressable_devices instead of all global device and index using the device_ordinal. Currently, jax doesn't appear to set `options.executable_build_options.device_ordinal()` so we will default to 0 in that case. Copybara import of the project: -- 1bad102e374c7aa4c3acfefdad66477813432a82 by Trevor Morris : Use addressable devices so that correct mem limit is read. Also use device_ordinal to index into addressable devices - however device_ordinal is currently unset so will default to 0 Merging this change closes #7088 PiperOrigin-RevId: 583928913 --- third_party/xla/xla/python/py_client.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 9f7021be5761c7..2caed4eb9562fc 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -365,9 +365,15 @@ StatusOr> PyClient::Compile( auto* pjrt_compatible_client = llvm::dyn_cast_or_null(ifrt_client_.get()); if (pjrt_compatible_client != nullptr) { - auto devices = pjrt_compatible_client->pjrt_client()->devices(); - if (!devices.empty()) { - auto stats = devices[0]->GetAllocatorStats(); + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); if (stats.ok() && stats->bytes_limit) { options.executable_build_options.set_device_memory_size( *stats->bytes_limit); From 053919542484165fe3c3e88e73f6acdd339e3497 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 00:47:36 -0800 Subject: [PATCH 278/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/383e601d6140ee499349c4bb53085eb4a891f500. PiperOrigin-RevId: 583934201 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index b5ea5e6c8a2bd2..617614dba84506 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "886cfe0e0fd894ba1beafbb80585c6f32de8a2e4" - TFRT_SHA256 = "d391129b09b90a343f4b948f8fda109a260cdfb0e1ea63c978cafcdf528a85e3" + TFRT_COMMIT = "383e601d6140ee499349c4bb53085eb4a891f500" + TFRT_SHA256 = "edacc0434ee28a2203f6699b0cac3feed1bf2fe8b011f59e5bcfcb48a74e4bcb" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index b5ea5e6c8a2bd2..617614dba84506 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "886cfe0e0fd894ba1beafbb80585c6f32de8a2e4" - TFRT_SHA256 = "d391129b09b90a343f4b948f8fda109a260cdfb0e1ea63c978cafcdf528a85e3" + TFRT_COMMIT = "383e601d6140ee499349c4bb53085eb4a891f500" + TFRT_SHA256 = "edacc0434ee28a2203f6699b0cac3feed1bf2fe8b011f59e5bcfcb48a74e4bcb" tf_http_archive( name = "tf_runtime", From a127f3ba83a0ac065cff2a20e82eb1433d0840bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 01:01:56 -0800 Subject: [PATCH 279/391] compat: Update forward compatibility horizon to 2023-11-20 PiperOrigin-RevId: 583937663 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 91d105f6c88132..9b2118409ae917 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 19) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 11, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 0641f01c9217baa749fc5eae8bf85ae42de3ce0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 01:02:02 -0800 Subject: [PATCH 280/391] Update GraphDef version to 1686. PiperOrigin-RevId: 583937693 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 199a727c62be14..1dc8d61fea9bcd 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1685 // Updated: 2023/11/19 +#define TF_GRAPH_DEF_VERSION 1686 // Updated: 2023/11/20 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 6589436d0000cdb5885e360e2151c0d30a1044a1 Mon Sep 17 00:00:00 2001 From: Chao Date: Mon, 20 Nov 2023 01:41:55 -0800 Subject: [PATCH 281/391] PR #7029: [ROCm] add fpus for rocm and enable xla_compile_lib_test Imported from GitHub PR https://github.com/openxla/xla/pull/7029 1. we add `fpus_per_core` for cost model and fusion analysis. 2. enable ROCm for new xla_compile_lib_test https://github.com/openxla/xla/commit/a129c9bfe093743a93a9c2986d187a066940c21d#diff-19845b853b7cf43ac606aca22262c15fdf2d56e5eb196a98f8d982e947123f1d please check @xla-rotation thanks in advance! Copybara import of the project: -- e9bcf2439d1384dc11e57f8e045b2e959b67dc62 by Chao Chen : add fpus for rocm and enable xla_compile_lib_test Merging this change closes #7029 PiperOrigin-RevId: 583947247 --- .../xla/service/gpu/gpu_device_info_for_tests.cc | 2 +- .../xla/xla/service/gpu/gpu_device_info_test.cc | 4 ++-- .../xla/stream_executor/rocm/rocm_gpu_executor.cc | 13 +++++++++++++ third_party/xla/xla/tools/BUILD | 4 +++- third_party/xla/xla/tools/xla_compile_lib_test.cc | 7 ++++++- 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc index 5f88847b3352a4..8eb44d45aa56eb 100644 --- a/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc +++ b/third_party/xla/xla/service/gpu/gpu_device_info_for_tests.cc @@ -53,7 +53,7 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() { b.set_shared_memory_per_core(64 * 1024); b.set_threads_per_core_limit(2048); b.set_core_count(104); - b.set_fpus_per_core(0); + b.set_fpus_per_core(128); b.set_block_dim_limit_x(2'147'483'647); b.set_block_dim_limit_y(2'147'483'647); b.set_block_dim_limit_z(2'147'483'647); diff --git a/third_party/xla/xla/service/gpu/gpu_device_info_test.cc b/third_party/xla/xla/service/gpu/gpu_device_info_test.cc index cb8d69f4c62592..c9e2ae245a7535 100644 --- a/third_party/xla/xla/service/gpu/gpu_device_info_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_device_info_test.cc @@ -119,7 +119,7 @@ TEST(DeviceInfoTest, DeviceInfoIsCorrect) { /*shared_memory_per_block_optin=*/0, /*shared_memory_per_core=*/64 * 1024, /*threads_per_core_limit=*/2560, /*core_count=*/120, - /*fpus_per_core=*/0, /*block_dim_limit_x=*/2'147'483'647, + /*fpus_per_core=*/128, /*block_dim_limit_x=*/2'147'483'647, /*block_dim_limit_y=*/2'147'483'647, /*block_dim_limit_z=*/2'147'483'647, /*memory_bandwidth=*/1228800000000, @@ -136,7 +136,7 @@ TEST(DeviceInfoTest, DeviceInfoIsCorrect) { /*shared_memory_per_block_optin=*/0, /*shared_memory_per_core=*/64 * 1024, /*threads_per_core_limit=*/2560, /*core_count=*/60, - /*fpus_per_core=*/0, /*block_dim_limit_x=*/2'147'483'647, + /*fpus_per_core=*/64, /*block_dim_limit_x=*/2'147'483'647, /*block_dim_limit_y=*/2'147'483'647, /*block_dim_limit_z=*/2'147'483'647, /*memory_bandwidth=*/256000000000, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index 3b7e6dd0a178c6..bd596857070335 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -107,6 +107,18 @@ bool GpuExecutor::UnloadModule(ModuleHandle module_handle) { return UnloadGpuBinary(gpu_binary); } +namespace { +int fpus_per_core(std::string gcn_arch_name) { + // Source: + // https://www.amd.com/content/dam/amd/en/documents/instinct-business-docs/white-papers/amd-cdna2-white-paper.pdf + int n = 128; // gfx90a and gfx908 -> 128 + if (gcn_arch_name.substr(0, 6) == "gfx906") { + n = 64; + } + return n; +} +} // namespace + tsl::StatusOr> GpuExecutor::CreateOrShareConstant(Stream* stream, const std::vector& content) { @@ -893,6 +905,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { GpuDriver::GetMaxSharedMemoryPerBlock(device).value()); int core_count = GpuDriver::GetMultiprocessorCount(device).value(); builder.set_core_count(core_count); + builder.set_fpus_per_core(fpus_per_core(gcn_arch_name)); builder.set_threads_per_core_limit( GpuDriver::GetMaxThreadsPerMultiprocessor(device).value()); builder.set_registers_per_block_limit( diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 2f23d3b68b8717..99b1848b934800 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -689,7 +689,9 @@ xla_test( ":data/add.hlo", "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":xla_compile_lib", "//xla:util", diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_compile_lib_test.cc index 05d86570ecee0e..dacd2ecec26820 100644 --- a/third_party/xla/xla/tools/xla_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_compile_lib_test.cc @@ -53,7 +53,12 @@ using ::tsl::testing::StatusIs; #if XLA_TEST_BACKEND_CPU static constexpr absl::string_view kPlatformName = "Host"; #elif XLA_TEST_BACKEND_GPU -static constexpr absl::string_view kPlatformName = "CUDA"; +static constexpr absl::string_view kPlatformName = +#if TENSORFLOW_USE_ROCM + "ROCM"; +#else + "CUDA"; +#endif #endif // XLA_TEST_BACKEND_CPU class XlaCompileLibTest : public HloTestBase { From 386e57664f941a8ecba5460c9c5a4d5d9f1d1831 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 20 Nov 2023 02:21:54 -0800 Subject: [PATCH 282/391] Remove is_scheduled from .hlo file. hlo-opt runs the HLO passes, and if we have is_scheduled=true, the scheduling pass will be skipped which currently makes the test fail. PiperOrigin-RevId: 583957372 --- third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo index e99d7146bebe78..434d6876f62c87 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_buffers.hlo @@ -1,6 +1,6 @@ // RUN: hlo-opt %s --platform=CUDA --stage=buffer-assignment --xla_gpu_target_config_filename=%S/gpu_specs/a100.txtpb | FileCheck %s -HloModule m, is_scheduled=true +HloModule m add { a = f16[] parameter(0) From 020171654eee781714aad7a53bf1d73ec562963d Mon Sep 17 00:00:00 2001 From: hmonishN <143435143+hmonishN@users.noreply.github.com> Date: Mon, 20 Nov 2023 03:48:04 -0800 Subject: [PATCH 283/391] PR #7132: adding fix for local_client_aot_test target build failure Imported from GitHub PR https://github.com/openxla/xla/pull/7132 Fixing build failure for local_client_aot_test_target. Command: bazel test --config=cuda --jobs=150 --test_timeout=3600 --nocheck_visibility --test_output=streamed //xla/tests:local_client_aot_test ERROR: /xla/xla/tests/BUILD:230:14: Linking xla/tests/local_client_aot_test_helper [for tool] failed: (Exit 1): /usr/bin/ld: bazel-out/k8-opt-exec-50AE0418/bin/xla/stream_executor/cuda/libcuda_executor.lo(cuda_executor.o): in function `stream_executor::gpu::GpuExecutor::GetKernel(stream_executor::MultiKernelLoaderSpec const&, stream_executor::Kernel*)': cuda_executor.cc:(.text._ZN15stream_executor3gpu11GpuExecutor9GetKernelERKNS_21MultiKernelLoaderSpecEPNS_6KernelE+0x2b5): undefined reference to `stream_executor::CudaPtxInMemory::text(int, int) const' /usr/bin/ld: cuda_executor.cc:(.text._ZN15stream_executor3gpu11GpuExecutor9GetKernelERKNS_21MultiKernelLoaderSpecEPNS_6KernelE+0x4df): undefined reference to `stream_executor::CudaPtxInMemory::default_text() const' collect2: error: ld returned 1 exit status [9,413 / 9,416] 1 / 1 tests, 1 failed; checking cached actions Target //xla/tests:local_client_aot_test failed to build Fix: replaced stream_executor_headers with stream_executor in cuda_executor dependencies to create a dependency path from cuda_blas to kernel_spec. In addition, adding command line argument --define=framework_shared_object=false in bazel command will fix the error. Copybara import of the project: -- 44e16997435238f0155798206a476b2bc32563f4 by Harshit Monish : adding fix for local_client_aot_test_target Merging this change closes #7132 PiperOrigin-RevId: 583977122 --- third_party/xla/xla/stream_executor/cuda/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 6c187ce1781d42..6796364bd8005e 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -558,7 +558,7 @@ cc_library( "//xla/stream_executor:command_buffer", "//xla/stream_executor:kernel", "//xla/stream_executor:plugin_registry", - "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor:stream_executor", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/gpu:asm_compiler", "//xla/stream_executor/gpu:gpu_command_buffer", From 3a34c23733ea575b8c37ba43656de5ffb5f4c8e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Mon, 20 Nov 2023 03:59:17 -0800 Subject: [PATCH 284/391] [XLA:GPU] Fix type of pred parameters in Triton fusion The i8 storage type needs to be casted to the i1 type. I also had to fix the conversion of i1 to other int types. It is already covered by ir_emitter_triton_parametrized_test's ConvertFusionExecutesCorrectly/pred_s32. I also had to fix the comparison of i1 values. It is already covered by ir_emitter_triton_parametrized_test's CompareFusionExecutesCorrectly. PiperOrigin-RevId: 583979096 --- .../xla/xla/service/gpu/ir_emitter_triton.cc | 28 +++++--- .../xla/service/gpu/ir_emitter_triton_test.cc | 67 +++++++++++++++++++ 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index caef52d7d8977e..b7ee8986d622d9 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -266,6 +266,9 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { dst_element_ty.isa()) { if (src_element_ty.getIntOrFloatBitWidth() < dst_element_ty.getIntOrFloatBitWidth()) { + if (src_element_ty.isInteger(1)) { + return b.create(dst_ty, value); + } return b.create(dst_ty, value); } return b.create(dst_ty, value); @@ -303,10 +306,12 @@ Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { Value Compare(ImplicitLocOpBuilder& b, ValueRange values, mlir::mhlo::ComparisonDirection direction) { - if (mlir::getElementTypeOrSelf(values[0]).isa()) { + const Type type = mlir::getElementTypeOrSelf(values[0]); + if (type.isa()) { return b.create( - mlir::mhlo::impl::getCmpPredicate(direction, - /*isSigned=*/true) + mlir::mhlo::impl::getCmpPredicate( + direction, + /*isSigned=*/!type.isInteger(1)) .value(), values[0], values[1]); } @@ -1394,11 +1399,18 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path, i < analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS).size(); Side& side = is_lhs ? lhs : rhs; auto& values = is_lhs ? values_lhs : values_rhs; - CHECK(values - .insert({iter_args_to_parameters[i], - EmitParameterLoad(b, iter_args[i], - iter_args_to_boundary_checks[i])}) - .second); + + const HloInstruction* param_hlo = iter_args_to_parameters[i]; + Type param_ty = TritonType(b, param_hlo->shape().element_type()); + Type param_storage_ty = StorageType(b, param_ty); + Value param_value = + EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]); + if (param_ty != param_storage_ty) { + // For example cast i8 to i1. + param_value = Cast(b, param_value, param_ty); + } + + CHECK(values.insert({param_hlo, param_value}).second); SmallVector increments; for (const DimProperties& dim : side.tiled_dims) { const TensorIterationSpec::DimIterationSpec* spec = analysis.IterSpec( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 992bc3e6ff937e..f2f2db794bb266 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -647,6 +647,47 @@ CHECK: } tsl::testing::IsOkAndHolds(true)); } +TEST_F(TritonFilecheckTest, PredParametersAreTruncatedToI1) { + const std::string kHloText = R"( +HloModule m + +triton_gemm_computation { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + compare = pred[2,2]{1,0} compare(a, b), direction=LT + and = pred[2,2]{1,0} and(p, compare) + convert = f32[2,2]{1,0} convert(and) + ROOT r = f32[2,2]{1,0} dot(convert, c), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + ROOT triton_gemm = f32[2,2]{1,0} fusion(p, a, b, c), kind=kCustom, + calls=triton_gemm_computation, + backend_config={kind: "__triton_gemm", + triton_gemm_config: { + "block_m":16,"block_n":16,"block_k":16, + "split_k":1,"num_stages":1,"num_warps":1 + } + } +} +)"; + TritonGemmConfig config(16, 16, 16, 1, 1, 1); + ASSERT_THAT(CreateTritonIrAndFileCheck(kHloText, config, EmitMatMul, + "triton_gemm_computation", R"( +CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr, 1> -> tensor<16x16xi8> +CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1> +CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> +)"), + tsl::testing::IsOkAndHolds(true)); +} + TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { const std::string kHloText = R"( triton_gemm_r { @@ -2143,6 +2184,32 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +TEST_F(TritonGemmLevel2Test, SupportPredParametersUsedInExpressions) { + const std::string kHloText = R"( +ENTRY e { + p = pred[2,2]{1,0} parameter(0) + a = f32[2,2]{1,0} parameter(1) + b = f32[2,2]{1,0} parameter(2) + c = f32[2,2]{1,0} parameter(3) + compare = pred[2,2]{1,0} compare(a, b), direction=LT + and = pred[2,2]{1,0} and(p, compare) + convert = f32[2,2]{1,0} convert(and) + ROOT r = f32[2,2]{1,0} dot(convert, c), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter(), + m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom))); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})); +} + TEST_F(TritonGemmTest, Naming) { const char* hlo_text = R"( HloModule t From c8abc1e0fd6b84c970dd3772fa9851eb968daef6 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 20 Nov 2023 04:23:42 -0800 Subject: [PATCH 285/391] Deduplicate algorithms in CreateOpRunners() We can have duplicates because we fetch algorithms using heuristics_mode_a and heuristics_mode_b. PiperOrigin-RevId: 583984086 --- third_party/xla/xla/stream_executor/cuda/BUILD | 3 ++- third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc | 9 +++++++++ third_party/xla/xla/stream_executor/dnn.h | 8 ++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 6796364bd8005e..b719a603d089af 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1,8 +1,8 @@ # Description: # CUDA-platform specific StreamExecutor support code. -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load("//xla/tests:build_defs.bzl", "xla_test") +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load( "//xla:xla.bzl", "xla_cc_test", @@ -390,6 +390,7 @@ cc_library( ":cuda_stream", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 4808f13cdde379..41a5440cc46f7d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -8245,6 +8246,7 @@ tsl::Status CreateOpRunners( auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime(); out_runners->clear(); + absl::flat_hash_set algorithm_deduplication; for (int i = 0; i < filtered_configs.size(); i++) { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(cudnn.handle()) @@ -8280,6 +8282,13 @@ tsl::Status CreateOpRunners( << runner_or.status(); continue; } + // We currently collect a list of algorithms using heuristics_mode_a and + // heuristics_mode_b, so we can potentially have duplicates. But we should + // not actually autotune the same algorithm twice! + if (!algorithm_deduplication.insert(runner_or->ToAlgorithmDesc().value()) + .second) { + continue; + } out_runners->push_back(std::make_unique>( std::move(runner_or).value())); diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index 20d2534941bbbd..a181516e9b26e6 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -882,6 +882,9 @@ class AlgorithmDesc { uint64_t hash() const; + template + friend H AbslHashValue(H h, const AlgorithmDesc& algo_desc); + AlgorithmProto ToProto() const { return proto_; } std::string ToString() const; @@ -890,6 +893,11 @@ class AlgorithmDesc { AlgorithmProto proto_; }; +template +H AbslHashValue(H h, const AlgorithmDesc& algo_desc) { + return H::combine(std::move(h), algo_desc.hash()); +} + // Describes the result from a perf experiment. // // Arguments: From 2fec9872e1a32c3b2bc5ba351b21f7fc404fb2a4 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 20 Nov 2023 04:42:19 -0800 Subject: [PATCH 286/391] Priority fusion: cache HloFusionAnalyses. PiperOrigin-RevId: 583987557 --- third_party/xla/xla/service/gpu/BUILD | 1 - .../xla/service/gpu/hlo_fusion_analysis.cc | 123 +++++++++--------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 34 ++--- .../xla/service/gpu/kernel_mapping_scheme.h | 24 ++-- third_party/xla/xla/service/gpu/model/BUILD | 29 ----- .../gpu/model/fusion_analysis_cache.cc | 93 ------------- .../service/gpu/model/fusion_analysis_cache.h | 69 ---------- .../gpu/model/fusion_analysis_cache_test.cc | 115 ---------------- .../gpu/model/gpu_performance_model.cc | 36 +---- .../service/gpu/model/gpu_performance_model.h | 12 +- .../gpu/model/gpu_performance_model_test.cc | 2 +- .../xla/xla/service/gpu/priority_fusion.cc | 21 +-- .../xla/xla/service/gpu/priority_fusion.h | 9 +- .../xla/service/gpu/priority_fusion_test.cc | 8 +- 14 files changed, 110 insertions(+), 466 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h delete mode 100644 third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ef176d7051ed6e..042acefef80e23 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2093,7 +2093,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass", "//xla/service:instruction_fusion", - "//xla/service/gpu/model:fusion_analysis_cache", "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/service/gpu/model:gpu_performance_model", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index c064c8c9c1770c..bb2fe734a63055 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -58,35 +58,6 @@ const auto kDimX = TilingScheme::DimX; const auto kLinearIndexingX = TilingScheme::LinearIndexingX; const auto kStridedIndexingX = TilingScheme::StridedIndexingX; -std::optional ComputeTransposeTilingScheme( - const std::optional& tiled_transpose) { - if (!tiled_transpose) { - return std::nullopt; - } - - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the input shape. - Vector3 dims = tiled_transpose->dimensions; - Vector3 order = tiled_transpose->permutation; - - Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; - Vector3 tile_sizes{1, 1, 1}; - tile_sizes[order[2]] = WarpSize() / kNumRows; - Vector3 num_threads{1, 1, WarpSize()}; - num_threads[order[2]] = kNumRows; - - return TilingScheme( - /*permuted_dims*/ permuted_dims, - /*tile_sizes=*/tile_sizes, - /*num_threads=*/num_threads, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1, - /*scaling_factor=*/1, - /*tiling_dimensions=*/{order[2], 2}); -} - // Returns true if `instr` is a non-strided slice. bool IsSliceWithUnitStrides(const HloInstruction* instr) { auto slice = DynCast(instr); @@ -286,28 +257,6 @@ std::optional FindConsistentTransposeHero( } // namespace -HloFusionAnalysis::HloFusionAnalysis( - FusionBackendConfig fusion_backend_config, - std::vector fusion_roots, - FusionBoundaryFn fusion_boundary_fn, - std::vector fusion_arguments, - std::vector fusion_heroes, - const se::DeviceDescription* device_info, - std::optional tiled_transpose, bool has_4_bit_input, - bool has_4_bit_output) - : fusion_backend_config_(std::move(fusion_backend_config)), - fusion_roots_(std::move(fusion_roots)), - fusion_boundary_fn_(std::move(fusion_boundary_fn)), - fusion_arguments_(std::move(fusion_arguments)), - fusion_heroes_(std::move(fusion_heroes)), - device_info_(device_info), - tiled_transpose_(tiled_transpose), - has_4_bit_input_(has_4_bit_input), - has_4_bit_output_(has_4_bit_output), - reduction_codegen_info_(ComputeReductionCodegenInfo(FindHeroReduction())), - transpose_tiling_scheme_(ComputeTransposeTilingScheme(tiled_transpose_)), - loop_fusion_config_(ComputeLoopFusionConfig()) {} - // static StatusOr HloFusionAnalysis::Create( FusionBackendConfig backend_config, @@ -404,7 +353,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kLoop; } -StatusOr HloFusionAnalysis::GetLaunchDimensions() const { +StatusOr HloFusionAnalysis::GetLaunchDimensions() { auto emitter_fusion_kind = GetEmitterFusionKind(); switch (emitter_fusion_kind) { case EmitterFusionKind::kLoop: { @@ -454,9 +403,7 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions() const { } const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { - if (GetEmitterFusionKind() != EmitterFusionKind::kReduction) { - return nullptr; - } + CHECK(GetEmitterFusionKind() == EmitterFusionKind::kReduction); auto roots = fusion_roots(); CHECK(!roots.empty()); // We always use the first reduce root that triggers unnested reduction @@ -471,8 +418,57 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -std::optional -HloFusionAnalysis::ComputeLoopFusionConfig() const { +const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { + if (reduction_codegen_info_.has_value()) { + return &reduction_codegen_info_.value(); + } + + const HloInstruction* hero_reduction = FindHeroReduction(); + + auto reduction_codegen_info = ComputeReductionCodegenInfo(hero_reduction); + reduction_codegen_info_.emplace(std::move(reduction_codegen_info)); + return &reduction_codegen_info_.value(); +} + +const TilingScheme* HloFusionAnalysis::GetTransposeTilingScheme() { + if (transpose_tiling_scheme_.has_value()) { + return &transpose_tiling_scheme_.value(); + } + + if (!tiled_transpose_) { + return nullptr; + } + + constexpr int kNumRows = 4; + static_assert(WarpSize() % kNumRows == 0); + + // 3D view over the input shape. + Vector3 dims = tiled_transpose_->dimensions; + Vector3 order = tiled_transpose_->permutation; + + Vector3 permuted_dims = {dims[order[0]], dims[order[1]], dims[order[2]]}; + Vector3 tile_sizes{1, 1, 1}; + tile_sizes[order[2]] = WarpSize() / kNumRows; + Vector3 num_threads{1, 1, WarpSize()}; + num_threads[order[2]] = kNumRows; + + TilingScheme tiling_scheme( + /*permuted_dims*/ permuted_dims, + /*tile_sizes=*/tile_sizes, + /*num_threads=*/num_threads, + /*indexing_order=*/kLinearIndexingX, + /*vector_size=*/1, + /*scaling_factor=*/1, + /*tiling_dimensions=*/{order[2], 2}); + transpose_tiling_scheme_.emplace(std::move(tiling_scheme)); + return &transpose_tiling_scheme_.value(); +} + +const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { + if (loop_fusion_config_.has_value()) { + return &loop_fusion_config_.value(); + } + int unroll_factor = 1; // Unrolling is good to read large inputs with small elements // due to vector loads, but increases the register pressure when one @@ -505,7 +501,8 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { if (GetEmitterFusionKind() == EmitterFusionKind::kScatter) { // Only the unroll factor is used for scatter. - return LaunchDimensionsConfig{unroll_factor}; + loop_fusion_config_.emplace(LaunchDimensionsConfig{unroll_factor}); + return &loop_fusion_config_.value(); } bool row_vectorized; @@ -540,7 +537,8 @@ HloFusionAnalysis::ComputeLoopFusionConfig() const { launch_config.row_vectorized = false; launch_config.few_waves = false; } - return launch_config; + loop_fusion_config_.emplace(std::move(launch_config)); + return &loop_fusion_config_.value(); } const Shape& HloFusionAnalysis::GetElementShape() const { @@ -811,13 +809,8 @@ int HloFusionAnalysis::CalculateVirtualThreadScalingFactorForReduction( return 1; } -std::optional -HloFusionAnalysis::ComputeReductionCodegenInfo( +ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const { - if (!hero_reduction) { - return std::nullopt; - } - Shape input_shape = hero_reduction->operand(0)->shape(); ReductionDimensions reduction_dimensions = GetReductionKindAndContiguousComponents(*hero_reduction); diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index 1bec5ca650be47..c07819db2d3a15 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -68,27 +68,19 @@ class HloFusionAnalysis { // Determines the launch dimensions for the fusion. The fusion kind must not // be `kTriton`. - StatusOr GetLaunchDimensions() const; + StatusOr GetLaunchDimensions(); // Calculates the reduction information. Returns `nullptr` if the fusion is // not a reduction. - const ReductionCodegenInfo* GetReductionCodegenInfo() const { - return reduction_codegen_info_.has_value() ? &*reduction_codegen_info_ - : nullptr; - } + const ReductionCodegenInfo* GetReductionCodegenInfo(); // Calculates the transpose tiling information. Returns `nullptr` if the // fusion is not a transpose. - const TilingScheme* GetTransposeTilingScheme() const { - return transpose_tiling_scheme_.has_value() ? &*transpose_tiling_scheme_ - : nullptr; - } + const TilingScheme* GetTransposeTilingScheme(); // Calculates the loop fusion config. Returns `nullptr` if the fusion is not a // loop. - const LaunchDimensionsConfig* GetLoopFusionConfig() const { - return loop_fusion_config_.has_value() ? &*loop_fusion_config_ : nullptr; - } + const LaunchDimensionsConfig* GetLoopFusionConfig(); // Returns the hero reduction of the computation. const HloInstruction* FindHeroReduction() const; @@ -101,7 +93,16 @@ class HloFusionAnalysis { std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, - bool has_4_bit_input, bool has_4_bit_output); + bool has_4_bit_input, bool has_4_bit_output) + : fusion_backend_config_(std::move(fusion_backend_config)), + fusion_roots_(std::move(fusion_roots)), + fusion_boundary_fn_(std::move(fusion_boundary_fn)), + fusion_arguments_(std::move(fusion_arguments)), + fusion_heroes_(std::move(fusion_heroes)), + device_info_(device_info), + tiled_transpose_(tiled_transpose), + has_4_bit_input_(has_4_bit_input), + has_4_bit_output_(has_4_bit_output) {} const Shape& GetElementShape() const; int SmallestInputDtypeBits() const; @@ -117,9 +118,8 @@ class HloFusionAnalysis { bool reduction_is_race_free) const; int CalculateVirtualThreadScalingFactorForReduction( const ReductionDimensions& reduction_dimensions) const; - std::optional ComputeReductionCodegenInfo( + ReductionCodegenInfo ComputeReductionCodegenInfo( const HloInstruction* hero_reduction) const; - std::optional ComputeLoopFusionConfig() const; bool HasConsistentTransposeHeros() const; FusionBackendConfig fusion_backend_config_; @@ -131,8 +131,8 @@ class HloFusionAnalysis { std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; - bool has_4_bit_input_ = false; - bool has_4_bit_output_ = false; + const bool has_4_bit_input_ = false; + const bool has_4_bit_output_ = false; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; diff --git a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h index f7b51c42c6beaf..4a6f0f7ae3c6fa 100644 --- a/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h +++ b/third_party/xla/xla/service/gpu/kernel_mapping_scheme.h @@ -146,34 +146,34 @@ class TilingScheme { private: // The number of elements in each dimension. - Vector3 dims_in_elems_; + const Vector3 dims_in_elems_; // The number of elements for each dimension of a tile. - Vector3 tile_sizes_; + const Vector3 tile_sizes_; // The dimensions which are used for the shared memory tile. - Vector2 tiling_dimensions_; + const Vector2 tiling_dimensions_; // Number of threads implicitly assigned to each dimension. - Vector3 num_threads_; + const Vector3 num_threads_; - IndexingOrder indexing_order_; + const IndexingOrder indexing_order_; // Vector size for dimension X. - int vector_size_; + const int vector_size_; // Scaling apply to transform physical threadIdx into logical. - int64_t thread_id_virtual_scaling_ = 1; + const int64_t thread_id_virtual_scaling_ = 1; }; class ReductionCodegenInfo { public: using IndexGroups = std::vector>; - ReductionCodegenInfo(TilingScheme mapping_scheme, int num_partial_results, - bool is_row_reduction, bool is_race_free, - IndexGroups index_groups, - const HloInstruction* first_reduce) + explicit ReductionCodegenInfo(TilingScheme mapping_scheme, + int num_partial_results, bool is_row_reduction, + bool is_race_free, IndexGroups index_groups, + const HloInstruction* first_reduce) : tiling_scheme_(mapping_scheme), num_partial_results_(num_partial_results), is_row_reduction_(is_row_reduction), @@ -198,7 +198,7 @@ class ReductionCodegenInfo { private: friend class ReductionCodegenState; - TilingScheme tiling_scheme_; + const TilingScheme tiling_scheme_; int num_partial_results_; bool is_row_reduction_; bool is_race_free_; diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 034f2a2f20d2ff..1b542e7dca447a 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -66,34 +66,6 @@ xla_test( ], ) -cc_library( - name = "fusion_analysis_cache", - srcs = ["fusion_analysis_cache.cc"], - hdrs = ["fusion_analysis_cache.h"], - visibility = ["//visibility:public"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/synchronization", - ], -) - -xla_cc_test( - name = "fusion_analysis_cache_test", - srcs = ["fusion_analysis_cache_test.cc"], - deps = [ - ":fusion_analysis_cache", - "//xla/service:hlo_parser", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "gpu_cost_model_stats_collection", srcs = ["gpu_cost_model_stats_collection.cc"], @@ -177,7 +149,6 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ - ":fusion_analysis_cache", ":gpu_hlo_cost_analysis", "//xla:shape_util", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc deleted file mode 100644 index 00a294413506ac..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/fusion_analysis_cache.h" - -#include "xla/hlo/ir/hlo_instruction.h" - -namespace xla::gpu { - -const std::optional& HloFusionAnalysisCache::Get( - const HloInstruction& instruction) { - { - absl::ReaderMutexLock lock(&mutex_); - auto it = analyses_.find(instruction.unique_id()); - if (it != analyses_.end()) { - return it->second; - } - } - - std::optional analysis = - AnalyzeFusion(instruction, device_info_); - absl::MutexLock lock(&mutex_); - - // If some other thread created an entry for this key concurrently, return - // that instead (the other thread is likely using the instance). - auto it = analyses_.find(instruction.unique_id()); - if (it != analyses_.end()) { - return it->second; - } - - return analyses_[instruction.unique_id()] = std::move(analysis); -} - -const std::optional& HloFusionAnalysisCache::Get( - const HloInstruction& producer, const HloInstruction& consumer) { - std::pair key{producer.unique_id(), consumer.unique_id()}; - { - absl::ReaderMutexLock lock(&mutex_); - auto it = producer_consumer_analyses_.find(key); - if (it != producer_consumer_analyses_.end()) { - return it->second; - } - } - - std::optional analysis = - AnalyzeProducerConsumerFusion(producer, consumer, device_info_); - absl::MutexLock lock(&mutex_); - - // If some other thread created an entry for this key concurrently, return - // that instead (the other thread is likely using the instance). - auto it = producer_consumer_analyses_.find(key); - if (it != producer_consumer_analyses_.end()) { - return it->second; - } - - producers_for_consumers_[consumer.unique_id()].push_back( - producer.unique_id()); - consumers_for_producers_[producer.unique_id()].push_back( - consumer.unique_id()); - return producer_consumer_analyses_[key] = std::move(analysis); -} - -void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - analyses_.erase(instruction.unique_id()); - - if (auto consumers = - consumers_for_producers_.extract(instruction.unique_id())) { - for (const auto consumer : consumers.mapped()) { - producer_consumer_analyses_.erase({instruction.unique_id(), consumer}); - } - } - if (auto producers = - producers_for_consumers_.extract(instruction.unique_id())) { - for (const auto producer : producers.mapped()) { - producer_consumer_analyses_.erase({producer, instruction.unique_id()}); - } - } -} - -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h deleted file mode 100644 index b13c0a102f3704..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ -#define XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ - -#include "absl/container/node_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/stream_executor/device_description.h" - -namespace xla::gpu { - -// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` -// and `Invalidate` the same key. Analyses are cached based on unique_ids, no -// checking or tracking of changes is done. -class HloFusionAnalysisCache { - public: - explicit HloFusionAnalysisCache( - const stream_executor::DeviceDescription& device_info) - : device_info_(device_info) {} - - // Returns the analysis for the given instruction, creating it if it doesn't - // exist yet. Do not call concurrently with `Invalidate` for the same key. - const std::optional& Get( - const HloInstruction& instruction); - - // Returns the analysis for the given producer/consumer pair. - const std::optional& Get(const HloInstruction& producer, - const HloInstruction& consumer); - - // Removes the cache entry for the given instruction, if it exists. Also - // removes all producer-consumer fusions that involve this instruction. - void Invalidate(const HloInstruction& instruction); - - private: - const stream_executor::DeviceDescription& device_info_; - - absl::Mutex mutex_; - -// All `int` keys and values here are unique instruction IDs. - absl::node_hash_map> analyses_; - absl::node_hash_map, std::optional> - producer_consumer_analyses_; - - // For each instruction `producer`, contains the `consumer`s for which we have - // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. - absl::flat_hash_map> consumers_for_producers_; - // For each instruction `consumer`, contains the `producer`s for which we have - // entries {`producer`, `consumer`} in `producer_consumer_analyses_`. - absl::flat_hash_map> producers_for_consumers_; -}; - -} // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_MODEL_FUSION_ANALYSIS_CACHE_H_ diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc deleted file mode 100644 index edacd6a7c8666b..00000000000000 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/fusion_analysis_cache.h" - -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/hlo_parser.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla::gpu { -namespace { - -class FusionAnalysisCacheTest : public HloTestBase { - public: - stream_executor::DeviceDescription device_{ - TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - HloFusionAnalysisCache cache_{device_}; -}; - -TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { - absl::string_view hlo_string = R"( - HloModule m - - f { - c0 = f32[] constant(0) - b0 = f32[1000] broadcast(c0) - ROOT n0 = f32[1000] negate(b0) - } - - ENTRY e { - ROOT r.1 = f32[1000] fusion(), kind=kLoop, calls=f - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto* computation = module->GetComputationWithName("f"); - auto* broadcast = computation->GetInstructionWithName("b0"); - auto* negate = computation->GetInstructionWithName("n0"); - auto* fusion = module->entry_computation()->root_instruction(); - - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(negate)); - - computation->set_root_instruction(broadcast); - - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(negate)) - << "Analysis should be cached."; - - cache_.Invalidate(*fusion); - EXPECT_THAT(cache_.Get(*fusion)->fusion_roots(), - ::testing::ElementsAre(broadcast)) - << "Analysis should have been recomputed"; -} - -TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { - absl::string_view hlo_string = R"( - HloModule m - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - f { - c0 = f32[] constant(0) - b0 = f32[1000] broadcast(c0) - ROOT r0 = f32[] reduce(b0, c0), dimensions={0}, to_apply=add - } - - ENTRY e { - f0 = f32[] fusion(), kind=kInput, calls=f - ROOT n0 = f32[] negate(f0) - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto* fusion = module->entry_computation()->GetInstructionWithName("f0"); - auto* neg = module->entry_computation()->GetInstructionWithName("n0"); - - auto* computation = module->GetComputationWithName("f"); - auto* constant = computation->GetInstructionWithName("c0"); - - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kReduction); - - computation->set_root_instruction(constant); - - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kReduction) - << "Analysis should be cached."; - - cache_.Invalidate(*fusion); - EXPECT_EQ(cache_.Get(*fusion, *neg)->GetEmitterFusionKind(), - HloFusionAnalysis::EmitterFusionKind::kLoop) - << "Analysis should have been recomputed"; -} - -} // namespace -} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index c9acce7f65c3ca..64e9132d5ce507 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -224,7 +224,7 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, // that the IR emitter will use. LaunchDimensions EstimateFusionLaunchDimensions( int64_t estimated_num_threads, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const se::DeviceDescription& device_info) { if (fusion_analysis) { // TODO(jreiffers): This is the wrong place for this DUS analysis. @@ -269,15 +269,7 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( int64_t bytes_written = cost_analysis->output_bytes_accessed(*instr); int64_t bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; - // Use the analysis cache if present. - // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*instr, *cost_analysis->device_info_); - const auto& fusion_analysis = config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*instr) - : local_analysis; + auto fusion_analysis = AnalyzeFusion(*instr, *cost_analysis->device_info_); LaunchDimensions launch_dimensions = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(instr->shape()), fusion_analysis, *device_info); @@ -349,7 +341,7 @@ float GetCommonUtilization( const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); @@ -438,16 +430,7 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( float utilization_by_this_consumer = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); - // Use the analysis cache if present. - // TODO(jreiffers): Remove this once all callers use a cache. - std::optional local_analysis = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeFusion(*fused_consumer, *device_info); - const auto& analysis_unfused = - config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*fused_consumer) - : local_analysis; + auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); LaunchDimensions launch_dimensions_unfused = EstimateFusionLaunchDimensions( ShapeUtil::ElementsInRecursive(fused_consumer->shape()), @@ -496,15 +479,8 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( // // TODO(shyshkov): Add calculations for consumer epilogue in the formula to // make it complete. - std::optional local_analysis_fused = - config.fusion_analysis_cache - ? std::nullopt - : AnalyzeProducerConsumerFusion(*producer, *fused_consumer, - *device_info); - const auto& analysis_fused = - config.fusion_analysis_cache - ? config.fusion_analysis_cache->Get(*producer, *fused_consumer) - : local_analysis_fused; + auto analysis_fused = + AnalyzeProducerConsumerFusion(*producer, *fused_consumer, *device_info); LaunchDimensions launch_dimensions_fused = EstimateFusionLaunchDimensions( producer_data.num_threads * utilization_by_this_consumer, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h index 0fcc8cfcb2abf2..b7b28fff1eeda7 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/time/time.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" @@ -63,25 +62,20 @@ struct GpuPerformanceModelOptions { // re-reads can happen from cache. bool first_read_from_dram = false; - // If present, use this to retrieve fusion analyses. - HloFusionAnalysisCache* fusion_analysis_cache = nullptr; - static GpuPerformanceModelOptions Default() { return GpuPerformanceModelOptions(); } - static GpuPerformanceModelOptions PriorityFusion( - HloFusionAnalysisCache* fusion_analysis_cache) { + static GpuPerformanceModelOptions PriorityFusion() { GpuPerformanceModelOptions config; config.consider_coalescing = true; config.first_read_from_dram = true; - config.fusion_analysis_cache = fusion_analysis_cache; return config; } static GpuPerformanceModelOptions ForModule(const HloModule* module) { return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion(nullptr) // Only cache within priority fusion. + ? PriorityFusion() : Default(); } }; @@ -127,7 +121,7 @@ class GpuPerformanceModel { const GpuHloCostAnalysis* cost_analysis, const se::DeviceDescription& gpu_device_info, int64_t num_blocks, const HloInstruction* producer, - const std::optional& fusion_analysis, + std::optional& fusion_analysis, const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer = nullptr); }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index d768bce08c55ef..68bde4b9010382 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -360,7 +360,7 @@ ENTRY fusion { std::vector consumers{ module->entry_computation()->GetInstructionWithName("reduce.1")}; GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(nullptr), + producer, &analysis_, GpuPerformanceModelOptions::PriorityFusion(), consumers); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 105, 10); diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 1c2e7c93970e30..8df808c1bfad82 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -42,7 +42,6 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/instruction_fusion.h" @@ -80,14 +79,12 @@ class GpuPriorityFusionQueue : public FusionQueue { const GpuHloCostAnalysis::Options& cost_analysis_options, const se::DeviceDescription* device_info, const CanFuseCallback& can_fuse, FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool, - HloFusionAnalysisCache& fusion_analysis_cache) + tsl::thread::ThreadPool* thread_pool) : computation_(computation), cost_analysis_(cost_analysis_options, device_info), can_fuse_(can_fuse), fusion_process_dump_(fusion_process_dump), - thread_pool_(thread_pool), - fusion_analysis_cache_(fusion_analysis_cache) { + thread_pool_(thread_pool) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -184,9 +181,6 @@ class GpuPriorityFusionQueue : public FusionQueue { fusion_step->set_consumer_name(std::string(original_consumer->name())); } - fusion_analysis_cache_.Invalidate(*fusion); - fusion_analysis_cache_.Invalidate(*original_producer); - // The original consumer was replaced with the fusion, but it's pointer can // still be referenced somewhere, for example, in to_update_priority_. // Priority recomputation is called before DCE. Remove all references to @@ -264,7 +258,6 @@ class GpuPriorityFusionQueue : public FusionQueue { void RemoveInstruction(HloInstruction* instruction) override { to_update_priority_.erase(instruction); producer_user_count_.erase(instruction); - fusion_analysis_cache_.Invalidate(*instruction); auto reverse_it = reverse_map_.find(instruction); if (reverse_it == reverse_map_.end()) { @@ -296,8 +289,7 @@ class GpuPriorityFusionQueue : public FusionQueue { GpuPerformanceModel::RunTimes run_times = GpuPerformanceModel::EstimateRunTimes( producer, &cost_analysis_, - GpuPerformanceModelOptions::PriorityFusion(&fusion_analysis_cache_), - producer->users()); + GpuPerformanceModelOptions::PriorityFusion(), producer->users()); if (fusion_process_dump_) { absl::MutexLock lock(&fusion_process_dump_mutex_); auto* step = @@ -373,8 +365,6 @@ class GpuPriorityFusionQueue : public FusionQueue { absl::Mutex fusion_process_dump_mutex_; tsl::thread::ThreadPool* thread_pool_; - - HloFusionAnalysisCache& fusion_analysis_cache_; }; } // namespace @@ -512,7 +502,8 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( // matter but some passes downstream still query these instead of fusion // analysis. // TODO: Don't recompute this all the time. - const auto& analysis = fusion_analysis_cache_.Get(*producer, *consumer); + auto analysis = + AnalyzeProducerConsumerFusion(*producer, *consumer, device_info_); if (!analysis) return HloInstruction::FusionKind::kLoop; switch (analysis->GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kLoop: @@ -553,7 +544,7 @@ std::unique_ptr GpuPriorityFusion::GetFusionQueue( [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index); }, - fusion_process_dump_.get(), thread_pool_, fusion_analysis_cache_)); + fusion_process_dump_.get(), thread_pool_)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/priority_fusion.h index afc5e8f99003d4..1723766d7784c8 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/priority_fusion.h @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" -#include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" @@ -43,13 +42,12 @@ namespace gpu { class GpuPriorityFusion : public InstructionFusion { public: GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& device, + const se::DeviceDescription& d, GpuHloCostAnalysis::Options cost_analysis_options) : InstructionFusion(GpuPriorityFusion::IsExpensive), thread_pool_(thread_pool), - device_info_(device), - cost_analysis_options_(std::move(cost_analysis_options)), - fusion_analysis_cache_(device_info_) {} + device_info_(d), + cost_analysis_options_(std::move(cost_analysis_options)) {} absl::string_view name() const override { return "priority-fusion"; } @@ -88,7 +86,6 @@ class GpuPriorityFusion : public InstructionFusion { absl::Mutex fusion_node_evaluations_mutex_; absl::flat_hash_map fusion_node_evaluations_; - HloFusionAnalysisCache fusion_analysis_cache_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 310340af91d391..5574173a75ef11 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -237,10 +237,10 @@ TEST_F(PriorityFusionTest, ReductionEpilogueFusionRegressionTest) { } )"; - EXPECT_THAT(RunAndGetFusionKinds(kHlo), - ::testing::UnorderedElementsAre( - HloFusionAnalysis::EmitterFusionKind::kLoop, - HloFusionAnalysis::EmitterFusionKind::kReduction)); + EXPECT_THAT( + RunAndGetFusionKinds(kHlo), + ::testing::ElementsAre(HloFusionAnalysis::EmitterFusionKind::kLoop, + HloFusionAnalysis::EmitterFusionKind::kReduction)); RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_), R"( CHECK: ENTRY From 50f2ad9373419bbf638682078dc16906f60571e1 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 20 Nov 2023 05:09:46 -0800 Subject: [PATCH 287/391] [XLA:GPU][NFC] Rename kParameterPerScope to kParameterPerDotScope. Adds a comment explaining how the constant was set, and makes it clear that it is not necessarily applicable for fusions that do not involve GEMMs. PiperOrigin-RevId: 583993985 --- third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc | 7 ++++--- third_party/xla/xla/service/gpu/gemm_rewriter_triton.h | 5 ++++- .../xla/xla/service/gpu/gemm_rewriter_triton_test.cc | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 5477e0f0fd8bd0..5cb2e99e5dedc1 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -1239,7 +1239,7 @@ void FusionContext::TryToFuseWithInputsRecursively( to_visit.pop(); // Watch the total number of fusion parameters. if (inputs.size() + NumAddedParameters(*hlo) > - TritonFusionAnalysis::kMaxParameterPerScope) { + TritonFusionAnalysis::kMaxParameterPerDotScope) { // Re-queue: the number of parameters may go down when other instructions // are processed. to_visit.push(hlo); @@ -1323,9 +1323,10 @@ StatusOr FuseDot(HloInstruction& dot, gpu_version, old_to_new_mapping, fusion_inputs, builder); const int new_parameters = fusion_inputs.size() - operand_count_before; - TF_RET_CHECK(new_parameters <= TritonFusionAnalysis::kMaxParameterPerScope) + TF_RET_CHECK(new_parameters <= + TritonFusionAnalysis::kMaxParameterPerDotScope) << "Too many new parameters: " << new_parameters << " > " - << TritonFusionAnalysis::kMaxParameterPerScope; + << TritonFusionAnalysis::kMaxParameterPerDotScope; return context; }; diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h index db52fa54aadb31..2f15dadaa883a0 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.h @@ -131,7 +131,10 @@ class TritonFusionAnalysis { // Every parameter requires a separate piece of shared memory for asynchronous // loads. Multiple parameters are approximately equivalent to multiple // pipeline stages. - static constexpr int kMaxParameterPerScope = 4; + // Note: this has been tuned specifically for GEMMs, where pipelining with + // more than 4 stages has been shown to rarely be practical. This limitation + // is not necessarily applicable to other operations. + static constexpr int kMaxParameterPerDotScope = 4; // Scope -> HLO -> dot dimension number -> iteration spec at the HLO's output. const TensorIterationSpec::DimIterationSpec* IterSpec(Scope scope, diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc index 7a78d9523db1b2..dc9c8032880819 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc @@ -1121,7 +1121,7 @@ ENTRY e { EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), HloInstruction::FusionKind::kCustom); EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - TritonFusionAnalysis::kMaxParameterPerScope * 2); + TritonFusionAnalysis::kMaxParameterPerDotScope * 2); } TEST_F(GemmRewriterTritonLevel2Test, @@ -1149,7 +1149,7 @@ ENTRY e { EXPECT_EQ(module->entry_computation()->root_instruction()->fusion_kind(), HloInstruction::FusionKind::kCustom); EXPECT_LE(module->entry_computation()->root_instruction()->operand_count(), - TritonFusionAnalysis::kMaxParameterPerScope + 1); + TritonFusionAnalysis::kMaxParameterPerDotScope + 1); } TEST_F(GemmRewriterTritonLevel2Test, From 3cfc3f3c98003c9cd57f9b49a9c2dac7b48f5fbc Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 20 Nov 2023 05:20:18 -0800 Subject: [PATCH 288/391] [XLA:GPU] Use the same coalescing heuristics for fused and unfused case. PiperOrigin-RevId: 583995859 --- .../gpu/model/gpu_performance_model.cc | 54 ++++++++++++------- .../xla/service/gpu/priority_fusion_test.cc | 4 +- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 64e9132d5ce507..4c2947f094288a 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -257,6 +257,35 @@ LaunchDimensions EstimateFusionLaunchDimensions( return LaunchDimensions(num_blocks, block_size); } +// Returns true if all input reads are coalesced. If consumer is not nullptr, +// producer and consumer are considered as one fusion, otherwise it's only the +// producer. +// +// This is a crude heuristic until we get proper tile analysis. +bool IsReadCoalesced(const std::optional& fusion_analysis, + const GpuPerformanceModelOptions& config, + const HloInstruction* producer, + const HloInstruction* consumer = nullptr) { + if (!config.consider_coalescing) return true; + + bool coalesced = (fusion_analysis && + fusion_analysis->GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kTranspose) || + (!TransposesMinorDimension(producer) && + !(consumer && TransposesMinorDimension(consumer))); + + if (consumer) { + // Fusing two row reductions breaks coalescing. + coalesced &= (fusion_analysis && + fusion_analysis->GetEmitterFusionKind() != + HloFusionAnalysis::EmitterFusionKind::kReduction) || + !IsInputFusibleReduction(*producer) || + !IsInputFusibleReduction(*consumer); + } + + return coalesced; +} + } // namespace /*static*/ EstimateRunTimeData @@ -347,9 +376,7 @@ float GetCommonUtilization( absl::Duration ret = absl::ZeroDuration(); float producer_output_utilization = 1.f; ConstHloInstructionMap consumer_operands; - bool consumer_transposes = false; if (fused_consumer) { - consumer_transposes = TransposesMinorDimension(fused_consumer); producer_output_utilization = cost_analysis->operand_utilization( *fused_consumer, fused_consumer->operand_index(producer)); for (int64_t i = 0; i < fused_consumer->operand_count(); ++i) { @@ -357,7 +384,9 @@ float GetCommonUtilization( } } - bool producer_transposes = TransposesMinorDimension(producer); + // TODO(jreiffers): We should be checking each operand. + bool coalesced = + IsReadCoalesced(fusion_analysis, config, producer, fused_consumer); for (int i = 0; i < producer->operand_count(); ++i) { // Information about data read taking into account utilization. // If `operand_utilization` is 0, `operand_bytes_accessed` should be also 0. @@ -381,25 +410,13 @@ float GetCommonUtilization( /*producer_idx_of_operand=*/i, fused_consumer, consumer_operands, cost_analysis); - // TODO(jreiffers): We should be checking each operand here. - bool coalesced = (fusion_analysis && - fusion_analysis->GetEmitterFusionKind() == - HloFusionAnalysis::EmitterFusionKind::kTranspose) || - (!producer_transposes && !consumer_transposes); - // Fusing two row reductions breaks coalescing. - coalesced &= ((fusion_analysis && - fusion_analysis->GetEmitterFusionKind() != - HloFusionAnalysis::EmitterFusionKind::kReduction) || - !fused_consumer || !IsInputFusibleReduction(*producer) || - !IsInputFusibleReduction(*fused_consumer)); const auto& operand_shape = producer->operand(i)->shape(); CHECK_LE(common_utilization, producer_output_utilization); float n_bytes_total = operand_bytes_accessed * (producer_output_utilization - common_utilization); ret += ReadTime(gpu_device_info, num_blocks, /*n_bytes_net=*/n_bytes_net, - n_bytes_total, operand_shape.element_type(), - coalesced || !config.consider_coalescing, + n_bytes_total, operand_shape.element_type(), coalesced, config.first_read_from_dram); } return ret; @@ -440,10 +457,11 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( utilization_by_this_consumer); int64_t n_bytes_net = std::min(producer_data.bytes_written, n_bytes_total); + bool coalesced = + IsReadCoalesced(analysis_unfused, config, /*producer=*/fused_consumer); auto read_time_unfused = ReadTime( *device_info, launch_dimensions_unfused.num_blocks(), n_bytes_net, - n_bytes_total, fused_consumer->shape().element_type(), - /*coalesced=*/!TransposesMinorDimension(fused_consumer), + n_bytes_total, fused_consumer->shape().element_type(), coalesced, config.first_read_from_dram); VLOG(10) << " Read time unfused: " << read_time_unfused; diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/priority_fusion_test.cc index 5574173a75ef11..f8b6b1b1486cc9 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion_test.cc @@ -318,8 +318,8 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { using Kind = HloFusionAnalysis::EmitterFusionKind; EXPECT_THAT(RunAndGetFusionKinds(kHlo), ::testing::UnorderedElementsAre( - Kind::kReduction, Kind::kReduction, Kind::kTranspose, - Kind::kTranspose, Kind::kTranspose)); + Kind::kLoop, Kind::kReduction, Kind::kReduction, + Kind::kTranspose, Kind::kTranspose, Kind::kTranspose)); } TEST_F(PriorityFusionTest, DoNotFuseReduceIntoReduce) { From 04603093f7f023c255eefd135aa3788110363400 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 20 Nov 2023 05:53:02 -0800 Subject: [PATCH 289/391] [XLA:GPU] Remove unnecessary local variable. PiperOrigin-RevId: 584001557 --- third_party/xla/xla/service/gpu/gpu_compiler.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index ff0294cafa5baf..8f78780c107da9 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1686,13 +1686,10 @@ StatusOr> GpuCompiler::RunBackend( std::vector allocations; if (res.compile_module_results.use_original_allocations) { if (!options.is_autotuning_compilation) { - std::vector original_allocations = - buffer_assignment->ReleaseAllocations(); - allocations = std::move(original_allocations); + allocations = buffer_assignment->ReleaseAllocations(); } else { - std::vector original_allocations = + allocations = res.compile_module_results.buffer_assignment->ReleaseAllocations(); - allocations = std::move(original_allocations); } } else { allocations = std::move(res.compile_module_results.allocations); From f0b420f37a82e61aa33255287933f364fb02d7c2 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 20 Nov 2023 06:47:11 -0800 Subject: [PATCH 290/391] [XLA:GPU] Add GetOperandUtilization helper. (NFC) Having separate ConstHloInstructionMap is unnecessary and likely inefficient. PiperOrigin-RevId: 584012601 --- .../gpu/model/gpu_performance_model.cc | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 4c2947f094288a..8df7598fb8483a 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -325,6 +325,19 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( return {flops, bytes_written, num_threads, write_time, exec_time}; } +// Returns utilization of operand by instruction. Returns 0, if the operand is +// not used by the instruction. +float GetOperandUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* instr, + const HloInstruction* operand) { + if (!instr->IsUserOf(operand)) { + return 0.f; + } + + return cost_analysis->operand_utilization(*instr, + instr->operand_index(operand)); +} + // Returns utilization `overlap` between a common operand of producer and // consumer on merge. `utilization > 0` means that the operand will be accessed // more efficiently after fusion. @@ -334,14 +347,13 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( // a fusion or just be an elementwise instruction. // 2) Consumer has to have common elementwise roots for the producer and the // common operand if it is a fusion or just be an elementwise instruction. -float GetCommonUtilization( - const HloInstruction* producer, int64_t producer_idx_of_operand, - const HloInstruction* consumer, - const ConstHloInstructionMap& consumer_operands, - const GpuHloCostAnalysis* cost_analysis) { - auto consumer_idx_of_operand = - consumer_operands.find(producer->operand(producer_idx_of_operand)); - if (consumer_idx_of_operand == consumer_operands.end()) { +float GetCommonUtilization(const GpuHloCostAnalysis* cost_analysis, + const HloInstruction* producer, + int64_t producer_idx_of_operand, + const HloInstruction* consumer) { + const auto* operand = producer->operand(producer_idx_of_operand); + + if (!consumer || !consumer->IsUserOf(operand)) { return 0.f; } @@ -350,9 +362,10 @@ float GetCommonUtilization( FusionUsesParameterElementwiseFromRoot(producer, producer_idx_of_operand, cost_analysis))) { if (consumer->opcode() == HloOpcode::kFusion) { - int64_t consumer_idx_of_producer = consumer_operands.at(producer); + int64_t consumer_idx_of_common_operand = consumer->operand_index(operand); + int64_t consumer_idx_of_producer = consumer->operand_index(producer); return cost_analysis->CommonElementwiseUtilization( - consumer->fused_parameter(consumer_idx_of_operand->second), + consumer->fused_parameter(consumer_idx_of_common_operand), consumer->fused_parameter(consumer_idx_of_producer)); } else { if (consumer->IsElementwise()) { @@ -374,15 +387,10 @@ float GetCommonUtilization( const GpuPerformanceModelOptions& config, const HloInstruction* fused_consumer) { absl::Duration ret = absl::ZeroDuration(); - float producer_output_utilization = 1.f; - ConstHloInstructionMap consumer_operands; - if (fused_consumer) { - producer_output_utilization = cost_analysis->operand_utilization( - *fused_consumer, fused_consumer->operand_index(producer)); - for (int64_t i = 0; i < fused_consumer->operand_count(); ++i) { - consumer_operands[fused_consumer->operand(i)] = i; - } - } + float producer_output_utilization = + fused_consumer + ? GetOperandUtilization(cost_analysis, fused_consumer, producer) + : 1.f; // TODO(jreiffers): We should be checking each operand. bool coalesced = @@ -405,10 +413,8 @@ float GetCommonUtilization( // Look if common operand of producer and consumer will be accessed more // efficiently on merge. - float common_utilization = - GetCommonUtilization(producer, - /*producer_idx_of_operand=*/i, fused_consumer, - consumer_operands, cost_analysis); + float common_utilization = GetCommonUtilization( + cost_analysis, producer, /*producer_idx_of_operand=*/i, fused_consumer); const auto& operand_shape = producer->operand(i)->shape(); @@ -444,8 +450,8 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( for (const HloInstruction* fused_consumer : fused_consumers) { VLOG(8) << "Unfused consumer: " << fused_consumer->name(); - float utilization_by_this_consumer = cost_analysis->operand_utilization( - *fused_consumer, fused_consumer->operand_index(producer)); + float utilization_by_this_consumer = + GetOperandUtilization(cost_analysis, fused_consumer, producer); auto analysis_unfused = AnalyzeFusion(*fused_consumer, *device_info); From 14717f4b6110d9e4a1289404faa638b9a46ce2f0 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 20 Nov 2023 07:59:36 -0800 Subject: [PATCH 291/391] [xla:gpu] Extend LMHLO to GPU runtime lowering to the no_parallel_custom_call attribute. Also extend the GPU runtime for NCCL to include this attribute. PiperOrigin-RevId: 584027010 --- .../gpu/transforms/lmhlo_to_gpu_runtime.cc | 15 ++++ .../xla/service/gpu/runtime/collectives.cc | 87 ++++++++++++------- 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc index 266370e34c7f25..593b81ac2dcc61 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc @@ -791,6 +791,18 @@ class CollectiveOpLowering : public OpRewritePattern { return op.getIsSync(); } + template + static typename std::enable_if_t, bool> + noParallelCustomCall(OpT) { + return false; + } + + template + static typename std::enable_if_t, bool> + noParallelCustomCall(OpT op) { + return op.getNoParallelCustomCall(); + } + // For async collective erase all corresponding done operations. template void eraseDoneOp(PatternRewriter& rewriter, CollectiveOp op) const { @@ -913,6 +925,9 @@ class CollectiveOpLowering : public OpRewritePattern { bool is_async = !getIsSync(op); call->setAttr(b.getStringAttr("is_async"), b.getBoolAttr(is_async)); + call->setAttr(b.getStringAttr("no_parallel_custom_call"), + b.getBoolAttr(noParallelCustomCall(op))); + // If the collective will not execute asynchronously, erase the associated // done op. if (!is_async) { diff --git a/third_party/xla/xla/service/gpu/runtime/collectives.cc b/third_party/xla/xla/service/gpu/runtime/collectives.cc index fd5ffa470b671e..809da0d431c63f 100644 --- a/third_party/xla/xla/service/gpu/runtime/collectives.cc +++ b/third_party/xla/xla/service/gpu/runtime/collectives.cc @@ -232,6 +232,7 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, const DebugOptions* debug_options, se::Stream* stream, CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, + bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values, absl::Span source_peers, @@ -239,6 +240,7 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, NcclP2PRunner runner, DeviceBuffersGetter device_buffers_getter, uint64_t stream_id) { + (void)no_parallel_custom_call; NcclExecuteParams params(*run_options, stream->parent()); const std::string device_string = @@ -287,12 +289,14 @@ absl::Status CollectivePermuteImpl( const DebugOptions* debug_options, CollectivesSupport* collectives, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, + bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values, absl::Span source_peers, absl::Span target_peers) { #if XLA_ENABLE_XCCL - VLOG(3) << "Running CollectivePermute " << (is_async ? "(Async)" : "(Sync)"); + VLOG(3) << "Running CollectivePermute " << (is_async ? "(Async) " : "(Sync) ") + << no_parallel_custom_call; return RunSyncOrAsync( run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { @@ -302,10 +306,10 @@ absl::Status CollectivePermuteImpl( return NcclMockImplCommon(stream); } return P2PImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, replica_group_offsets, - replica_group_values, source_peers, target_peers, - RunCollectivePermute, GetDeviceBufferPairs, - GetStreamId(is_async)); + group_mode, op_id, no_parallel_custom_call, + replica_group_offsets, replica_group_values, + source_peers, target_peers, RunCollectivePermute, + GetDeviceBufferPairs, GetStreamId(is_async)); }); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL disabled"); @@ -324,6 +328,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr>("replica_group_offsets") .Attr>("replica_group_values") .Attr>("source_peers") @@ -339,7 +344,7 @@ static absl::Status P2PSendImpl(const ServiceExecutableRunOptions* run_options, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, - bool is_async, + bool is_async, bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values, absl::Span source_peers, @@ -351,9 +356,10 @@ static absl::Status P2PSendImpl(const ServiceExecutableRunOptions* run_options, run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { return P2PImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, replica_group_offsets, - replica_group_values, source_peers, target_peers, - RunSend, GetSingleArgAsDeviceBufferPair, + group_mode, op_id, no_parallel_custom_call, + replica_group_offsets, replica_group_values, + source_peers, target_peers, RunSend, + GetSingleArgAsDeviceBufferPair, GetStreamId(is_async, kAsyncStreamP2P)); }, kAsyncStreamP2P); @@ -374,6 +380,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr>("replica_group_offsets") .Attr>("replica_group_values") .Attr>("source_peers") @@ -389,7 +396,7 @@ static absl::Status P2PRecvImpl(const ServiceExecutableRunOptions* run_options, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, - bool is_async, + bool is_async, bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values, absl::Span source_peers, @@ -401,9 +408,10 @@ static absl::Status P2PRecvImpl(const ServiceExecutableRunOptions* run_options, run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { return P2PImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, replica_group_offsets, - replica_group_values, source_peers, target_peers, - RunRecv, GetSingleArgAsDeviceBufferPair, + group_mode, op_id, no_parallel_custom_call, + replica_group_offsets, replica_group_values, + source_peers, target_peers, RunRecv, + GetSingleArgAsDeviceBufferPair, GetStreamId(is_async, kAsyncStreamP2P)); }, kAsyncStreamP2P); @@ -424,6 +432,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr>("replica_group_offsets") .Attr>("replica_group_values") .Attr>("source_peers") @@ -439,9 +448,10 @@ absl::Status AllGatherImplCommon( const DebugOptions* debug_options, se::Stream* stream, CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async) { + absl::Span replica_group_values, bool is_async, + bool no_parallel_custom_call) { NcclExecuteParams params(*run_options, stream->parent()); - + (void)no_parallel_custom_call; TF_ASSIGN_OR_RETURN( auto comm, GetNcclComm(params, group_mode, op_id, replica_group_offsets, @@ -462,10 +472,12 @@ absl::Status AllGatherImpl(const ServiceExecutableRunOptions* run_options, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, + bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values) { #if XLA_ENABLE_XCCL - VLOG(3) << "Running AllGather " << (is_async ? "(Async)" : "(Sync)"); + VLOG(3) << "Running AllGather " << (is_async ? "(Async) " : "(Sync) ") + << no_parallel_custom_call; return RunSyncOrAsync( run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { @@ -476,7 +488,8 @@ absl::Status AllGatherImpl(const ServiceExecutableRunOptions* run_options, } return AllGatherImplCommon(run_options, debug_options, stream, args, group_mode, op_id, replica_group_offsets, - replica_group_values, is_async); + replica_group_values, is_async, + no_parallel_custom_call); }); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL diasbled"); @@ -495,6 +508,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr>("replica_group_offsets") .Attr>("replica_group_values")); @@ -508,7 +522,9 @@ absl::Status AllReduceImplCommon( const DebugOptions* debug_options, se::Stream* stream, CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async) { + absl::Span replica_group_values, bool is_async, + bool no_parallel_custom_call) { + (void)no_parallel_custom_call; NcclExecuteParams params(*run_options, stream->parent()); TF_ASSIGN_OR_RETURN( @@ -533,11 +549,12 @@ absl::Status AllReduceImpl(const ServiceExecutableRunOptions* run_options, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, - int64_t reduction_kind, + bool no_parallel_custom_call, int64_t reduction_kind, absl::Span replica_group_offsets, absl::Span replica_group_values) { #if XLA_ENABLE_XCCL - VLOG(3) << "Running AllReduce " << (is_async ? "(Async)" : "(Sync)"); + VLOG(3) << "Running AllReduce " << (is_async ? "(Async) " : "(Sync) ") + << no_parallel_custom_call; return RunSyncOrAsync( run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { @@ -549,7 +566,7 @@ absl::Status AllReduceImpl(const ServiceExecutableRunOptions* run_options, return AllReduceImplCommon(run_options, debug_options, stream, args, group_mode, op_id, reduction_kind, replica_group_offsets, replica_group_values, - is_async); + is_async, no_parallel_custom_call); }); #else // XLA_ENABLE_XCCL // NCCL disabled. @@ -569,6 +586,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr("reduction_kind") // ReductionKind .Attr>("replica_group_offsets") .Attr>("replica_group_values")); @@ -586,7 +604,8 @@ absl::Status AllToAllImplCommon(const ServiceExecutableRunOptions* run_options, int64_t op_id, absl::Span replica_group_offsets, absl::Span replica_group_values, - bool is_async) { + bool is_async, bool no_parallel_custom_call) { + (void)no_parallel_custom_call; NcclExecuteParams params(*run_options, stream->parent()); TF_ASSIGN_OR_RETURN( @@ -611,10 +630,12 @@ absl::Status AllToAllImpl(const ServiceExecutableRunOptions* run_options, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, bool has_split_dimension, int64_t op_id, bool is_async, + bool no_parallel_custom_call, absl::Span replica_group_offsets, absl::Span replica_group_values) { #if XLA_ENABLE_XCCL - VLOG(3) << "Running AllToAll " << (is_async ? "(Async)" : "(Sync)"); + VLOG(3) << "Running AllToAll " << (is_async ? "(Async) " : "(Sync) ") + << no_parallel_custom_call; return RunSyncOrAsync( run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { @@ -626,7 +647,7 @@ absl::Status AllToAllImpl(const ServiceExecutableRunOptions* run_options, return AllToAllImplCommon(run_options, debug_options, stream, args, group_mode, has_split_dimension, op_id, replica_group_offsets, replica_group_values, - is_async); + is_async, no_parallel_custom_call); }); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL disabled"); @@ -646,6 +667,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("has_split_dimension") .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr>("replica_group_offsets") .Attr>("replica_group_values")); @@ -659,7 +681,9 @@ absl::Status ReduceScatterImplCommon( const DebugOptions* debug_options, se::Stream* stream, CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id, int64_t reduction_kind, absl::Span replica_group_offsets, - absl::Span replica_group_values, bool is_async) { + absl::Span replica_group_values, bool is_async, + bool no_parallel_custom_call) { + (void)no_parallel_custom_call; NcclExecuteParams params(*run_options, stream->parent()); TF_ASSIGN_OR_RETURN( @@ -684,11 +708,13 @@ absl::Status ReduceScatterImpl(const ServiceExecutableRunOptions* run_options, AsyncCollectivesSupport* async_collectives, CustomCall::RemainingArgs args, int32_t uid, int64_t group_mode, int64_t op_id, bool is_async, + bool no_parallel_custom_call, int64_t reduction_kind, absl::Span replica_group_offsets, absl::Span replica_group_values) { #if XLA_ENABLE_XCCL - VLOG(3) << "Running ReduceScatter " << (is_async ? "(Async)" : "(Sync)"); + VLOG(3) << "Running ReduceScatter " << (is_async ? "(Async) " : "(Sync) ") + << no_parallel_custom_call; return RunSyncOrAsync( run_options, collectives, async_collectives, uid, is_async, [&](se::Stream* stream) { @@ -697,10 +723,10 @@ absl::Status ReduceScatterImpl(const ServiceExecutableRunOptions* run_options, if (gpu_opts && gpu_opts->enable_mock_nccl_collectives()) { return NcclMockImplCommon(stream); } - return ReduceScatterImplCommon(run_options, debug_options, stream, args, - group_mode, op_id, reduction_kind, - replica_group_offsets, - replica_group_values, is_async); + return ReduceScatterImplCommon( + run_options, debug_options, stream, args, group_mode, op_id, + reduction_kind, replica_group_offsets, replica_group_values, + is_async, no_parallel_custom_call); }); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL disabled"); @@ -719,6 +745,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( .Attr("group_mode") // CollectiveOpGroupMode .Attr("op_id") .Attr("is_async") + .Attr("no_parallel_custom_call") .Attr("reduction_kind") // ReductionKind .Attr>("replica_group_offsets") .Attr>("replica_group_values")); From 4f0fe34879a53dc300dc61471b99898e9cfb8466 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 07:59:37 -0800 Subject: [PATCH 292/391] Disable MMA_V3 in triton by default PiperOrigin-RevId: 584027016 --- third_party/triton/b311157761.patch | 64 +++++++++++++++++++ third_party/triton/workspace.bzl | 1 + .../xla/third_party/triton/b311157761.patch | 64 +++++++++++++++++++ .../xla/third_party/triton/workspace.bzl | 1 + 4 files changed, 130 insertions(+) create mode 100644 third_party/triton/b311157761.patch create mode 100644 third_party/xla/third_party/triton/b311157761.patch diff --git a/third_party/triton/b311157761.patch b/third_party/triton/b311157761.patch new file mode 100644 index 00000000000000..b03fa04c142a42 --- /dev/null +++ b/third_party/triton/b311157761.patch @@ -0,0 +1,64 @@ +diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp +--- a/include/triton/Tools/Sys/GetEnv.hpp ++++ b/include/triton/Tools/Sys/GetEnv.hpp +@@ -30,6 +30,7 @@ + namespace triton { + + const std::set ENV_VARS = { ++ "ENABLE_MMA_V3", + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", + "AMDGCN_ENABLE_DUMP"}; +diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp +--- a/lib/Analysis/Utility.cpp ++++ b/lib/Analysis/Utility.cpp +@@ -394,7 +394,8 @@ bool supportMMA(triton::DotOp op, int version) { + auto aElemTy = op.getA().getType().cast().getElementType(); + auto bElemTy = op.getB().getType().cast().getElementType(); + if (version == 3) { +- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) ++ // TODO(b/311157761): enable mma_v3 ++ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return false; + auto retType = op.getResult().getType().cast(); + auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); +diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +@@ -40,7 +40,8 @@ public: + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; +- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) ++ // TODO(b/311157761): enable mma_v3 ++ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { +diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir +--- a/test/Conversion/tritongpu_to_llvm_hopper.mlir ++++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s + + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> + #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> +diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir +--- a/test/TritonGPU/accelerate-matmul.mlir ++++ b/test/TritonGPU/accelerate-matmul.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s + + // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> + // CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir +--- a/test/TritonGPU/fence-inserstion.mlir ++++ b/test/TritonGPU/fence-inserstion.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> \ No newline at end of file diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 3795c89bb75563..bd4c52b02d48a5 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -15,6 +15,7 @@ def repo(): urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. patch_file = [ + "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", ], diff --git a/third_party/xla/third_party/triton/b311157761.patch b/third_party/xla/third_party/triton/b311157761.patch new file mode 100644 index 00000000000000..b03fa04c142a42 --- /dev/null +++ b/third_party/xla/third_party/triton/b311157761.patch @@ -0,0 +1,64 @@ +diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp +--- a/include/triton/Tools/Sys/GetEnv.hpp ++++ b/include/triton/Tools/Sys/GetEnv.hpp +@@ -30,6 +30,7 @@ + namespace triton { + + const std::set ENV_VARS = { ++ "ENABLE_MMA_V3", + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", + "AMDGCN_ENABLE_DUMP"}; +diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp +--- a/lib/Analysis/Utility.cpp ++++ b/lib/Analysis/Utility.cpp +@@ -394,7 +394,8 @@ bool supportMMA(triton::DotOp op, int version) { + auto aElemTy = op.getA().getType().cast().getElementType(); + auto bElemTy = op.getB().getType().cast().getElementType(); + if (version == 3) { +- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) ++ // TODO(b/311157761): enable mma_v3 ++ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return false; + auto retType = op.getResult().getType().cast(); + auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); +diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +@@ -40,7 +40,8 @@ public: + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; +- if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) ++ // TODO(b/311157761): enable mma_v3 ++ if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { +diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir +--- a/test/Conversion/tritongpu_to_llvm_hopper.mlir ++++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s + + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> + #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> +diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir +--- a/test/TritonGPU/accelerate-matmul.mlir ++++ b/test/TritonGPU/accelerate-matmul.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s + + // CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> + // CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir +--- a/test/TritonGPU/fence-inserstion.mlir ++++ b/test/TritonGPU/fence-inserstion.mlir +@@ -1,4 +1,4 @@ +-// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s ++// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> \ No newline at end of file diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 3795c89bb75563..bd4c52b02d48a5 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -15,6 +15,7 @@ def repo(): urls = tf_mirror_urls("https://github.com/openxla/triton/archive/{commit}.tar.gz".format(commit = TRITON_COMMIT)), # For temporary changes which haven't landed upstream yet. patch_file = [ + "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", ], From b9893b23307c57f83843a57dd5ce6a4b5306b4ab Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 20 Nov 2023 08:06:46 -0800 Subject: [PATCH 293/391] Integrate StableHLO at openxla/stablehlo@444c72d6 PiperOrigin-RevId: 584028946 --- third_party/stablehlo/temporary.patch | 322 +++++++++++------- third_party/stablehlo/workspace.bzl | 4 +- .../xla/third_party/stablehlo/temporary.patch | 322 +++++++++++------- .../xla/third_party/stablehlo/workspace.bzl | 4 +- 4 files changed, 396 insertions(+), 256 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index c90f940606d5f1..814417d1966d90 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,40 +1,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -289,10 +289,10 @@ - strip_include_prefix = ".", - deps = [ - ":interpreter_ops_inc_gen", -- ":reference_value", - ":reference_numpy", - ":reference_ops", - ":reference_process_grid", -+ ":reference_value", - "@llvm-project//llvm:Support", +@@ -375,12 +375,23 @@ + ":linalg_pass_inc_gen", + ":stablehlo_ops", + ":chlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:ArithDialect", ++ "@llvm-project//mlir:BufferizationDialect", ++ "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", -@@ -465,10 +465,10 @@ - strip_include_prefix = ".", - deps = [ - ":reference_axes", -+ ":reference_configuration", - ":reference_element", - ":reference_errors", - ":reference_index", -- ":reference_configuration", - ":reference_process", - ":reference_process_grid", - ":reference_scope", -@@ -965,8 +965,8 @@ + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:LinalgUtils", ++ "@llvm-project//mlir:MathDialect", ++ "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:SCFDialect", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:SparseTensorDialect", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", ], - deps = [ - ":interpreter_ops", -+ ":reference_api", - ":reference_errors", -- ":reference_api", - ":reference_ops", - ":reference_process_grid", - ":reference_scope", diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -193,15 +183,127 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -371,4 +371,4 @@ - } // namespace hlo - } // namespace mlir +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt b/stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt +--- stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt ++++ stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt +@@ -21,7 +21,25 @@ + Core + + LINK_LIBS PUBLIC ++ LLVMSupport ++ MLIRArithDialect ++ MLIRBufferizationDialect ++ MLIRComplexDialect ++ MLIRFuncDialect + MLIRIR ++ MLIRLinalgDialect ++ MLIRLinalgTransforms ++ MLIRLinalgUtils ++ MLIRMathDialect ++ MLIRMemRefDialect + MLIRPass ++ MLIRPass ++ MLIRSCFDialect ++ MLIRShapeDialect ++ MLIRSparseTensorDialect ++ MLIRSupport ++ MLIRTensorDialect + MLIRTransforms ++ MLIRTransforms ++ MLIRVectorDialect + ) +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h b/stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h +--- stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h ++++ stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H + #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H + +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/Passes.h b/stablehlo/stablehlo/conversions/linalg/transforms/Passes.h +--- stablehlo/stablehlo/conversions/linalg/transforms/Passes.h ++++ stablehlo/stablehlo/conversions/linalg/transforms/Passes.h +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H + #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H + +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/Passes.td b/stablehlo/stablehlo/conversions/linalg/transforms/Passes.td +--- stablehlo/stablehlo/conversions/linalg/transforms/Passes.td ++++ stablehlo/stablehlo/conversions/linalg/transforms/Passes.td +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_TO_LINALG_PASSES + #define STABLEHLO_TO_LINALG_PASSES --#endif -+#endif // STABLEHLO_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #include "PassDetail.h" + #include "mlir/Dialect/Arith/IR/Arith.h" + #include "mlir/Dialect/Bufferization/IR/Bufferization.h" diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp --- stablehlo/stablehlo/dialect/TypeInference.cpp +++ stablehlo/stablehlo/dialect/TypeInference.cpp @@ -3888,41 +3990,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel ---- stablehlo/stablehlo/tests/BUILD.bazel -+++ stablehlo/stablehlo/tests/BUILD.bazel -@@ -40,6 +40,7 @@ - - gentbl_cc_library( - name = "check_ops_inc_gen", -+ strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], -@@ -52,7 +53,6 @@ - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "CheckOps.td", -- strip_include_prefix = ".", - deps = [ - ":check_ops_td_files", - ], -@@ -140,6 +140,7 @@ - [ - lit_test( - name = "%s.test" % src, -+ size = "small", - srcs = [src], - data = [ - "lit.cfg.py", -@@ -149,7 +150,6 @@ - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ] + glob(["%s.bc" % src]), -- size = "small", - tags = ["stablehlo_tests"], - ) - for src in glob(["**/*.mlir"]) diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir --- stablehlo/stablehlo/tests/ops_stablehlo.mlir +++ stablehlo/stablehlo/tests/ops_stablehlo.mlir @@ -3962,69 +4029,72 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir --- stablehlo/stablehlo/tests/print_stablehlo.mlir +++ stablehlo/stablehlo/tests/print_stablehlo.mlir -@@ -1,5 +1,5 @@ --// RUN: stablehlo-opt %s | FileCheck %s --// RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s -+// RUN: stablehlo-opt %s --split-input-file | FileCheck %s -+// RUN: stablehlo-opt %s --split-input-file | stablehlo-opt --split-input-file | FileCheck %s +@@ -1,5 +1,34 @@ + // RUN: stablehlo-opt %s | FileCheck %s + // RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s ++ ++// Test encodings first since aliases are printed at top of file. ++#CSR = #sparse_tensor.encoding<{ ++ map = (d0, d1) -> (d0 : dense, d1 : compressed) ++}> ++ ++#DCSR = #sparse_tensor.encoding<{ ++ map = (d0, d1) -> (d0 : compressed, d1 : compressed) ++}> ++ ++// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> ++// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> ++// CHECK-LABEL: func @encodings ++func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, ++ %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { ++ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> ++ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> ++ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, ++ tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> ++ %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, ++ tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #DCSR> ++ %2 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> ++ %3 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> ++ %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> ++ func.return %0 : tensor<10x20xf32> ++} // CHECK-LABEL: func @zero_input func.func @zero_input() -> !stablehlo.token { -@@ -291,6 +291,8 @@ +@@ -291,32 +320,6 @@ "stablehlo.return"() : () -> () } -+// ----- -+ - #CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) - }> -@@ -299,14 +301,16 @@ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) - }> - -+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -+// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> - // CHECK-LABEL: func @encodings - func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { +-#CSR = #sparse_tensor.encoding<{ +- map = (d0, d1) -> (d0 : dense, d1 : compressed) +-}> +- +-#DCSR = #sparse_tensor.encoding<{ +- map = (d0, d1) -> (d0 : compressed, d1 : compressed) +-}> +- +-// CHECK-LABEL: func @encodings +-func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, +- %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { - // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) -> tensor<10x20xf32> - // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> - // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xf32> - // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> - // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xcomplex> -+ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> -+ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> -+ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> -+ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> -+ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> - %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, -@@ -316,6 +320,8 @@ - %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> - func.return %0 : tensor<10x20xf32> - } -+ -+// ----- - +- %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, +- tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> +- %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, +- tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #DCSR> +- %2 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> +- %3 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> +- %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> +- func.return %0 : tensor<10x20xf32> +-} +- func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: tensor<2x2xi8>, %arg3: tensor<2x3xi8>) -> tensor<2x2x3xi32> { // CHECK: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -25,6 +25,7 @@ - - namespace mlir { - namespace stablehlo { -+ - #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS - #define GEN_PASS_DECL_STABLEHLOLEGALIZETOVHLOPASS - #define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS -@@ -66,4 +67,4 @@ - } // namespace stablehlo - } // namespace mlir - --#endif // STABLEHLO_DIALECT_VHLO_OPS_H -+#endif // STABLEHLO_TRANSFORMS_PASSES_H + // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 1383c2f5634f7b..df25c4b7ca88f7 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "c3e1ab5af5f4f3ab749072ffa960348f90d5e832" - STABLEHLO_SHA256 = "9a421322b2176d1226e95472823cffb7752d1948513e18af28d5c6a27d0f35ca" + STABLEHLO_COMMIT = "444c72d6de4e2b3c7738457708ea863d5fb7f0e4" + STABLEHLO_SHA256 = "9f8623138f30212d57e047cc29b6d892c8fc5884c5bd4a919441c24ca3785c42" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index ceb3e74472fc74..d1125545a22f9b 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,40 +1,30 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -289,10 +289,10 @@ - strip_include_prefix = ".", - deps = [ - ":interpreter_ops_inc_gen", -- ":reference_value", - ":reference_numpy", - ":reference_ops", - ":reference_process_grid", -+ ":reference_value", - "@llvm-project//llvm:Support", +@@ -375,12 +375,23 @@ + ":linalg_pass_inc_gen", + ":stablehlo_ops", + ":chlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:ArithDialect", ++ "@llvm-project//mlir:BufferizationDialect", ++ "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", -@@ -465,10 +465,10 @@ - strip_include_prefix = ".", - deps = [ - ":reference_axes", -+ ":reference_configuration", - ":reference_element", - ":reference_errors", - ":reference_index", -- ":reference_configuration", - ":reference_process", - ":reference_process_grid", - ":reference_scope", -@@ -965,8 +965,8 @@ + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:LinalgUtils", ++ "@llvm-project//mlir:MathDialect", ++ "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:SCFDialect", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:SparseTensorDialect", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", ], - deps = [ - ":interpreter_ops", -+ ":reference_api", - ":reference_errors", -- ":reference_api", - ":reference_ops", - ":reference_process_grid", - ":reference_scope", diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -193,15 +183,127 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -371,4 +371,4 @@ - } // namespace hlo - } // namespace mlir +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt b/stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt +--- stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt ++++ stablehlo/stablehlo/conversions/linalg/transforms/CMakeLists.txt +@@ -21,7 +21,25 @@ + Core + + LINK_LIBS PUBLIC ++ LLVMSupport ++ MLIRArithDialect ++ MLIRBufferizationDialect ++ MLIRComplexDialect ++ MLIRFuncDialect + MLIRIR ++ MLIRLinalgDialect ++ MLIRLinalgTransforms ++ MLIRLinalgUtils ++ MLIRMathDialect ++ MLIRMemRefDialect + MLIRPass ++ MLIRPass ++ MLIRSCFDialect ++ MLIRShapeDialect ++ MLIRSparseTensorDialect ++ MLIRSupport ++ MLIRTensorDialect + MLIRTransforms ++ MLIRTransforms ++ MLIRVectorDialect + ) +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h b/stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h +--- stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h ++++ stablehlo/stablehlo/conversions/linalg/transforms/PassDetail.h +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H + #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H + +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/Passes.h b/stablehlo/stablehlo/conversions/linalg/transforms/Passes.h +--- stablehlo/stablehlo/conversions/linalg/transforms/Passes.h ++++ stablehlo/stablehlo/conversions/linalg/transforms/Passes.h +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H + #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H + +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/Passes.td b/stablehlo/stablehlo/conversions/linalg/transforms/Passes.td +--- stablehlo/stablehlo/conversions/linalg/transforms/Passes.td ++++ stablehlo/stablehlo/conversions/linalg/transforms/Passes.td +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #ifndef STABLEHLO_TO_LINALG_PASSES + #define STABLEHLO_TO_LINALG_PASSES --#endif -+#endif // STABLEHLO_DIALECT_BASE_H +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +@@ -1,3 +1,19 @@ ++/* Copyright 2022 The IREE Authors ++ Copyright 2023 OpenXLA Authors. All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ + #include "PassDetail.h" + #include "mlir/Dialect/Arith/IR/Arith.h" + #include "mlir/Dialect/Bufferization/IR/Bufferization.h" diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp --- stablehlo/stablehlo/dialect/TypeInference.cpp +++ stablehlo/stablehlo/dialect/TypeInference.cpp @@ -3888,41 +3990,6 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel ---- stablehlo/stablehlo/tests/BUILD.bazel -+++ stablehlo/stablehlo/tests/BUILD.bazel -@@ -40,6 +40,7 @@ - - gentbl_cc_library( - name = "check_ops_inc_gen", -+ strip_include_prefix = ".", - tbl_outs = [ - ( - ["-gen-op-decls"], -@@ -52,7 +53,6 @@ - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "CheckOps.td", -- strip_include_prefix = ".", - deps = [ - ":check_ops_td_files", - ], -@@ -140,6 +140,7 @@ - [ - lit_test( - name = "%s.test" % src, -+ size = "small", - srcs = [src], - data = [ - "lit.cfg.py", -@@ -149,7 +150,6 @@ - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", - ] + glob(["%s.bc" % src]), -- size = "small", - tags = ["stablehlo_tests"], - ) - for src in glob(["**/*.mlir"]) diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir --- stablehlo/stablehlo/tests/ops_stablehlo.mlir +++ stablehlo/stablehlo/tests/ops_stablehlo.mlir @@ -3962,69 +4029,72 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir --- stablehlo/stablehlo/tests/print_stablehlo.mlir +++ stablehlo/stablehlo/tests/print_stablehlo.mlir -@@ -1,5 +1,5 @@ --// RUN: stablehlo-opt %s | FileCheck %s --// RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s -+// RUN: stablehlo-opt %s --split-input-file | FileCheck %s -+// RUN: stablehlo-opt %s --split-input-file | stablehlo-opt --split-input-file | FileCheck %s +@@ -1,5 +1,34 @@ + // RUN: stablehlo-opt %s | FileCheck %s + // RUN: stablehlo-opt %s | stablehlo-opt | FileCheck %s ++ ++// Test encodings first since aliases are printed at top of file. ++#CSR = #sparse_tensor.encoding<{ ++ map = (d0, d1) -> (d0 : dense, d1 : compressed) ++}> ++ ++#DCSR = #sparse_tensor.encoding<{ ++ map = (d0, d1) -> (d0 : compressed, d1 : compressed) ++}> ++ ++// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> ++// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> ++// CHECK-LABEL: func @encodings ++func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, ++ %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { ++ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> ++ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> ++ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> ++ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> ++ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, ++ tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> ++ %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, ++ tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #DCSR> ++ %2 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> ++ %3 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> ++ %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> ++ func.return %0 : tensor<10x20xf32> ++} // CHECK-LABEL: func @zero_input func.func @zero_input() -> !stablehlo.token { -@@ -291,6 +291,8 @@ +@@ -291,32 +320,6 @@ "stablehlo.return"() : () -> () } -+// ----- -+ - #CSR = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) - }> -@@ -299,14 +301,16 @@ - map = (d0, d1) -> (d0 : compressed, d1 : compressed) - }> - -+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> -+// CHECK: #[[$DCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> - // CHECK-LABEL: func @encodings - func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, - %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { +-#CSR = #sparse_tensor.encoding<{ +- map = (d0, d1) -> (d0 : dense, d1 : compressed) +-}> +- +-#DCSR = #sparse_tensor.encoding<{ +- map = (d0, d1) -> (d0 : compressed, d1 : compressed) +-}> +- +-// CHECK-LABEL: func @encodings +-func.func @encodings(%arg0: tensor<10x20xf32, #CSR>, +- %arg1: tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> { - // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>>) -> tensor<10x20xf32> - // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>> - // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xf32> - // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>> - // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>) -> tensor<10x20xcomplex> -+ // CHECK: %0 = stablehlo.add %arg0, %arg1 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$DCSR]]>) -> tensor<10x20xf32> -+ // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #[[$DCSR]]> -+ // CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xf32> -+ // CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #[[$CSR]]> -+ // CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #[[$CSR]]>, tensor<10x20xf32, #[[$CSR]]>) -> tensor<10x20xcomplex> - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, - tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> - %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, -@@ -316,6 +320,8 @@ - %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> - func.return %0 : tensor<10x20xf32> - } -+ -+// ----- - +- %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>, +- tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32> +- %1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>, +- tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #DCSR> +- %2 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32> +- %3 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR> +- %4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex> +- func.return %0 : tensor<10x20xf32> +-} +- func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: tensor<2x2xi8>, %arg3: tensor<2x3xi8>) -> tensor<2x2x3xi32> { // CHECK: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -25,6 +25,7 @@ - - namespace mlir { - namespace stablehlo { -+ - #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS - #define GEN_PASS_DECL_STABLEHLOLEGALIZETOVHLOPASS - #define GEN_PASS_DECL_STABLEHLOREFINESHAPESPASS -@@ -66,4 +67,4 @@ - } // namespace stablehlo - } // namespace mlir - --#endif // STABLEHLO_DIALECT_VHLO_OPS_H -+#endif // STABLEHLO_TRANSFORMS_PASSES_H + // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 1383c2f5634f7b..df25c4b7ca88f7 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "c3e1ab5af5f4f3ab749072ffa960348f90d5e832" - STABLEHLO_SHA256 = "9a421322b2176d1226e95472823cffb7752d1948513e18af28d5c6a27d0f35ca" + STABLEHLO_COMMIT = "444c72d6de4e2b3c7738457708ea863d5fb7f0e4" + STABLEHLO_SHA256 = "9f8623138f30212d57e047cc29b6d892c8fc5884c5bd4a919441c24ca3785c42" # LINT.ThenChange(Google-internal path) tf_http_archive( From c837ea6ddc22147ad53eba6ebec8ac0aff9f09e7 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 20 Nov 2023 09:40:32 -0800 Subject: [PATCH 294/391] [XLA] [NFC] Expose RenderGraph in dump.h header for outside usage PiperOrigin-RevId: 584051718 --- third_party/xla/xla/service/dump.cc | 47 +++++++++++++++-------------- third_party/xla/xla/service/dump.h | 6 ++++ 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index a490f9af998cc2..e1af0157b865ba 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -44,6 +44,21 @@ limitations under the License. namespace xla { +std::string RenderGraph(absl::string_view label, const HloModule& module, + RenderedGraphFormat format, + bool show_fusion_subcomputations) { + HloRenderOptions hlo_render_options; + hlo_render_options.show_fusion_subcomputations = show_fusion_subcomputations; + StatusOr rendered_graph = + RenderGraph(*module.entry_computation(), label, + module.config().debug_options(), format, hlo_render_options); + if (rendered_graph.ok()) { + return std::move(rendered_graph).value(); + } + return absl::StrFormat("Error rendering graph: %s", + rendered_graph.status().ToString()); +} + namespace { using absl::StrCat; @@ -428,36 +443,22 @@ static std::vector DumpHloModuleImpl( pb, opts, opts.dump_compress_protos)); } - auto render_graph = [&](RenderedGraphFormat format, - bool show_fusion_subcomputations = true) { - HloRenderOptions hlo_render_options; - hlo_render_options.show_fusion_subcomputations = - show_fusion_subcomputations; - StatusOr rendered_graph = - RenderGraph(*module.entry_computation(), - /*label=*/filename, module.config().debug_options(), format, - hlo_render_options); - if (rendered_graph.ok()) { - return std::move(rendered_graph).value(); - } - return StrFormat("Error rendering graph: %s", - rendered_graph.status().ToString()); - }; if (opts.dump_as_dot) { - file_paths.push_back( - DumpToFileInDirImpl(StrFormat("%s.dot", filename), - render_graph(RenderedGraphFormat::kDot), opts)); + file_paths.push_back(DumpToFileInDirImpl( + StrFormat("%s.dot", filename), + RenderGraph(filename, module, RenderedGraphFormat::kDot), opts)); } if (opts.dump_as_html) { - file_paths.push_back( - DumpToFileInDirImpl(StrFormat("%s.html", filename), - render_graph(RenderedGraphFormat::kHtml), opts)); + file_paths.push_back(DumpToFileInDirImpl( + StrFormat("%s.html", filename), + RenderGraph(filename, module, RenderedGraphFormat::kHtml), opts)); if (absl::StrContains(filename, kAfterOptimizationsDumpName)) { file_paths.push_back(DumpToFileInDirImpl( StrFormat("%s.top_level.html", filename), - render_graph(RenderedGraphFormat::kHtml, false), opts)); + RenderGraph(filename, module, RenderedGraphFormat::kHtml, false), + opts)); } } @@ -486,7 +487,7 @@ static std::vector DumpHloModuleImpl( // Special case for rendering graphs as URLs. We'll dump them to a file // because why not, but we always log them to stdout as well. if (opts.dump_as_url) { - std::string url = render_graph(RenderedGraphFormat::kUrl); + std::string url = RenderGraph(filename, module, RenderedGraphFormat::kUrl); std::cout << filename << " --> " << url << std::endl; if (!opts.dumping_to_stdout()) { file_paths.push_back( diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h index 86244ba4b73edb..3dbacbd5341707 100644 --- a/third_party/xla/xla/service/dump.h +++ b/third_party/xla/xla/service/dump.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_graph_dumper.h" #include "xla/status.h" #include "xla/xla.pb.h" @@ -92,6 +93,11 @@ void DumpProtobufToFile(const tsl::protobuf::Message& proto, tsl::Env*, const tsl::protobuf::Message&)> text_formatter = nullptr); +// Render graph in a given format. +std::string RenderGraph(absl::string_view label, const HloModule& module, + RenderedGraphFormat format, + bool show_fusion_subcomputations = true); + // Similar to above, but the filename depends on module's information and the // given name. Also allows for the optional serialization function. void DumpPerModuleProtobufToFile( From 5af2cdf3666c07c962c15fdd7b6d1fbac133ac1e Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 20 Nov 2023 10:01:58 -0800 Subject: [PATCH 295/391] HloFusionAnalysis: Don't store references to ops outside the analyzed fusion. The lifetime of instructions outside the fusion is generally unpredicatable, so we shouldn't store references to them. PiperOrigin-RevId: 584057081 --- .../xla/service/gpu/hlo_fusion_analysis.cc | 75 +++++++++++-------- .../xla/xla/service/gpu/hlo_fusion_analysis.h | 22 +++--- 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index bb2fe734a63055..104913b6ab163c 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -255,6 +255,16 @@ std::optional FindConsistentTransposeHero( return tiled_transpose_hero; } +int SmallestInputDtypeBits(const std::vector& args) { + int bits = std::numeric_limits::max(); + for (const HloInstruction* operand : args) { + if (!operand->shape().IsArray()) continue; + bits = std::min(bits, + primitive_util::BitWidth(operand->shape().element_type())); + } + return bits; +} + } // namespace // static @@ -277,16 +287,20 @@ StatusOr HloFusionAnalysis::Create( auto is_4bit = [](const HloInstruction* arg) { return primitive_util::Is4BitType(arg->shape().element_type()); }; - bool has_4_bit_input = absl::c_any_of(fusion_arguments, is_4bit); - bool has_4_bit_output = absl::c_any_of(hlo_roots, is_4bit); + + InputOutputInfo input_output_info{ + .has_4_bit_input = absl::c_any_of(fusion_arguments, is_4bit), + .has_4_bit_output = absl::c_any_of(hlo_roots, is_4bit), + .smallest_input_dtype_bits = SmallestInputDtypeBits(fusion_arguments), + }; std::optional tiled_transpose_hero = FindConsistentTransposeHero(hlo_roots, heroes); return HloFusionAnalysis(std::move(backend_config), std::move(hlo_roots), - std::move(boundary_fn), std::move(fusion_arguments), - std::move(heroes), device_info, tiled_transpose_hero, - has_4_bit_input, has_4_bit_output); + std::move(boundary_fn), std::move(heroes), + device_info, tiled_transpose_hero, + std::move(input_output_info)); } // static @@ -320,7 +334,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() } #endif - if (has_4_bit_input_ || has_4_bit_output_) { + if (input_output_info_.has_4_bit_input || + input_output_info_.has_4_bit_output) { // Only loop fusions currently can handle int4 inputs/outputs, due to the // special handling with IrArray needed to deal with two values occupying a // single byte. @@ -485,7 +500,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() { } // CHECK that unroll_factor is a power-of-2, as needed by the logic below. CHECK(absl::has_single_bit(static_cast(unroll_factor))); - if (has_4_bit_output_ && unroll_factor == 1) { + if (input_output_info_.has_4_bit_output && unroll_factor == 1) { // Ensure a single thread writes to a byte containing two int4 values by // setting unroll_factor to 2. unroll_factor is always a power of 2, so // setting it to 2 here ensures unroll_factor is even when there are 4-bit @@ -549,15 +564,6 @@ const Shape& HloFusionAnalysis::GetElementShape() const { return *shape; } -int HloFusionAnalysis::SmallestInputDtypeBits() const { - int bits = std::numeric_limits::max(); - for (const HloInstruction* operand : fusion_arguments_) { - bits = std::min(bits, - primitive_util::BitWidth(operand->shape().element_type())); - } - return bits; -} - int64_t HloFusionAnalysis::MaxBeneficialColumnReductionUnrollBasedOnBlockSize() const { int64_t num_reduce_output_elems = 0; @@ -735,23 +741,25 @@ bool HloFusionAnalysis::IsUnrollingColumnReductionBeneficial( return TraversalResult::kVisitOperands; }); - for (auto* argument : fusion_arguments_) { - if (!reachable_through_non_elementwise.contains(argument) && - ShapeUtil::SameDimensions(input_shape, argument->shape())) { - ++can_be_vectorized; - } - } - - // Fusion inputs with more elements than the reduce op input must participate - // in non-elementwise operations and we assume that they are not vectorizable - // for the purpose of estimating the benefit of unrolling. If the kernel is - // unrolled even with such an assumption, and the accesses to those inputs - // turn out to be vectorizable, the compiler will still vectorize them. int64_t num_elements = ShapeUtil::ElementsIn(input_shape); - cannot_be_vectorized += - absl::c_count_if(fusion_arguments_, [&](const HloInstruction* parameter) { - return ShapeUtil::ElementsIn(parameter->shape()) > num_elements; + FindFusionArguments( + fusion_roots_, fusion_boundary_fn_, [&](const HloInstruction& arg) { + if (!reachable_through_non_elementwise.contains(&arg) && + ShapeUtil::SameDimensions(input_shape, arg.shape())) { + ++can_be_vectorized; + } + + // Fusion inputs with more elements than the reduce op input must + // participate in non-elementwise operations and we assume that they are + // not vectorizable for the purpose of estimating the benefit of + // unrolling. If the kernel is unrolled even with such an assumption, + // and the accesses to those inputs turn out to be vectorizable, the + // compiler will still vectorize them. + if (ShapeUtil::ElementsIn(arg.shape()) > num_elements) { + ++cannot_be_vectorized; + } }); + if (can_be_vectorized < cannot_be_vectorized) { return false; } @@ -785,7 +793,7 @@ bool HloFusionAnalysis::CanVectorizeReduction( if (cuda_cc == nullptr) return false; if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return true; if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) { - return SmallestInputDtypeBits() <= 32 && + return input_output_info_.smallest_input_dtype_bits <= 32 && reduction_dimensions.dimensions[kDimX] % (reduction_tiling[2] * num_threads_x) == 0; @@ -863,7 +871,8 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( // difference, e.g. by affecting register spilling. int num_partial_results = 1; if (!reduction_dimensions.is_row_reduction && vectorize) { - int smallest_input_dtype_bits = SmallestInputDtypeBits(); + int smallest_input_dtype_bits = + input_output_info_.smallest_input_dtype_bits; if (smallest_input_dtype_bits <= 32) { // Make sure to use all the data read at once. // Instead of hardcoding the granularity, we can query the granularity we diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index c07819db2d3a15..bcaf34aca46f26 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -86,26 +86,30 @@ class HloFusionAnalysis { const HloInstruction* FindHeroReduction() const; private: + // Precomputed information about inputs (arguments) and outputs (roots) of the + // fusion. + struct InputOutputInfo { + bool has_4_bit_input; + bool has_4_bit_output; + int smallest_input_dtype_bits; + }; + HloFusionAnalysis(FusionBackendConfig fusion_backend_config, std::vector fusion_roots, FusionBoundaryFn fusion_boundary_fn, - std::vector fusion_arguments, std::vector fusion_heroes, const se::DeviceDescription* device_info, std::optional tiled_transpose, - bool has_4_bit_input, bool has_4_bit_output) + InputOutputInfo input_output_info) : fusion_backend_config_(std::move(fusion_backend_config)), fusion_roots_(std::move(fusion_roots)), fusion_boundary_fn_(std::move(fusion_boundary_fn)), - fusion_arguments_(std::move(fusion_arguments)), fusion_heroes_(std::move(fusion_heroes)), device_info_(device_info), tiled_transpose_(tiled_transpose), - has_4_bit_input_(has_4_bit_input), - has_4_bit_output_(has_4_bit_output) {} + input_output_info_(std::move(input_output_info)) {} const Shape& GetElementShape() const; - int SmallestInputDtypeBits() const; int64_t MaxBeneficialColumnReductionUnrollBasedOnBlockSize() const; std::vector> GroupDisjointReductions() const; @@ -125,14 +129,10 @@ class HloFusionAnalysis { FusionBackendConfig fusion_backend_config_; std::vector fusion_roots_; FusionBoundaryFn fusion_boundary_fn_; - // The HLO instructions that are inputs into the fusion. These instructions - // are /outside/ the fusion. - std::vector fusion_arguments_; std::vector fusion_heroes_; const se::DeviceDescription* device_info_; std::optional tiled_transpose_; - const bool has_4_bit_input_ = false; - const bool has_4_bit_output_ = false; + InputOutputInfo input_output_info_; std::optional reduction_codegen_info_; std::optional transpose_tiling_scheme_; From 1ff4dd878149db614be71ebf7412abea4ae921c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 10:16:23 -0800 Subject: [PATCH 296/391] Lower tensor.dim to MHLO in ShapeLegalizeToHLO Currently this supports static dim only. PiperOrigin-RevId: 584061103 --- .../shape_legalize_to_hlo.cc | 22 +++++++++++++++++++ .../Dialect/mhlo/shape_legalize_to_hlo.mlir | 21 ++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc index b07a589afdbe23..b5b45d35818d04 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_legalize_to_hlo/shape_legalize_to_hlo.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -28,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" @@ -250,6 +252,25 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern { } }; +struct ConvertTensorDimPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::DimOp op, + PatternRewriter& rewriter) const override { + // We only support getting static index. + auto constIndex = + dyn_cast_or_null(op.getIndex().getDefiningOp()); + if (!constIndex) { + return failure(); + } + + auto dim = rewriter.create( + op->getLoc(), op.getSource(), constIndex.value()); + auto dimIndex = castToIndex(rewriter, op.getLoc(), dim); + rewriter.replaceOp(op, dimIndex); + return success(); + } +}; + template struct CastOperandsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -337,6 +358,7 @@ struct ShapeLegalizeToHloPass patterns.add(&getContext()); patterns.add(&getContext()); patterns.add>(&getContext()); + patterns.add(&getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir index bb5af8494888ec..fcdd268c0c4208 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/shape_legalize_to_hlo.mlir @@ -88,3 +88,24 @@ func.func @shape_of_ranked_to_shape(%arg0: tensor) -> !shape.shape { %0 = shape.shape_of %arg0 : tensor -> !shape.shape func.return %0 : !shape.shape } + + +// ----- + +// CHECK-LABEL: func.func @tensor_dim +func.func @tensor_dim(%arg0: tensor) -> index { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + func.return %dim : index + // CHECK: %[[DIM_SIZE:.*]] = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor) -> tensor + // CHECK-NEXT: %[[DIM_SIZE_INDEX:.*]] = builtin.unrealized_conversion_cast %[[DIM_SIZE]] : tensor to index + // CHECK-NEXT: return %[[DIM_SIZE_INDEX]] : index +} + +// ----- + +func.func @tensor_dim_dynamic(%arg0: tensor, %arg1: index) -> index { + // expected-error@+1 {{failed to legalize operation 'tensor.dim' that was explicitly marked illegal}} + %dim = tensor.dim %arg0, %arg1 : tensor + func.return %dim : index +} From 3dc9b41d5c3a7ab71a5c4da4301e68f8e5bd4062 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 20 Nov 2023 10:21:53 -0800 Subject: [PATCH 297/391] [XLA:GPU] Use LazyRE2 so regexes don't show up in compile time profiles PiperOrigin-RevId: 584062609 --- third_party/xla/xla/stream_executor/gpu/asm_compiler.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc b/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc index e86f86ab9df0f9..4cdd13f44f75db 100644 --- a/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/gpu/asm_compiler.cc @@ -86,10 +86,11 @@ tsl::StatusOr> GetToolVersion( return tsl::errors::FailedPrecondition( "Couldn't get ptxas/nvlink version string: ", tool_version.status()); } + static constexpr LazyRE2 kVersionRegex = {R"(\bV(\d+)\.(\d+)\.(\d+)\b)"}; std::array version; - std::string vmaj_str, vmin_str, vdot_str; - if (!RE2::PartialMatch(tool_version.value(), R"(\bV(\d+)\.(\d+)\.(\d+)\b)", - &vmaj_str, &vmin_str, &vdot_str) || + absl::string_view vmaj_str, vmin_str, vdot_str; + if (!RE2::PartialMatch(tool_version.value(), *kVersionRegex, &vmaj_str, + &vmin_str, &vdot_str) || !absl::SimpleAtoi(vmaj_str, &version[0]) || !absl::SimpleAtoi(vmin_str, &version[1]) || !absl::SimpleAtoi(vdot_str, &version[2])) { From 825f9ba778faa5e70430163d8325e14e689cd56d Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 20 Nov 2023 10:44:06 -0800 Subject: [PATCH 298/391] Add GCS URI location for storing macOS nightly wheel artifacts PiperOrigin-RevId: 584069494 --- ci/official/envs/nightly_macos_arm64_py310 | 1 + ci/official/envs/nightly_macos_arm64_py311 | 1 + ci/official/envs/nightly_macos_arm64_py312 | 1 + ci/official/envs/nightly_macos_arm64_py39 | 1 + 4 files changed, 4 insertions(+) diff --git a/ci/official/envs/nightly_macos_arm64_py310 b/ci/official/envs/nightly_macos_arm64_py310 index 38ae51e9e2f587..4c950b626e1daa 100644 --- a/ci/official/envs/nightly_macos_arm64_py310 +++ b/ci/official/envs/nightly_macos_arm64_py310 @@ -6,6 +6,7 @@ TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.10 TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py311 b/ci/official/envs/nightly_macos_arm64_py311 index 218b292dd41f3f..6d7699640758a3 100644 --- a/ci/official/envs/nightly_macos_arm64_py311 +++ b/ci/official/envs/nightly_macos_arm64_py311 @@ -6,6 +6,7 @@ TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.11 TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py312 b/ci/official/envs/nightly_macos_arm64_py312 index 7d89c9f31118d4..add4488fb4650b 100644 --- a/ci/official/envs/nightly_macos_arm64_py312 +++ b/ci/official/envs/nightly_macos_arm64_py312 @@ -6,6 +6,7 @@ TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.12 TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos-arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION diff --git a/ci/official/envs/nightly_macos_arm64_py39 b/ci/official/envs/nightly_macos_arm64_py39 index 4a3c24353b69c1..f969f20d782ee7 100644 --- a/ci/official/envs/nightly_macos_arm64_py39 +++ b/ci/official/envs/nightly_macos_arm64_py39 @@ -8,6 +8,7 @@ TFCI_DOCKER_ENABLE=0 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 TFCI_PYTHON_VERSION=3.9 TFCI_UPLOAD_WHL_GCS_ENABLE=1 +TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-nightly/macos/arm64/$(date -I)" TFCI_WHL_AUDIT_ENABLE= TFCI_WHL_SIZE_LIMIT=240M TFCI_PYENV_INSTALL_LOCAL_ENABLE=$TFCI_PYTHON_VERSION From 762697635c421ab904a29c1f33fd5727617992aa Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 20 Nov 2023 11:12:51 -0800 Subject: [PATCH 299/391] Prevent all `reduce-reduce` fusions. We can't currently properly cost-model these, so do this instead. Benchmarks are neutral to positive, compile time improves by about 10%. PiperOrigin-RevId: 584078428 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/xla/service/gpu/priority_fusion.cc | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 042acefef80e23..27997a6b74ec5b 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2083,6 +2083,7 @@ cc_library( deps = [ ":fusion_process_dump_proto_cc", ":gpu_fusible", + ":hlo_traversal", "//xla:shape_util", "//xla:statusor", "//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/priority_fusion.cc index 8df808c1bfad82..32569acfbfc5e4 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/priority_fusion.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/instruction_fusion.h" @@ -466,8 +467,21 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, return can_fuse; } - // Avoid cases where we'd create a fusion that hit limitations in ptxas. Would - // be nice to model this with cost instead. + // Avoid fusing reduce into reduce. Our cost model doesn't currently + // understand this case due to a lack of tiling analysis. + // TODO(b/312200883): Remove this. + auto contains_reduce = [&](const HloInstruction* instr) { + return HloAnyOf({instr}, MakeSingleInstructionFusion(*instr), + [](const HloInstruction& node) { + return node.opcode() == HloOpcode::kReduce; + }); + }; + if (contains_reduce(producer) && contains_reduce(consumer)) { + return "both the producer and the consumer contain a reduce"; + } + + // Avoid cases where we'd create a fusion that hit limitations in ptxas. + // Would be nice to model this with cost instead. if (auto fits_budget = FusionFitsInBudget(*consumer, *producer, device_info_, /*is_consumer_producer_fusion=*/true); From a933bca2bd8766f236054b37db807a6f54bf5d52 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 11:21:22 -0800 Subject: [PATCH 300/391] Re-enable layering check for target. PiperOrigin-RevId: 584080786 --- tensorflow/core/common_runtime/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 586025cca162d1..d0349fd275813a 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -329,7 +329,6 @@ cc_library( srcs = ["all_to_all.cc"], hdrs = ["all_to_all.h"], copts = tf_copts(), - features = ["-layering_check"], deps = [ ":base_collective_executor", ":collective_rma_local", @@ -341,7 +340,7 @@ cc_library( ":process_util", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/platform:blocking_counter", ], alwayslink = 1, ) From b6633b5c23f59397859085f63ae3a7425e6658c4 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Mon, 20 Nov 2023 11:38:26 -0800 Subject: [PATCH 301/391] [XLA] Make complex resharding more aggressive. Support cases where target dimension is not 1. PiperOrigin-RevId: 584085511 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 16 ++++++------ .../xla/service/spmd/spmd_partitioner_test.cc | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index e76dcce79f4f4f..068e94eac10fa7 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -1894,12 +1894,11 @@ PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, 0) { auto get_reshaped_sharding = [&](int64_t target_dim) -> std::optional { - if (source.tile_assignment().dim(target_dim) != 1) { - return std::nullopt; - } - if (source.tile_assignment().dim(i) != - target.tile_assignment().dim(i) * - target.tile_assignment().dim(target_dim)) { + if (source.tile_assignment().dim(target_dim) == + target.tile_assignment().dim(target_dim) || + source.tile_assignment().dim(i) != + target.tile_assignment().dim(i) * + target.tile_assignment().dim(target_dim)) { VLOG(10) << "Skipped for target dim different from dimension_size " << target_dim << " src size: " << source.tile_assignment().dim(i) @@ -1912,7 +1911,7 @@ PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, }; for (int j = i - 1; j >= 0; --j) { if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Right"; + VLOG(10) << "Triggered Unmerge to Right i = " << i << ",j = " << j; std::vector dimensions( reshaped_sharding->tile_assignment().dimensions().begin(), reshaped_sharding->tile_assignment().dimensions().end()); @@ -1934,7 +1933,7 @@ PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, } for (int j = i + 1; j < target.TiledDataRank(); ++j) { if (auto reshaped_sharding = get_reshaped_sharding(j)) { - VLOG(10) << "Triggered Unmerge to Left"; + VLOG(10) << "Triggered Unmerge to Left i = " << i << ",j = " << j; std::vector dimensions( reshaped_sharding->tile_assignment().dimensions().begin(), reshaped_sharding->tile_assignment().dimensions().end()); @@ -2092,6 +2091,7 @@ std::optional PartitionedHlo::TryComplexReshardHandling( VLOG(10) << "Reshaped shape: " << reshaped.hlo()->shape().ToString(); VLOG(10) << "Reshaped base_shape: " << reshaped.base_shape().ToString(); VLOG(10) << "Before sharding: " << before_sharding.ToString(); + VLOG(10) << "Reshaped: " << reshaped.hlo()->ToString(); auto reshard = reshaped.ReshardNoCache(new_reshaped_sharding, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false); diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 8e4b21c1807a1b..4a1c45cc6e8552 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -13988,6 +13988,32 @@ ENTRY %entry { op::Shape("bf16[8,2048,16384]"))); } +TEST_P(SpmdPartitioningTest, ComplexReshapeReshard) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %extracted_computation (param: f32[13,128,312,16,312]) -> f32[13,39936,4992] { + %param = f32[13,128,312,16,312]{4,2,3,1,0} parameter(0) + %copy.1261 = f32[13,128,312,16,312]{4,3,2,1,0} copy(f32[13,128,312,16,312]{4,2,3,1,0} %param), sharding={devices=[1,32,1,2,1,2]<=[2,64]T(1,0) last_tile_dim_replicate} + %reshape.27217 = f32[13,39936,4992]{2,1,0} reshape(f32[13,128,312,16,312]{4,3,2,1,0} %copy.1261), sharding={devices=[1,2,32,2]<=[2,32,2]T(2,1,0) last_tile_dim_replicate} + %copy.1260 = f32[13,39936,4992]{2,1,0} copy(f32[13,39936,4992]{2,1,0} %reshape.27217), sharding={devices=[1,2,32,2]<=[2,32,2]T(2,1,0) last_tile_dim_replicate} + ROOT %copy = f32[13,39936,4992]{2,1,0} copy(f32[13,39936,4992]{2,1,0} %copy.1260) +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/128, + /*conv_halo_exchange_always_on_lhs=*/true, + /*choose_faster_windowed_einsum=*/true, + /*unroll_windowed_einsum=*/false, + /*bidirectional_windowed_einsum=*/true, + /*threshold_for_windowed_einsum_mib=*/-1)); + XLA_VLOG_LINES(1, module->ToString()); + // Check an all-to-all is emitted for resharding. + auto all_to_all = FindInstruction(module.get(), HloOpcode::kAllToAll); + EXPECT_NE(all_to_all, nullptr); +} + } // namespace } // namespace spmd } // namespace xla From 195875ecf28e4b361d97e0e67c3f6c35299baede Mon Sep 17 00:00:00 2001 From: Krasimir Georgiev Date: Mon, 20 Nov 2023 12:00:33 -0800 Subject: [PATCH 302/391] Integrate LLVM at llvm/llvm-project@9bdbb8226e70 Updates LLVM usage to match [9bdbb8226e70](https://github.com/llvm/llvm-project/commit/9bdbb8226e70) PiperOrigin-RevId: 584091615 --- third_party/llvm/generated.patch | 4583 +----------------------------- third_party/llvm/workspace.bzl | 4 +- 2 files changed, 14 insertions(+), 4573 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 6220f1e8bd02f2..bf4982f08ea1bf 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,4572 +1,13 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/compiler-rt/test/msan/vararg_shadow.cpp b/compiler-rt/test/msan/vararg_shadow.cpp ---- a/compiler-rt/test/msan/vararg_shadow.cpp -+++ b/compiler-rt/test/msan/vararg_shadow.cpp -@@ -3,8 +3,8 @@ - // Without -fno-sanitize-memory-param-retval we can't even pass poisoned values. - // RUN: %clangxx_msan -fno-sanitize-memory-param-retval -fsanitize-memory-track-origins=0 -O3 %s -o %t - --// The most of targets fail the test. --// XFAIL: target={{(x86|aarch64|loongarch64|mips|powerpc64).*}} -+// Nothing works yet. -+// XFAIL: target={{(aarch64|loongarch64|mips|powerpc64).*}} - - #include - #include -diff -ruN --strip-trailing-cr a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp ---- a/llvm/lib/Target/X86/X86ISelLowering.cpp -+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp -@@ -49796,8 +49796,8 @@ - } - } - -- // If we also load/broadcast this to a wider type, then just extract the -- // lowest subvector. -+ // If we also broadcast this to a wider type, then just extract the lowest -+ // subvector. - if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() && - (RegVT.is128BitVector() || RegVT.is256BitVector())) { - SDValue Ptr = Ld->getBasePtr(); -@@ -49805,9 +49805,8 @@ - for (SDNode *User : Chain->uses()) { - if (User != N && - (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD || -- User->getOpcode() == X86ISD::VBROADCAST_LOAD || -- ISD::isNormalLoad(User)) && -- cast(User)->getChain() == Chain && -+ User->getOpcode() == X86ISD::VBROADCAST_LOAD) && -+ cast(User)->getChain() == Chain && - !User->hasAnyUseOfValue(1) && - User->getValueSizeInBits(0).getFixedValue() > - RegVT.getFixedSizeInBits()) { -@@ -49820,13 +49819,9 @@ - Extract = DAG.getBitcast(RegVT, Extract); - return DCI.CombineTo(N, Extract, SDValue(User, 1)); - } -- if ((User->getOpcode() == X86ISD::VBROADCAST_LOAD || -- (ISD::isNormalLoad(User) && -- cast(User)->getBasePtr() != Ptr)) && -+ if (User->getOpcode() == X86ISD::VBROADCAST_LOAD && - getTargetConstantFromBasePtr(Ptr)) { -- // See if we are loading a constant that has also been broadcast or -- // we are loading a constant that also matches in the lower -- // bits of a longer constant (but from a different constant pool ptr). -+ // See if we are loading a constant that has also been broadcast. - APInt Undefs, UserUndefs; - SmallVector Bits, UserBits; - if (getTargetConstantBitsFromNode(SDValue(N, 0), 8, Undefs, Bits) && -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp ---- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp -+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp -@@ -4669,16 +4669,22 @@ - - /// Compute the shadow address for a given va_arg. - Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, -- unsigned ArgOffset, unsigned ArgSize) { -- // Make sure we don't overflow __msan_va_arg_tls. -- if (ArgOffset + ArgSize > kParamTLSSize) -- return nullptr; -+ unsigned ArgOffset) { - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg_va_s"); - } - -+ /// Compute the shadow address for a given va_arg. -+ Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, -+ unsigned ArgOffset, unsigned ArgSize) { -+ // Make sure we don't overflow __msan_va_arg_tls. -+ if (ArgOffset + ArgSize > kParamTLSSize) -+ return nullptr; -+ return getShadowPtrForVAArgument(Ty, IRB, ArgOffset); -+ } -+ - /// Compute the origin address for a given va_arg. - Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) { - Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); -@@ -4772,6 +4778,24 @@ - unsigned FpOffset = AMD64GpEndOffset; - unsigned OverflowOffset = AMD64FpEndOffset; - const DataLayout &DL = F.getParent()->getDataLayout(); -+ -+ auto CleanUnusedTLS = [&](Value *ShadowBase, unsigned BaseOffset) { -+ // Make sure we don't overflow __msan_va_arg_tls. -+ if (OverflowOffset <= kParamTLSSize) -+ return false; // Not needed, end is not reacheed. -+ -+ // The tails of __msan_va_arg_tls is not large enough to fit full -+ // value shadow, but it will be copied to backup anyway. Make it -+ // clean. -+ if (BaseOffset < kParamTLSSize) { -+ Value *TailSize = ConstantInt::getSigned(IRB.getInt32Ty(), -+ kParamTLSSize - BaseOffset); -+ IRB.CreateMemSet(ShadowBase, ConstantInt::getNullValue(IRB.getInt8Ty()), -+ TailSize, Align(8)); -+ } -+ return true; // Incomplete -+ }; -+ - for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { - bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); - bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); -@@ -4784,19 +4808,22 @@ - assert(A->getType()->isPointerTy()); - Type *RealTy = CB.getParamByValType(ArgNo); - uint64_t ArgSize = DL.getTypeAllocSize(RealTy); -- Value *ShadowBase = getShadowPtrForVAArgument( -- RealTy, IRB, OverflowOffset, alignTo(ArgSize, 8)); -+ uint64_t AlignedSize = alignTo(ArgSize, 8); -+ unsigned BaseOffset = OverflowOffset; -+ Value *ShadowBase = -+ getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); - Value *OriginBase = nullptr; - if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); -- OverflowOffset += alignTo(ArgSize, 8); -- if (!ShadowBase) -- continue; -+ OverflowOffset += AlignedSize; -+ -+ if (CleanUnusedTLS(ShadowBase, BaseOffset)) -+ continue; // We have no space to copy shadow there. -+ - Value *ShadowPtr, *OriginPtr; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, - /*isStore*/ false); -- - IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr, - kShadowTLSAlignment, ArgSize); - if (MS.TrackOrigins) -@@ -4811,36 +4838,39 @@ - Value *ShadowBase, *OriginBase = nullptr; - switch (AK) { - case AK_GeneralPurpose: -- ShadowBase = -- getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); -+ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); - if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(IRB, GpOffset); - GpOffset += 8; -+ assert(GpOffset <= kParamTLSSize); - break; - case AK_FloatingPoint: -- ShadowBase = -- getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); -+ ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); - if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(IRB, FpOffset); - FpOffset += 16; -+ assert(FpOffset <= kParamTLSSize); - break; - case AK_Memory: - if (IsFixed) - continue; - uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); -+ uint64_t AlignedSize = alignTo(ArgSize, 8); -+ unsigned BaseOffset = OverflowOffset; - ShadowBase = -- getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); -- if (MS.TrackOrigins) -+ getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); -+ if (MS.TrackOrigins) { - OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); -- OverflowOffset += alignTo(ArgSize, 8); -+ } -+ OverflowOffset += AlignedSize; -+ if (CleanUnusedTLS(ShadowBase, BaseOffset)) -+ continue; // We have no space to copy shadow there. - } - // Take fixed arguments into account for GpOffset and FpOffset, - // but don't actually store shadows for them. - // TODO(glider): don't call get*PtrForVAArgument() for them. - if (IsFixed) - continue; -- if (!ShadowBase) -- continue; - Value *Shadow = MSV.getShadow(A); - IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment); - if (MS.TrackOrigins) { -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll b/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll ---- a/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll -+++ b/llvm/test/CodeGen/X86/broadcast-elm-cross-splat-vec.ll -@@ -1400,7 +1400,7 @@ - ; AVX-64-LABEL: f4xi64_i128: - ; AVX-64: # %bb.0: - ; AVX-64-NEXT: vextractf128 $1, %ymm0, %xmm1 --; AVX-64-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1] -+; AVX-64-NEXT: vmovdqa {{.*#+}} xmm2 = [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0] - ; AVX-64-NEXT: vpaddq %xmm2, %xmm1, %xmm1 - ; AVX-64-NEXT: vpaddq %xmm2, %xmm0, %xmm0 - ; AVX-64-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0 -@@ -1535,7 +1535,7 @@ - ; AVX-64-NEXT: vextractf128 $1, %ymm1, %xmm2 - ; AVX-64-NEXT: vmovdqa {{.*#+}} xmm3 = [2,3] - ; AVX-64-NEXT: vpaddq %xmm3, %xmm2, %xmm2 --; AVX-64-NEXT: vmovdqa {{.*#+}} xmm4 = [0,1] -+; AVX-64-NEXT: vmovdqa {{.*#+}} xmm4 = [0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0] - ; AVX-64-NEXT: vpaddq %xmm4, %xmm1, %xmm1 - ; AVX-64-NEXT: vinsertf128 $1, %xmm2, %ymm1, %ymm1 - ; AVX-64-NEXT: vextractf128 $1, %ymm0, %xmm2 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll -@@ -2157,7 +2157,7 @@ - ; AVX2-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm10[2,3,0,1] - ; AVX2-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm10[0,1,2],ymm11[3],ymm10[4,5,6,7,8,9,10],ymm11[11],ymm10[12,13,14,15] - ; AVX2-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm10[2,3,2,3,2,3,2,3,8,9,8,9,6,7,4,5,18,19,18,19,18,19,18,19,24,25,24,25,22,23,20,21] --; AVX2-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> -+; AVX2-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] - ; AVX2-SLOW-NEXT: vpblendvb %ymm10, %ymm8, %ymm11, %ymm8 - ; AVX2-SLOW-NEXT: vpblendd {{.*#+}} ymm11 = ymm5[0,1],ymm6[2],ymm5[3,4,5],ymm6[6],ymm5[7] - ; AVX2-SLOW-NEXT: vextracti128 $1, %ymm11, %xmm12 -@@ -2329,7 +2329,7 @@ - ; AVX2-FAST-NEXT: vmovdqa {{.*#+}} ymm12 = <2,5,1,u,4,u,u,u> - ; AVX2-FAST-NEXT: vpermd %ymm11, %ymm12, %ymm11 - ; AVX2-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm11[2,3,2,3,2,3,2,3,8,9,0,1,6,7,8,9,18,19,18,19,18,19,18,19,24,25,16,17,22,23,24,25] --; AVX2-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> -+; AVX2-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] - ; AVX2-FAST-NEXT: vpblendvb %ymm11, %ymm10, %ymm12, %ymm10 - ; AVX2-FAST-NEXT: vpblendd {{.*#+}} ymm12 = ymm4[0,1],ymm6[2],ymm4[3,4,5],ymm6[6],ymm4[7] - ; AVX2-FAST-NEXT: vextracti128 $1, %ymm12, %xmm13 -@@ -2496,7 +2496,7 @@ - ; AVX2-FAST-PERLANE-NEXT: vpermq {{.*#+}} ymm12 = ymm11[2,3,0,1] - ; AVX2-FAST-PERLANE-NEXT: vpblendw {{.*#+}} ymm11 = ymm11[0,1,2],ymm12[3],ymm11[4,5,6,7,8,9,10],ymm12[11],ymm11[12,13,14,15] - ; AVX2-FAST-PERLANE-NEXT: vpshufb {{.*#+}} ymm12 = ymm11[2,3,2,3,2,3,2,3,8,9,8,9,6,7,4,5,18,19,18,19,18,19,18,19,24,25,24,25,22,23,20,21] --; AVX2-FAST-PERLANE-NEXT: vmovdqa {{.*#+}} xmm11 = <255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0> -+; AVX2-FAST-PERLANE-NEXT: vmovdqa {{.*#+}} xmm11 = [255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0] - ; AVX2-FAST-PERLANE-NEXT: vpblendvb %ymm11, %ymm8, %ymm12, %ymm8 - ; AVX2-FAST-PERLANE-NEXT: vpblendd {{.*#+}} ymm12 = ymm5[0,1],ymm6[2],ymm5[3,4,5],ymm6[6],ymm5[7] - ; AVX2-FAST-PERLANE-NEXT: vextracti128 $1, %ymm12, %xmm13 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i8-stride-5.ll -@@ -1685,7 +1685,7 @@ - ; AVX2-ONLY-NEXT: # ymm10 = mem[0,1,0,1] - ; AVX2-ONLY-NEXT: vpblendvb %ymm10, %ymm7, %ymm8, %ymm7 - ; AVX2-ONLY-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,1,6,11,16,21,26,31,20,25,30,19,24,29,u,u,u,u,u,u] --; AVX2-ONLY-NEXT: vmovdqa {{.*#+}} xmm10 = <255,255,255,255,255,255,255,255,255,255,255,255,255,0,0,0> -+; AVX2-ONLY-NEXT: vmovdqa {{.*#+}} xmm10 = [255,255,255,255,255,255,255,255,255,255,255,255,255,0,0,0] - ; AVX2-ONLY-NEXT: vpblendvb %ymm10, %ymm6, %ymm7, %ymm6 - ; AVX2-ONLY-NEXT: vmovdqa 144(%rdi), %xmm7 - ; AVX2-ONLY-NEXT: vpshufb {{.*#+}} xmm11 = xmm7[u,u,u,u,u,u,u,u,u,u],zero,zero,zero,xmm7[1,6,11] -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-3.ll -@@ -1238,12 +1238,13 @@ - ; AVX512F-NEXT: vshufi64x2 {{.*#+}} zmm3 = zmm3[0,1,2,3],zmm6[4,5,6,7] - ; AVX512F-NEXT: vmovdqa (%rdx), %ymm6 - ; AVX512F-NEXT: vmovdqa 32(%rdx), %ymm7 --; AVX512F-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,10,11,128,128,128,128,12,13,128,128,128,128,14,15,128,128,128,128,16,17,128,128,128,128,18,19,128,128,128,128] --; AVX512F-NEXT: vpshufb %ymm9, %ymm7, %ymm10 --; AVX512F-NEXT: vmovdqa {{.*#+}} ymm11 = <5,5,u,6,6,u,7,7> --; AVX512F-NEXT: vpermd %ymm7, %ymm11, %ymm7 --; AVX512F-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm7, %ymm7 --; AVX512F-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 -+; AVX512F-NEXT: vmovdqa {{.*#+}} ymm9 = <5,5,u,6,6,u,7,7> -+; AVX512F-NEXT: vpermd %ymm7, %ymm9, %ymm9 -+; AVX512F-NEXT: vmovdqa {{.*#+}} ymm10 = [0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0] -+; AVX512F-NEXT: vpandn %ymm9, %ymm10, %ymm9 -+; AVX512F-NEXT: vmovdqa {{.*#+}} ymm10 = [128,128,10,11,128,128,128,128,12,13,128,128,128,128,14,15,128,128,128,128,16,17,128,128,128,128,18,19,128,128,128,128] -+; AVX512F-NEXT: vpshufb %ymm10, %ymm7, %ymm7 -+; AVX512F-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 - ; AVX512F-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm7 - ; AVX512F-NEXT: vmovdqa (%rdi), %ymm3 - ; AVX512F-NEXT: vpshufb %ymm5, %ymm3, %ymm3 -@@ -1258,7 +1259,7 @@ - ; AVX512F-NEXT: vpshufb %xmm2, %xmm0, %xmm0 - ; AVX512F-NEXT: vinserti128 $1, %xmm5, %ymm0, %ymm0 - ; AVX512F-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm3[4,5,6,7] --; AVX512F-NEXT: vpshufb %ymm9, %ymm6, %ymm1 -+; AVX512F-NEXT: vpshufb %ymm10, %ymm6, %ymm1 - ; AVX512F-NEXT: vmovdqa {{.*#+}} ymm2 = - ; AVX512F-NEXT: vpermd %ymm6, %ymm2, %ymm2 - ; AVX512F-NEXT: vmovdqa {{.*#+}} ymm3 = [65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535,65535,0,65535] -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-5.ll -@@ -2831,15 +2831,15 @@ - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,0,0] - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm9, %zmm2 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm11, %zmm2 --; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm10 --; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm9 -+; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm9 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm10 - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm11 = [128,128,128,128,12,13,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128] --; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm9, %ymm4 --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,1,1,1] -+; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm10, %ymm4 -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,1,1,1] - ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} ymm21 = [65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535] --; AVX512F-SLOW-NEXT: vpandnq %ymm9, %ymm21, %ymm9 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm9, %zmm9 --; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm9 -+; AVX512F-SLOW-NEXT: vpandnq %ymm10, %ymm21, %ymm10 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm10, %zmm10 -+; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm10 - ; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm2 - ; AVX512F-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm12[0],xmm2[0],xmm12[1],xmm2[1],xmm12[2],xmm2[2],xmm12[3],xmm2[3] - ; AVX512F-SLOW-NEXT: vpshufb %xmm13, %xmm4, %xmm4 -@@ -2860,7 +2860,7 @@ - ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535] - ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm2, %zmm7, %zmm4 - ; AVX512F-SLOW-NEXT: vpbroadcastq (%r8), %ymm2 --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm10[0,1,1,1] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm9[0,1,1,1] - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm2, %zmm2 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm18[0,1,2,1,4,5,6,5] -@@ -2909,15 +2909,16 @@ - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm5, %zmm1 - ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm16, %zmm1 --; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm10, %ymm0 --; AVX512F-SLOW-NEXT: vpbroadcastq 16(%r8), %ymm3 --; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm3, %ymm3 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 -+; AVX512F-SLOW-NEXT: vpbroadcastq 16(%r8), %ymm0 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = [65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535] -+; AVX512F-SLOW-NEXT: vpandn %ymm0, %ymm3, %ymm0 -+; AVX512F-SLOW-NEXT: vpshufb %ymm11, %ymm9, %ymm3 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm0 - ; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm0 - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm0, 64(%r9) - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm4, 256(%r9) - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm2, (%r9) --; AVX512F-SLOW-NEXT: vmovdqa64 %zmm9, 192(%r9) -+; AVX512F-SLOW-NEXT: vmovdqa64 %zmm10, 192(%r9) - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm19, 128(%r9) - ; AVX512F-SLOW-NEXT: vzeroupper - ; AVX512F-SLOW-NEXT: retq -@@ -3018,10 +3019,11 @@ - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm13, %zmm7 - ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535,65535,0,0,65535,65535] - ; AVX512F-FAST-NEXT: vpternlogq $226, %zmm3, %zmm20, %zmm7 --; AVX512F-FAST-NEXT: vmovdqa64 %ymm24, %ymm3 --; AVX512F-FAST-NEXT: vpshufb %ymm3, %ymm0, %ymm0 - ; AVX512F-FAST-NEXT: vpbroadcastq 16(%r8), %ymm3 --; AVX512F-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm3, %ymm3 -+; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535,0,65535,65535,65535,65535] -+; AVX512F-FAST-NEXT: vpandn %ymm3, %ymm13, %ymm3 -+; AVX512F-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 -+; AVX512F-FAST-NEXT: vpshufb %ymm11, %ymm0, %ymm0 - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0 - ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm0 - ; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm13 = [30,31,28,29,26,27,30,31,30,31,28,29,30,31,28,29,30,31,28,29,26,27,30,31,30,31,28,29,30,31,28,29] -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i16-stride-7.ll -@@ -2522,7 +2522,8 @@ - ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm12, %xmm10, %xmm12 - ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd 8(%rax), %ymm10 --; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm10, %ymm10 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} ymm20 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512F-ONLY-FAST-NEXT: vpandnq %ymm10, %ymm20, %ymm10 - ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm8 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3] - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm14 = xmm8[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm8 = xmm8[0,1,2,3,6,7,4,5,6,7,4,5,12,13,14,15] -@@ -2788,7 +2789,8 @@ - ; AVX512DQ-FAST-NEXT: vpshufb %xmm12, %xmm10, %xmm12 - ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512DQ-FAST-NEXT: vpbroadcastd 8(%rax), %ymm10 --; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm10, %ymm10 -+; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} ymm20 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512DQ-FAST-NEXT: vpandnq %ymm10, %ymm20, %ymm10 - ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm8 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3] - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm14 = xmm8[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm8 = xmm8[0,1,2,3,6,7,4,5,6,7,4,5,12,13,14,15] -@@ -5250,336 +5252,330 @@ - ; - ; AVX512F-ONLY-SLOW-LABEL: store_i16_stride7_vf32: - ; AVX512F-ONLY-SLOW: # %bb.0: --; AVX512F-ONLY-SLOW-NEXT: subq $632, %rsp # imm = 0x278 -+; AVX512F-ONLY-SLOW-NEXT: subq $648, %rsp # imm = 0x288 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm1, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm1, %ymm16 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm13 = --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm13, %ymm2, %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 --; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm12 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm12, %ymm1, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm1, %ymm15 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm14, %ymm2, %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 --; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %ymm10 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %ymm11 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, (%rsp) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %ymm4 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm4, %ymm30 - ; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %xmm3 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %xmm6 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm11[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm4 = ymm10[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[2,2,2,3,6,6,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm2[2],ymm4[3,4],ymm2[5],ymm4[6,7,8,9],ymm2[10],ymm4[11,12],ymm2[13],ymm4[14,15] --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm6[0],xmm3[0],xmm6[1],xmm3[1],xmm6[2],xmm3[2],xmm6[3],xmm3[3] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm2[0,1,3,2,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %xmm8 --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,5,7,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> --; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm2, %zmm3, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm4 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %ymm5 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm5, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm31 -+; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %ymm15 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm2, %ymm15, %ymm5 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %ymm6 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm3, %ymm6, %ymm7 -+; AVX512F-ONLY-SLOW-NEXT: vpor %ymm5, %ymm7, %ymm5 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %xmm7 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %xmm8 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm4, %ymm5, %ymm4 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %ymm11 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm11, %ymm9 -+; AVX512F-ONLY-SLOW-NEXT: vpor %ymm4, %ymm9, %ymm4 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %ymm3 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %ymm6 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm12, %ymm6, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm14, %ymm3, %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %ymm12 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %ymm7 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm9, %ymm12, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm13, %ymm7, %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm13, %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm14, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vpor %ymm1, %ymm0, %ymm0 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %ymm13 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm0, %ymm13, %ymm0 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %ymm14 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm1, %ymm14, %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %xmm0 --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm2 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm2, %xmm8, %xmm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm2, %xmm20 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm2 = xmm0[1,1,2,2] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm0[0],xmm8[0],xmm0[1],xmm8[1],xmm0[2],xmm8[2],xmm0[3],xmm8[3] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm2 = xmm2[0,1,3,2,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = --; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm2, %zmm1, %zmm4 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %xmm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %xmm2 --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm5 = xmm4[0,1,2,3,4,5,7,6] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm4 = xmm4[0,1,3,2,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> --; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm9 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r9), %ymm4 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm2, %ymm4, %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%r8), %ymm0 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm3, %ymm0, %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vpor %ymm1, %ymm2, %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm6[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm2 = ymm15[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm2[2,2,2,3,6,6,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm1[2],ymm2[3,4],ymm1[5],ymm2[6,7,8,9],ymm1[10],ymm2[11,12],ymm1[13],ymm2[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm8[0],xmm7[0],xmm8[1],xmm7[1],xmm8[2],xmm7[2],xmm8[3],xmm7[3] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm1[0,1,3,2,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm3, %zmm2, %zmm9 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm9, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm10, %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[1,2,2,3,5,6,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm5[0,1],ymm4[2],ymm5[3,4],ymm4[5],ymm5[6,7,8,9],ymm4[10],ymm5[11,12],ymm4[13],ymm5[14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[0,0,2,1,4,4,6,5] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm9 = ymm10[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm9[0,0,0,0,4,4,4,4] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm9[0,1,2],ymm5[3],ymm9[4,5],ymm5[6],ymm9[7,8,9,10],ymm5[11],ymm9[12,13],ymm5[14],ymm9[15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm27 = [2,2,3,3,10,9,11,10] --; AVX512F-ONLY-SLOW-NEXT: vpermi2q %zmm4, %zmm5, %zmm27 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rcx), %xmm9 -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm8[4],xmm7[4],xmm8[5],xmm7[5],xmm8[6],xmm7[6],xmm8[7],xmm7[7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,4,5,7,6] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm2, %xmm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm3 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> -+; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm1, %zmm2, %zmm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdx), %xmm2 -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm7 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm9, %xmm1 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[1,1,2,2] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2,3],xmm1[4],xmm3[5,6],xmm1[7] -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm9[0],xmm2[1],xmm9[1],xmm2[2],xmm9[2],xmm2[3],xmm9[3] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm3[0,1,3,2,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = -+; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm3, %zmm1, %zmm8 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r9), %xmm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%r8), %xmm12 -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm12[0],xmm3[0],xmm12[1],xmm3[1],xmm12[2],xmm3[2],xmm12[3],xmm3[3] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm1[0,1,2,3,4,5,7,6] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[0,1,3,2,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm26 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> -+; AVX512F-ONLY-SLOW-NEXT: vpermi2d %zmm8, %zmm1, %zmm26 -+; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm15, %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm6[1,2,2,3,5,6,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0,1],ymm1[2],ymm8[3,4],ymm1[5],ymm8[6,7,8,9],ymm1[10],ymm8[11,12],ymm1[13],ymm8[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm8 = ymm15[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[0,0,0,0,4,4,4,4] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm8[0,1,2],ymm6[3],ymm8[4,5],ymm6[6],ymm8[7,8,9,10],ymm6[11],ymm8[12,13],ymm6[14],ymm8[15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] -+; AVX512F-ONLY-SLOW-NEXT: vpermi2q %zmm1, %zmm6, %zmm28 - ; AVX512F-ONLY-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rax), %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm5 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm5, %ymm4, %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] --; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm30 --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm12[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm9[4],xmm2[5],xmm9[5],xmm2[6],xmm9[6],xmm2[7],xmm9[7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm1, %xmm25 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rax), %ymm8 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm8, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512F-ONLY-SLOW-NEXT: vpandn %ymm1, %ymm2, %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm6 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm6, %ymm8, %ymm2 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm12[4],xmm3[4],xmm12[5],xmm3[5],xmm12[6],xmm3[6],xmm12[7],xmm3[7] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm5[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,2,1,4,4,6,5] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[1,1,1,1,5,5,5,5] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,0,0,4,4,4,4] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,1,1,3,4,5,5,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm3[0,1],ymm1[2],ymm3[3,4],ymm1[5],ymm3[6,7,8,9],ymm1[10],ymm3[11,12],ymm1[13],ymm3[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rax), %ymm1 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm1[0,1,1,3,4,5,5,7] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] -+; AVX512F-ONLY-SLOW-NEXT: vpandn %ymm3, %ymm8, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm6, %ymm1, %ymm6 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm0[0,0,2,1,4,4,6,5] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm6[3],ymm3[4,5],ymm6[6],ymm3[7,8,9,10],ymm6[11],ymm3[12,13],ymm6[14],ymm3[15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 -+; AVX512F-ONLY-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm15 = [22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27,22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27] -+; AVX512F-ONLY-SLOW-NEXT: # ymm15 = mem[0,1,0,1] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm15, %ymm13, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm14[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm22 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm6 = ymm13[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0],ymm6[1],ymm3[2,3],ymm6[4],ymm3[5,6,7,8],ymm6[9],ymm3[10,11],ymm6[12],ymm3[13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm5[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm11[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm6[0],ymm3[1],ymm6[2,3],ymm3[4],ymm6[5,6,7,8],ymm3[9],ymm6[10,11],ymm3[12],ymm6[13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm5 = ymm5[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm5[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm13 = ymm3[0,1,2],ymm5[3],ymm3[4,5],ymm5[6],ymm3[7,8,9,10],ymm5[11],ymm3[12,13],ymm5[14],ymm3[15] -+; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm4, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm0[1,2,2,3,5,6,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0,1],ymm3[2],ymm5[3,4],ymm3[5],ymm5[6,7,8,9],ymm3[10],ymm5[11,12],ymm3[13],ymm5[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm24 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm4[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,3,6,6,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm23 --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm0 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] -+; AVX512F-ONLY-SLOW-NEXT: # zmm0 = mem[0,1,2,3,0,1,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermd %zmm1, %zmm0, %zmm29 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu (%rsp), %ymm11 # 32-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm11[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] - ; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm30[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm9 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm14[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] - ; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 --; AVX512F-ONLY-SLOW-NEXT: vprold $16, %ymm13, %ymm0 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[1,2,2,3,5,6,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4],ymm0[5],ymm1[6,7,8,9],ymm0[10],ymm1[11,12],ymm0[13],ymm1[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm13[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,3,6,6,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm16, %ymm4 --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm4[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm29[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm5, %ymm18 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm15, %ymm8 --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm15[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm17[3,3,3,3,7,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm31[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdi), %xmm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %xmm9 --; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm9, %xmm1 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm0[1,1,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm15[0,1],xmm1[2],xmm15[3,4],xmm1[5],xmm15[6,7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm0[0],xmm9[0],xmm0[1],xmm9[1],xmm0[2],xmm9[2],xmm0[3],xmm9[3] --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm9[4],xmm0[4],xmm9[5],xmm0[5],xmm9[6],xmm0[6],xmm9[7],xmm0[7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %xmm9 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm20, %xmm0 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm0, %xmm9, %xmm0 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %xmm15 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm11 = xmm15[1,1,2,2] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm11[0],xmm0[1],xmm11[2,3],xmm0[4],xmm11[5,6],xmm0[7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm9[0],xmm15[1],xmm9[1],xmm15[2],xmm9[2],xmm15[3],xmm9[3] --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm9 = xmm15[4],xmm9[4],xmm15[5],xmm9[5],xmm15[6],xmm9[6],xmm15[7],xmm9[7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm9, %xmm25 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[1,1,1,1,5,5,5,5] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm6[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm24 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm7[0,1,1,3,4,5,5,7] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm12[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm22 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,0,2,1,4,4,6,5] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1,2],ymm3[3],ymm6[4,5],ymm3[6],ymm6[7,8,9,10],ymm3[11],ymm6[12,13],ymm3[14],ymm6[15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 --; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm3 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] --; AVX512F-ONLY-SLOW-NEXT: # zmm3 = mem[0,1,2,3,0,1,2,3] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rax), %ymm6 --; AVX512F-ONLY-SLOW-NEXT: vpermd %zmm6, %zmm3, %zmm3 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm28, %ymm7 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm7, %ymm6, %ymm11 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,1,1,3,4,5,5,7] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm6 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm11, %zmm6 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm11 --; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %xmm13 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %xmm14 --; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm13[0],xmm14[0],xmm13[1],xmm14[1],xmm13[2],xmm14[2],xmm13[3],xmm14[3] --; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm2 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 --; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm15 = xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] --; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm14, %xmm14 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[1,1,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm13[0,1],xmm14[2],xmm13[3,4],xmm14[5],xmm13[6,7] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa %ymm4, %ymm2 --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,0,0,0,4,4,4,4] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm29, %ymm12 --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm29[0,1,1,3,4,5,5,7] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm14 = ymm9[0,1],ymm14[2],ymm9[3,4],ymm14[5],ymm9[6,7,8,9],ymm14[10],ymm9[11,12],ymm14[13],ymm9[14,15] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm31, %xmm4 --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm4[0,2,3,3,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 -+; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm1, %xmm3 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm0[1,1,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm6[0,1],xmm3[2],xmm6[3,4],xmm3[5],xmm6[6,7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm2, %ymm21 -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm20 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rcx), %xmm1 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm0 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa 32(%rdx), %xmm6 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm7 = xmm6[1,1,2,2] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm7[0],xmm0[1],xmm7[2,3],xmm0[4],xmm7[5,6],xmm0[7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm5 = xmm6[0],xmm1[0],xmm6[1],xmm1[1],xmm6[2],xmm1[2],xmm6[3],xmm1[3] -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm1[4],xmm6[5],xmm1[5],xmm6[6],xmm1[6],xmm6[7],xmm1[7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm0, %xmm18 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm12, %xmm2 -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm3, %xmm4 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rdi), %xmm3 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa (%rsi), %xmm6 -+; AVX512F-ONLY-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm7 = xmm3[0],xmm6[0],xmm3[1],xmm6[1],xmm3[2],xmm6[2],xmm3[3],xmm6[3] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %xmm10, %xmm7, %xmm10 -+; AVX512F-ONLY-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] -+; AVX512F-ONLY-SLOW-NEXT: vprold $16, %xmm6, %xmm6 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm3[1,1,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} xmm7 = xmm3[0,1],xmm6[2],xmm3[3,4],xmm6[5],xmm3[6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm11[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm30[0,1,1,3,4,5,5,7] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufb %ymm15, %ymm11, %ymm3 -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm30[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm15 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm14[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,2,1,4,4,6,5] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[1,1,1,1,5,5,5,5] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm14 = ymm14[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[2,2,2,2,6,6,6,6] -+; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm12 = ymm12[0],ymm14[1],ymm12[2,3],ymm14[4],ymm12[5,6,7,8],ymm14[9],ymm12[10,11],ymm14[12],ymm12[13,14,15] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm25, %xmm1 -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm14 = xmm1[0,2,3,3,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,2,1] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[2,1,2,3,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,4] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm0[0,0,1,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,0,1,1] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm22[2,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm17[0,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm16[0,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,1,3,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm5[0,1,3,2,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm5 = xmm5[0,0,1,1] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,1,3,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,0,1,1] -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm30 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm0 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm13 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $184, %zmm30, %zmm13, %zmm0 -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,1,3] -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm9, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm13, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm14, %zmm5 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm8 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm31 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm31, %zmm8 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm1 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm2[0,1,2,3],zmm1[4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: # ymm2 = mem[2,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: # ymm5 = mem[2,1,3,2] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm27[2,2,3,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm24[2,1,3,2] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm16 = ymm23[2,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload -+; AVX512F-ONLY-SLOW-NEXT: # ymm17 = mem[2,3,3,3,6,7,7,7] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm21[0,0,2,1] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm20, %xmm9 -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm11 = xmm9[2,1,2,3,4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm11 = xmm11[0,1,2,3,4,5,5,4] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm19[0,0,1,1] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm18, %xmm9 -+; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm9[0,2,3,3,4,5,6,7] - ; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,2,1] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm15 = xmm15[2,1,2,3,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm15 = xmm15[0,1,2,3,4,5,5,4] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[0,0,1,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,1] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm30[2,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm23[0,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm19[2,1,3,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm18[0,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[0,1,3,2,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,1,1] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm5[2,1,3,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,1] --; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm12[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} ymm4 = ymm8[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,2,1,4,4,6,5] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[1,1,1,1,5,5,5,5] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm5[2],ymm4[3,4],ymm5[5],ymm4[6,7,8,9],ymm5[10],ymm4[11,12],ymm5[13],ymm4[14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} ymm12 = ymm8[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] --; AVX512F-ONLY-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm12[2,2,2,2,6,6,6,6] --; AVX512F-ONLY-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6,7,8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14,15] --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm28, %zmm12 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $184, %zmm12, %zmm10, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,1,1,3] --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm30, %zmm0 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm31, %zmm1 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm1 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm0 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm15, %zmm9 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm11[0,1,2,3],zmm0[4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: # ymm11 = mem[2,1,3,2] --; AVX512F-ONLY-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: # ymm12 = mem[2,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: # ymm15 = mem[2,3,3,3,6,7,7,7] --; AVX512F-ONLY-SLOW-NEXT: vpermq $96, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: # ymm17 = mem[0,0,2,1] --; AVX512F-ONLY-SLOW-NEXT: vpshuflw $230, {{[-0-9]+}}(%r{{[sb]}}p), %xmm8 # 16-byte Folded Reload --; AVX512F-ONLY-SLOW-NEXT: # xmm8 = mem[2,1,2,3,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm8[0,1,2,3,4,5,5,4] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[0,0,1,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm26[0,0,1,1] --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %xmm25, %xmm13 --; AVX512F-ONLY-SLOW-NEXT: vpshuflw {{.*#+}} xmm13 = xmm13[0,2,3,3,4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,2,1] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm24[2,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm22[2,1,3,2] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm21[2,2,3,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm20[0,0,1,1] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm16[0,0,2,1] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,1,3,2] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,2,2,3] --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm16 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm16, %zmm0 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm0 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm3 --; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm15[2,1,3,2] --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm9 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm7 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm17, %zmm1 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm18, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm1, %zmm10, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm1 --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm9 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm1, %zmm1 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm1 -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,1] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,0,2,1] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,1,3,2] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,2,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm18 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm1 - ; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm1 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm19, %zmm8, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm28, %zmm9, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm0, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm10[0,1,2,3],zmm8[4,5,6,7] --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm6 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm6 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm8 --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd (%rax), %ymm9 --; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm10 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm10, %zmm9, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm14, %zmm2 --; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm4, %zmm4 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm4 --; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm2 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] --; AVX512F-ONLY-SLOW-NEXT: # zmm2 = mem[0,1,2,3,0,1,2,3] --; AVX512F-ONLY-SLOW-NEXT: vpermd (%rax), %zmm2, %zmm2 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm27, %zmm2 --; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm8, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm8[0,1,2,3],zmm2[4,5,6,7] -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm16, %zmm14, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm29 -+; AVX512F-ONLY-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm17[2,1,3,2] -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm5 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm0, %zmm0 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm0 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm0 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm22, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm30, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm31, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm4 -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm8 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm4, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd (%rax), %ymm7 -+; AVX512F-ONLY-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm8 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm7, %zmm7 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm7 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm7 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm15, %zmm6, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm3, %zmm3 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm3 -+; AVX512F-ONLY-SLOW-NEXT: vbroadcasti64x4 {{.*#+}} zmm5 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] -+; AVX512F-ONLY-SLOW-NEXT: # zmm5 = mem[0,1,2,3,0,1,2,3] -+; AVX512F-ONLY-SLOW-NEXT: vpermd (%rax), %zmm5, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm28, %zmm5 -+; AVX512F-ONLY-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm5 - ; AVX512F-ONLY-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm2, 128(%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm9, (%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm6, 320(%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm1, 256(%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm7, 192(%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm0, 64(%rax) --; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm3, 384(%rax) --; AVX512F-ONLY-SLOW-NEXT: addq $632, %rsp # imm = 0x278 -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm5, 128(%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm7, (%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm2, 320(%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm4, 256(%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm0, 192(%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm1, 64(%rax) -+; AVX512F-ONLY-SLOW-NEXT: vmovdqa64 %zmm29, 384(%rax) -+; AVX512F-ONLY-SLOW-NEXT: addq $648, %rsp # imm = 0x288 - ; AVX512F-ONLY-SLOW-NEXT: vzeroupper - ; AVX512F-ONLY-SLOW-NEXT: retq - ; -@@ -5613,9 +5609,9 @@ - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm13 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm13, %ymm6 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm14 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm15 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm14, %ymm7 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm15, %ymm7 - ; AVX512F-ONLY-FAST-NEXT: vporq %ymm6, %ymm7, %ymm25 - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm4, %ymm10, %ymm4 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm6 -@@ -5629,8 +5625,9 @@ - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm15 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm15, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm1, %ymm14 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %ymm4 - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm4, %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 -@@ -5661,11 +5658,11 @@ - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 - ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] - ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm3, %xmm3 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm30 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm5, %xmm8 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm19 = [2,1,3,3,8,8,9,9] - ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm3, %zmm2, %zmm19 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] --; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] -+; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[3,3,3,3,7,7,7,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm3[2],ymm2[3,4],ymm3[5],ymm2[6,7,8,9],ymm3[10],ymm2[11,12],ymm3[13],ymm2[14,15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [2,2,2,3,8,8,8,9] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm3 -@@ -5693,52 +5690,54 @@ - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm27 = [0,0,0,1,8,9,9,11] - ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm27 - ; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm13, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[1,2,2,3,5,6,6,7] -+; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[1,2,2,3,5,6,6,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd {{.*#+}} ymm5 = [18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21] - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm13, %ymm3 --; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm14[0,0,2,1,4,4,6,5] -+; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm15[0,0,2,1,4,4,6,5] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1,2],ymm7[3],ymm3[4,5],ymm7[6],ymm3[7,8,9,10],ymm7[11],ymm3[12,13],ymm7[14],ymm3[15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] - ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm28 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm8 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm15 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm0 --; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm3 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm3, %ymm17 -+; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm9 - ; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm25, %zmm0, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm30, %xmm9 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm1, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm1, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm8, %xmm18 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm25 = <0,0,1,1,12,13,u,15> - ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm2, %zmm1, %zmm25 - ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd 8(%rax), %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512F-ONLY-FAST-NEXT: vpandn %ymm1, %ymm2, %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm3 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm3, %ymm7 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm3, %ymm16 - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm1, %zmm30 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18,19,20,21,18,19,20,21,24,25,26,27,22,23,22,23] - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[1,1,1,1,5,5,5,5] --; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 -+; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm31, %ymm13 - ; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} ymm1 = ymm13[0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11] - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm12[0,1,1,3,4,5,5,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm7 = ymm6[0,1],ymm1[2],ymm6[3,4],ymm1[5],ymm6[6,7,8,9],ymm1[10],ymm6[11,12],ymm1[13],ymm6[14,15] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm15, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm14, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm14, %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm4[0,0,2,1,4,4,6,5] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1,2],ymm5[3],ymm1[4,5],ymm5[6],ymm1[7,8,9,10],ymm5[11],ymm1[12,13],ymm5[14],ymm1[15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm29 = <0,1,u,3,10,10,11,11> - ; AVX512F-ONLY-FAST-NEXT: vpermi2q %zmm1, %zmm21, %zmm29 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rax), %ymm6 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm2, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm14 -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm1, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] -+; AVX512F-ONLY-FAST-NEXT: vpandn %ymm1, %ymm5, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm14 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[14,15,12,13,u,u,u,u,u,u,u,u,u,u,u,u,30,31,28,29,u,u,u,u,30,31,28,29,u,u,u,u] - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm12[3,3,3,3,7,7,7,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0],ymm2[1],ymm5[2,3],ymm2[4],ymm5[5,6,7,8],ymm2[9],ymm5[10,11],ymm2[12],ymm5[13,14,15] -@@ -5749,21 +5748,21 @@ - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm13 = ymm13[0,1],ymm12[2],ymm13[3,4],ymm12[5],ymm13[6,7,8,9],ymm12[10],ymm13[11,12],ymm12[13],ymm13[14,15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm21 = [2,2,2,3,8,10,10,11] - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm13 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm12 = ymm4[3,3,3,3,7,7,7,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm12[2],ymm2[3,4],ymm12[5],ymm2[6,7,8,9],ymm12[10],ymm2[11,12],ymm12[13],ymm2[14,15] --; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm15, %ymm12 -+; AVX512F-ONLY-FAST-NEXT: vprold $16, %ymm3, %ymm12 - ; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[1,2,2,3,5,6,6,7] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm12 = ymm4[0,1],ymm12[2],ymm4[3,4],ymm12[5],ymm4[6,7,8,9],ymm12[10],ymm4[11,12],ymm12[13],ymm4[14,15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = [2,1,3,2,10,10,10,11] - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm31, %zmm12 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm18 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm22, %zmm18, %zmm13 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm17 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm22, %zmm17, %zmm13 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm12 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 (%rax), %zmm15 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 (%rax), %zmm3 - ; AVX512F-ONLY-FAST-NEXT: vbroadcasti64x4 {{.*#+}} zmm4 = [30,5,0,0,31,6,0,31,30,5,0,0,31,6,0,31] - ; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[0,1,2,3,0,1,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermi2d %zmm15, %zmm6, %zmm4 -+; AVX512F-ONLY-FAST-NEXT: vpermi2d %zmm3, %zmm6, %zmm4 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm4 - ; AVX512F-ONLY-FAST-NEXT: vpunpckhwd {{.*#+}} xmm6 = xmm11[4],xmm10[4],xmm11[5],xmm10[5],xmm11[6],xmm10[6],xmm11[7],xmm10[7] - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} xmm12 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] -@@ -5778,14 +5777,15 @@ - ; AVX512F-ONLY-FAST-NEXT: # xmm6 = xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm24, %xmm1 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm13 = xmm1[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm6, %xmm6 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm22 = [0,1,1,3,8,8,9,9] - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm6, %zmm22, %zmm13 - ; AVX512F-ONLY-FAST-NEXT: vprold $16, %xmm0, %xmm6 --; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm8[1,1,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm15[1,1,2,3] - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1],xmm6[2],xmm2[3,4],xmm6[5],xmm2[6,7] --; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm8[0],xmm0[0],xmm8[1],xmm0[1],xmm8[2],xmm0[2],xmm8[3],xmm0[3] --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm0[0],xmm15[1],xmm0[1],xmm15[2],xmm0[2],xmm15[3],xmm0[3] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm11, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm2 - ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm12, %xmm2, %xmm6 -@@ -5814,8 +5814,8 @@ - ; AVX512F-ONLY-FAST-NEXT: vpblendw {{.*#+}} ymm5 = ymm10[0,1],ymm5[2],ymm10[3,4],ymm5[5],ymm10[6,7,8,9],ymm5[10],ymm10[11,12],ymm5[13],ymm10[14,15] - ; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm10 = xmm12[0,2,3,3,4,5,6,7] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,2,1] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm17[0,0,1,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm16[2,2,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,1,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,2,3] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,1,3,2] - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm0, %zmm31, %zmm5 - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastd (%rax), %ymm0 -@@ -5834,8 +5834,8 @@ - ; AVX512F-ONLY-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm12 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = <6,u,u,u,7,u,u,7> --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm3, %ymm2, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm15, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm16, %ymm2, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm16, %zmm3, %zmm3 - ; AVX512F-ONLY-FAST-NEXT: vbroadcasti64x4 {{.*#+}} zmm5 = [0,13,4,0,0,14,5,0,0,13,4,0,0,14,5,0] - ; AVX512F-ONLY-FAST-NEXT: # zmm5 = mem[0,1,2,3,0,1,2,3] - ; AVX512F-ONLY-FAST-NEXT: vpermd %zmm3, %zmm5, %zmm3 -@@ -5844,7 +5844,7 @@ - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm5 # 32-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm9 # 32-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm5, %zmm11, %zmm9 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm23, %zmm18, %zmm19 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm23, %zmm17, %zmm19 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm30 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm30 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload -@@ -5872,336 +5872,330 @@ - ; - ; AVX512DQ-SLOW-LABEL: store_i16_stride7_vf32: - ; AVX512DQ-SLOW: # %bb.0: --; AVX512DQ-SLOW-NEXT: subq $632, %rsp # imm = 0x278 -+; AVX512DQ-SLOW-NEXT: subq $648, %rsp # imm = 0x288 - ; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %ymm1 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] --; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm1, %ymm0 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm1, %ymm16 --; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %ymm2 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm13 = --; AVX512DQ-SLOW-NEXT: vpshufb %ymm13, %ymm2, %ymm1 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 --; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %ymm1 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm12 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] --; AVX512DQ-SLOW-NEXT: vpshufb %ymm12, %ymm1, %ymm0 --; AVX512DQ-SLOW-NEXT: vmovdqa %ymm1, %ymm15 --; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %ymm2 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> --; AVX512DQ-SLOW-NEXT: vpshufb %ymm14, %ymm2, %ymm1 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 --; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %ymm10 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] --; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm2 --; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %ymm11 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = --; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm3 -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, (%rsp) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm2 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %ymm4 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm1 = -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm4, %ymm30 - ; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 - ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %xmm3 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %xmm6 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm11[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm4 = ymm10[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[2,2,2,3,6,6,6,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm2[2],ymm4[3,4],ymm2[5],ymm4[6,7,8,9],ymm2[10],ymm4[11,12],ymm2[13],ymm4[14,15] --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm6[0],xmm3[0],xmm6[1],xmm3[1],xmm6[2],xmm3[2],xmm6[3],xmm3[3] --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm2[0,1,3,2,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm7 --; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %xmm8 --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,5,7,6] --; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm3[0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> --; AVX512DQ-SLOW-NEXT: vpermi2d %zmm2, %zmm3, %zmm4 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %ymm2 -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm4 = [128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128,128,128,128,128,128,128,128,128] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm4, %ymm2, %ymm2 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %ymm5 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = <12,13,14,15,128,128,u,u,u,u,u,u,u,u,u,u,16,17,128,128,u,u,u,u,u,u,u,u,16,17,18,19> -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm5, %ymm3 -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm31 -+; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm3, %ymm2 -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %ymm15 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm2, %ymm15, %ymm5 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %ymm6 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm3 = -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm3, %ymm6, %ymm7 -+; AVX512DQ-SLOW-NEXT: vpor %ymm5, %ymm7, %ymm5 -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %xmm7 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %xmm8 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm4, %ymm5, %ymm4 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %ymm11 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm11, %ymm9 -+; AVX512DQ-SLOW-NEXT: vpor %ymm4, %ymm9, %ymm4 - ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %ymm3 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %ymm6 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm12, %ymm6, %ymm2 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm14, %ymm3, %ymm4 --; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 --; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %ymm12 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %ymm7 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm9, %ymm12, %ymm2 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm13, %ymm7, %ymm4 --; AVX512DQ-SLOW-NEXT: vpor %ymm2, %ymm4, %ymm2 --; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm13, %ymm1 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm14, %ymm0 --; AVX512DQ-SLOW-NEXT: vpor %ymm1, %ymm0, %ymm0 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %ymm13 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm0, %ymm13, %ymm0 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %ymm14 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm1, %ymm14, %ymm1 -+; AVX512DQ-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %xmm0 --; AVX512DQ-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm2 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] --; AVX512DQ-SLOW-NEXT: vpshufb %xmm2, %xmm8, %xmm1 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm2, %xmm20 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm2 = xmm0[1,1,2,2] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm2[0],xmm1[1],xmm2[2,3],xmm1[4],xmm2[5,6],xmm1[7] --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm0[0],xmm8[0],xmm0[1],xmm8[1],xmm0[2],xmm8[2],xmm0[3],xmm8[3] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm2 = xmm2[0,1,3,2,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm4 = --; AVX512DQ-SLOW-NEXT: vpermi2d %zmm2, %zmm1, %zmm4 --; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %xmm1 --; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %xmm2 --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm5 = xmm4[0,1,2,3,4,5,7,6] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm4 = xmm4[0,1,3,2,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> --; AVX512DQ-SLOW-NEXT: vpermi2d %zmm5, %zmm4, %zmm9 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r9), %ymm4 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm2, %ymm4, %ymm1 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%r8), %ymm0 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm3, %ymm0, %ymm2 -+; AVX512DQ-SLOW-NEXT: vpor %ymm1, %ymm2, %ymm1 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm6[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm2 = ymm15[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm2 = ymm2[2,2,2,3,6,6,6,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm1[2],ymm2[3,4],ymm1[5],ymm2[6,7,8,9],ymm1[10],ymm2[11,12],ymm1[13],ymm2[14,15] -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm8[0],xmm7[0],xmm8[1],xmm7[1],xmm8[2],xmm7[2],xmm8[3],xmm7[3] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm9 = [4,5,4,5,4,5,6,7,16,17,16,17,16,17,17,19] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm1[0,1,3,2,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpermi2d %zmm3, %zmm2, %zmm9 - ; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm9, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vprold $16, %ymm10, %ymm4 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[1,2,2,3,5,6,6,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm5[0,1],ymm4[2],ymm5[3,4],ymm4[5],ymm5[6,7,8,9],ymm4[10],ymm5[11,12],ymm4[13],ymm5[14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm11[0,0,2,1,4,4,6,5] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm9 = ymm10[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm9[0,0,0,0,4,4,4,4] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm9[0,1,2],ymm5[3],ymm9[4,5],ymm5[6],ymm9[7,8,9,10],ymm5[11],ymm9[12,13],ymm5[14],ymm9[15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm27 = [2,2,3,3,10,9,11,10] --; AVX512DQ-SLOW-NEXT: vpermi2q %zmm4, %zmm5, %zmm27 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rcx), %xmm9 -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm8[4],xmm7[4],xmm8[5],xmm7[5],xmm8[6],xmm7[6],xmm8[7],xmm7[7] -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,4,5,7,6] -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm2, %xmm2 -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm3 = <16,18,19,19,19,19,u,u,0,1,0,1,2,3,2,3> -+; AVX512DQ-SLOW-NEXT: vpermi2d %zmm1, %zmm2, %zmm3 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rdx), %xmm2 -+; AVX512DQ-SLOW-NEXT: vpbroadcastq {{.*#+}} xmm7 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm9, %xmm1 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm2[1,1,2,2] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2,3],xmm1[4],xmm3[5,6],xmm1[7] -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm9[0],xmm2[1],xmm9[1],xmm2[2],xmm9[2],xmm2[3],xmm9[3] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm3 = xmm3[0,1,3,2,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = -+; AVX512DQ-SLOW-NEXT: vpermi2d %zmm3, %zmm1, %zmm8 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa (%r9), %xmm3 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%r8), %xmm12 -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm12[0],xmm3[0],xmm12[1],xmm3[1],xmm12[2],xmm3[2],xmm12[3],xmm3[3] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm1[0,1,2,3,4,5,7,6] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm1 = xmm1[0,1,3,2,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm26 = <0,1,0,1,0,1,1,3,16,18,19,19,19,19,u,u> -+; AVX512DQ-SLOW-NEXT: vpermi2d %zmm8, %zmm1, %zmm26 -+; AVX512DQ-SLOW-NEXT: vprold $16, %ymm15, %ymm1 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm6[1,2,2,3,5,6,6,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm8[0,1],ymm1[2],ymm8[3,4],ymm1[5],ymm8[6,7,8,9],ymm1[10],ymm8[11,12],ymm1[13],ymm8[14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm8 = ymm15[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[0,0,0,0,4,4,4,4] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm8[0,1,2],ymm6[3],ymm8[4,5],ymm6[6],ymm8[7,8,9,10],ymm6[11],ymm8[12,13],ymm6[14],ymm8[15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] -+; AVX512DQ-SLOW-NEXT: vpermi2q %zmm1, %zmm6, %zmm28 - ; AVX512DQ-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 --; AVX512DQ-SLOW-NEXT: vmovdqa (%rax), %ymm4 --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm0 --; AVX512DQ-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm5 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] --; AVX512DQ-SLOW-NEXT: vpshufb %ymm5, %ymm4, %ymm4 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm0 --; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] --; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm30 --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm12[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm7[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm9[4],xmm2[5],xmm9[5],xmm2[6],xmm9[6],xmm2[7],xmm9[7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm1, %xmm25 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rax), %ymm8 -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm8, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vpbroadcastd 8(%rax), %ymm1 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512DQ-SLOW-NEXT: vpandn %ymm1, %ymm2, %ymm1 -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm6 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm6, %ymm8, %ymm2 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm12 = xmm12[4],xmm3[4],xmm12[5],xmm3[5],xmm12[6],xmm3[6],xmm12[7],xmm3[7] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm5[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,2,1,4,4,6,5] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[1,1,1,1,5,5,5,5] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm3[2],ymm1[3,4],ymm3[5],ymm1[6,7,8,9],ymm3[10],ymm1[11,12],ymm3[13],ymm1[14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[0,0,0,0,4,4,4,4] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,1,1,3,4,5,5,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm1 = ymm3[0,1],ymm1[2],ymm3[3,4],ymm1[5],ymm3[6,7,8,9],ymm1[10],ymm3[11,12],ymm1[13],ymm3[14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rax), %ymm1 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm1[0,1,1,3,4,5,5,7] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[0,2,2,3] -+; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] -+; AVX512DQ-SLOW-NEXT: vpandn %ymm3, %ymm8, %ymm3 -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm6, %ymm1, %ymm6 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm2 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm0[0,0,2,1,4,4,6,5] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1,2],ymm6[3],ymm3[4,5],ymm6[6],ymm3[7,8,9,10],ymm6[11],ymm3[12,13],ymm6[14],ymm3[15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 -+; AVX512DQ-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm15 = [22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27,22,23,26,27,0,0,24,25,26,27,0,0,26,27,26,27] -+; AVX512DQ-SLOW-NEXT: # ymm15 = mem[0,1,0,1] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm15, %ymm13, %ymm3 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm14[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm22 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm6 = ymm13[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm3[0],ymm6[1],ymm3[2,3],ymm6[4],ymm3[5,6,7,8],ymm6[9],ymm3[10,11],ymm6[12],ymm3[13,14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm5[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm11[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm6[0],ymm3[1],ymm6[2,3],ymm3[4],ymm6[5,6,7,8],ymm3[9],ymm6[10,11],ymm3[12],ymm6[13,14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm11[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm5 = ymm5[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm5[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm13 = ymm3[0,1,2],ymm5[3],ymm3[4,5],ymm5[6],ymm3[7,8,9,10],ymm5[11],ymm3[12,13],ymm5[14],ymm3[15] -+; AVX512DQ-SLOW-NEXT: vprold $16, %ymm4, %ymm3 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm0[1,2,2,3,5,6,6,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0,1],ymm3[2],ymm5[3,4],ymm3[5],ymm5[6,7,8,9],ymm3[10],ymm5[11,12],ymm3[13],ymm5[14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm24 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm4[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,2,3,6,6,6,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] - ; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm23 --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm10 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm6[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm3[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 --; AVX512DQ-SLOW-NEXT: vprold $16, %ymm13, %ymm0 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[1,2,2,3,5,6,6,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm1[0,1],ymm0[2],ymm1[3,4],ymm0[5],ymm1[6,7,8,9],ymm0[10],ymm1[11,12],ymm0[13],ymm1[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm13[0,1,2,3,5,6,7,7,8,9,10,11,13,14,15,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,3,6,6,6,7] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm14[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm0 = ymm0[0,1],ymm1[2],ymm0[3,4],ymm1[5],ymm0[6,7,8,9],ymm1[10],ymm0[11,12],ymm1[13],ymm0[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm16, %ymm4 --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm4[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm0 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] -+; AVX512DQ-SLOW-NEXT: # zmm0 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpermd %zmm1, %zmm0, %zmm29 -+; AVX512DQ-SLOW-NEXT: vmovdqu (%rsp), %ymm11 # 32-byte Reload -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm11[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] - ; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm29[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm5, %ymm18 --; AVX512DQ-SLOW-NEXT: vmovdqa %ymm15, %ymm8 --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm15[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm30[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm9 = ymm1[0],ymm0[1],ymm1[2,3],ymm0[4],ymm1[5,6,7,8],ymm0[9],ymm1[10,11],ymm0[12],ymm1[13,14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm0 = ymm14[0,1,2,3,7,6,6,7,8,9,10,11,15,14,14,15] - ; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm17[3,3,3,3,7,7,7,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm31[3,3,3,3,7,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1,2],ymm0[3],ymm1[4,5],ymm0[6],ymm1[7,8,9,10],ymm0[11],ymm1[12,13],ymm0[14],ymm1[15] - ; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdi), %xmm0 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %xmm9 --; AVX512DQ-SLOW-NEXT: vprold $16, %xmm9, %xmm1 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm0[1,1,2,3] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm1 = xmm15[0,1],xmm1[2],xmm15[3,4],xmm1[5],xmm15[6,7] --; AVX512DQ-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm1 = xmm0[0],xmm9[0],xmm0[1],xmm9[1],xmm0[2],xmm9[2],xmm0[3],xmm9[3] --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm9[4],xmm0[4],xmm9[5],xmm0[5],xmm9[6],xmm0[6],xmm9[7],xmm0[7] --; AVX512DQ-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %xmm9 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm20, %xmm0 --; AVX512DQ-SLOW-NEXT: vpshufb %xmm0, %xmm9, %xmm0 --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %xmm15 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm11 = xmm15[1,1,2,2] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm11[0],xmm0[1],xmm11[2,3],xmm0[4],xmm11[5,6],xmm0[7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm9[0],xmm15[1],xmm9[1],xmm15[2],xmm9[2],xmm15[3],xmm9[3] --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm9 = xmm15[4],xmm9[4],xmm15[5],xmm9[5],xmm15[6],xmm9[6],xmm15[7],xmm9[7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm9, %xmm25 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[1,1,1,1,5,5,5,5] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm6[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,2,1,4,4,6,5] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm24 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm7[0,1,1,3,4,5,5,7] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm12[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm6[2],ymm3[3,4],ymm6[5],ymm3[6,7,8,9],ymm6[10],ymm3[11,12],ymm6[13],ymm3[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm22 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[0,0,2,1,4,4,6,5] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm6 = ymm13[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,0,0,0,4,4,4,4] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm6[0,1,2],ymm3[3],ymm6[4,5],ymm3[6],ymm6[7,8,9,10],ymm3[11],ymm6[12,13],ymm3[14],ymm6[15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 --; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm3 = [6,5,0,0,7,6,0,7,6,5,0,0,7,6,0,7] --; AVX512DQ-SLOW-NEXT: # zmm3 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rax), %ymm6 --; AVX512DQ-SLOW-NEXT: vpermd %zmm6, %zmm3, %zmm3 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm28, %ymm7 --; AVX512DQ-SLOW-NEXT: vpshufb %ymm7, %ymm6, %ymm11 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[0,1,1,3,4,5,5,7] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,2,2,3] --; AVX512DQ-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm6, %ymm6 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm11, %zmm6 --; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = [0,1,2,3,4,5,4,5,6,7,10,11,8,9,10,11] --; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm11 --; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm1 --; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %xmm13 --; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %xmm14 --; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm2 = xmm13[0],xmm14[0],xmm13[1],xmm14[1],xmm13[2],xmm14[2],xmm13[3],xmm14[3] --; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm2, %xmm2 --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 --; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm15 = xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] --; AVX512DQ-SLOW-NEXT: vprold $16, %xmm14, %xmm14 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm13 = xmm13[1,1,2,3] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm13[0,1],xmm14[2],xmm13[3,4],xmm14[5],xmm13[6,7] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 --; AVX512DQ-SLOW-NEXT: vmovdqa %ymm4, %ymm2 --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm4[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,0,0,0,4,4,4,4] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm29, %ymm12 --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm9 = ymm29[0,1,1,3,4,5,5,7] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm14 = ymm9[0,1],ymm14[2],ymm9[3,4],ymm14[5],ymm9[6,7,8,9],ymm14[10],ymm9[11,12],ymm14[13],ymm9[14,15] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm31, %xmm4 --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm4[0,2,3,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 -+; AVX512DQ-SLOW-NEXT: vprold $16, %xmm1, %xmm3 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm0[1,1,2,3] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm2 = xmm6[0,1],xmm3[2],xmm6[3,4],xmm3[5],xmm6[6,7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm2, %ymm21 -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm20 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rcx), %xmm1 -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm7, %xmm1, %xmm0 -+; AVX512DQ-SLOW-NEXT: vmovdqa 32(%rdx), %xmm6 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm7 = xmm6[1,1,2,2] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm7[0],xmm0[1],xmm7[2,3],xmm0[4],xmm7[5,6],xmm0[7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %ymm0, %ymm19 -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm5 = xmm6[0],xmm1[0],xmm6[1],xmm1[1],xmm6[2],xmm1[2],xmm6[3],xmm1[3] -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm1[4],xmm6[5],xmm1[5],xmm6[6],xmm1[6],xmm6[7],xmm1[7] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm0, %xmm18 -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm12, %xmm2 -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm3, %xmm4 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rdi), %xmm3 -+; AVX512DQ-SLOW-NEXT: vmovdqa (%rsi), %xmm6 -+; AVX512DQ-SLOW-NEXT: vpunpcklwd {{.*#+}} xmm7 = xmm3[0],xmm6[0],xmm3[1],xmm6[1],xmm3[2],xmm6[2],xmm3[3],xmm6[3] -+; AVX512DQ-SLOW-NEXT: vpshufb %xmm10, %xmm7, %xmm10 -+; AVX512DQ-SLOW-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm6[4],xmm3[4],xmm6[5],xmm3[5],xmm6[6],xmm3[6],xmm6[7],xmm3[7] -+; AVX512DQ-SLOW-NEXT: vprold $16, %xmm6, %xmm6 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm3 = xmm3[1,1,2,3] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} xmm7 = xmm3[0,1],xmm6[2],xmm3[3,4],xmm6[5],xmm3[6,7] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm11[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,0,0,4,4,4,4] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm30[0,1,1,3,4,5,5,7] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm6 = ymm6[0,1],ymm3[2],ymm6[3,4],ymm3[5],ymm6[6,7,8,9],ymm3[10],ymm6[11,12],ymm3[13],ymm6[14,15] -+; AVX512DQ-SLOW-NEXT: vpshufb %ymm15, %ymm11, %ymm3 -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm30[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm15 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm3 = ymm14[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[0,0,2,1,4,4,6,5] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[1,1,1,1,5,5,5,5] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1],ymm12[2],ymm3[3,4],ymm12[5],ymm3[6,7,8,9],ymm12[10],ymm3[11,12],ymm12[13],ymm3[14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm31[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm14 = ymm14[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[2,2,2,2,6,6,6,6] -+; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm12 = ymm12[0],ymm14[1],ymm12[2,3],ymm14[4],ymm12[5,6,7,8],ymm14[9],ymm12[10,11],ymm14[12],ymm12[13,14,15] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm25, %xmm1 -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm14 = xmm1[0,2,3,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,2,1] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[2,1,2,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,4] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm0[0,0,1,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[0,0,1,1] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm22[2,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm17[0,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm16[0,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,1,3,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm5 = xmm5[0,1,3,2,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm5 = xmm5[0,0,1,1] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,1,3,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,0,1,1] -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm30 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm0 -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm13 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] -+; AVX512DQ-SLOW-NEXT: vpternlogq $184, %zmm30, %zmm13, %zmm0 -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,1,3] -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm9, %zmm5 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 -+; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm13, %zmm4 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm14, %zmm5 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm8 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm31 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] -+; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm5, %zmm31, %zmm8 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm1 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm2[0,1,2,3],zmm1[4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: # ymm2 = mem[2,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: # ymm5 = mem[2,1,3,2] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm27[2,2,3,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm24[2,1,3,2] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm16 = ymm23[2,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload -+; AVX512DQ-SLOW-NEXT: # ymm17 = mem[2,3,3,3,6,7,7,7] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm21[0,0,2,1] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm20, %xmm9 -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm11 = xmm9[2,1,2,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm11 = xmm11[0,1,2,3,4,5,5,4] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm19[0,0,1,1] -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm18, %xmm9 -+; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm9 = xmm9[0,2,3,3,4,5,6,7] - ; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,2,1] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm15 = xmm15[2,1,2,3,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm15 = xmm15[0,1,2,3,4,5,5,4] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[0,0,1,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[0,0,1,1] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm30[2,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm23[0,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm19[2,1,3,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm18[0,2,2,3] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm0 = xmm0[0,1,3,2,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,1,1] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm5[2,1,3,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm1 = ymm1[0,0,1,1] --; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,u,u,24,25,26,27,u,u,26,27,26,27] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm12[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm4[2],ymm2[3,4],ymm4[5],ymm2[6,7,8,9],ymm4[10],ymm2[11,12],ymm4[13],ymm2[14,15] --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} ymm4 = ymm8[1,2,3,3,4,5,6,7,9,10,11,11,12,13,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,2,1,4,4,6,5] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[1,1,1,1,5,5,5,5] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm4 = ymm4[0,1],ymm5[2],ymm4[3,4],ymm5[5],ymm4[6,7,8,9],ymm5[10],ymm4[11,12],ymm5[13],ymm4[14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm5 = ymm17[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} ymm12 = ymm8[0,1,2,3,5,4,6,7,8,9,10,11,13,12,14,15] --; AVX512DQ-SLOW-NEXT: vpshufd {{.*#+}} ymm12 = ymm12[2,2,2,2,6,6,6,6] --; AVX512DQ-SLOW-NEXT: vpblendw {{.*#+}} ymm5 = ymm5[0],ymm12[1],ymm5[2,3],ymm12[4],ymm5[5,6,7,8],ymm12[9],ymm5[10,11],ymm12[12],ymm5[13,14,15] --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm28, %zmm12 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm7 --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] --; AVX512DQ-SLOW-NEXT: vpternlogq $184, %zmm12, %zmm10, %zmm7 --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,1,1,3] --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm30, %zmm0 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm31, %zmm1 --; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm1 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm0 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm15, %zmm9 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm10 = [65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535] --; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm10, %zmm9 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm11[0,1,2,3],zmm0[4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpermq $182, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: # ymm11 = mem[2,1,3,2] --; AVX512DQ-SLOW-NEXT: vpermq $234, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: # ymm12 = mem[2,2,2,3] --; AVX512DQ-SLOW-NEXT: vpshufd $254, {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: # ymm15 = mem[2,3,3,3,6,7,7,7] --; AVX512DQ-SLOW-NEXT: vpermq $96, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload --; AVX512DQ-SLOW-NEXT: # ymm17 = mem[0,0,2,1] --; AVX512DQ-SLOW-NEXT: vpshuflw $230, {{[-0-9]+}}(%r{{[sb]}}p), %xmm8 # 16-byte Folded Reload --; AVX512DQ-SLOW-NEXT: # xmm8 = mem[2,1,2,3,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpshufhw {{.*#+}} xmm8 = xmm8[0,1,2,3,4,5,5,4] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[0,0,1,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm26[0,0,1,1] --; AVX512DQ-SLOW-NEXT: vmovdqa64 %xmm25, %xmm13 --; AVX512DQ-SLOW-NEXT: vpshuflw {{.*#+}} xmm13 = xmm13[0,2,3,3,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,0,2,1] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm24[2,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm28 = ymm22[2,1,3,2] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm29 = ymm21[2,2,3,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm30 = ymm20[0,0,1,1] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm31 = ymm16[0,0,2,1] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,1,3,2] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,2,3] --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,2,2,3] --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm16 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm16, %zmm0 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm0 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm11, %zmm9 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm3 --; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm15[2,1,3,2] --; AVX512DQ-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm9 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm7, %zmm7 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm7 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm7 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm17, %zmm1 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm18, %zmm8 --; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm1, %zmm10, %zmm8 --; AVX512DQ-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm1 --; AVX512DQ-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm9 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm1, %zmm1 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm1 -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,1] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,0,2,1] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,1,3,2] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,2,2,3] -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[0,2,2,3] -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm18 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm1 - ; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm1 - ; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm19, %zmm8, %zmm8 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm28, %zmm9, %zmm9 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm29, %zmm0, %zmm8 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm10[0,1,2,3],zmm8[4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm6 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm6 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm31, %zmm30, %zmm8 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm9 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm8 --; AVX512DQ-SLOW-NEXT: vpbroadcastd (%rax), %ymm9 --; AVX512DQ-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm10 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm10, %zmm9, %zmm9 --; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm10, %zmm9 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm14, %zmm2 --; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm4, %zmm4 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm4 --; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm2 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] --; AVX512DQ-SLOW-NEXT: # zmm2 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] --; AVX512DQ-SLOW-NEXT: vpermd (%rax), %zmm2, %zmm2 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm27, %zmm2 --; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm2 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm8, %zmm2 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm2 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm8[0,1,2,3],zmm2[4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm2 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm16, %zmm14, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm29 -+; AVX512DQ-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm17[2,1,3,2] -+; AVX512DQ-SLOW-NEXT: vpbroadcastd 32(%rax), %ymm5 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm0, %zmm0 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm0 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm0 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm22, %zmm4 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm9, %zmm30, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm31, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpbroadcastd 36(%rax), %ymm4 -+; AVX512DQ-SLOW-NEXT: vpbroadcastd 40(%rax), %ymm8 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm4, %zmm4 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm10, %zmm5 -+; AVX512DQ-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload -+; AVX512DQ-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpbroadcastd (%rax), %ymm7 -+; AVX512DQ-SLOW-NEXT: vpbroadcastd 4(%rax), %ymm8 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm7, %zmm7 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm7 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm7 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm15, %zmm6, %zmm5 -+; AVX512DQ-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm3, %zmm3 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm3 -+; AVX512DQ-SLOW-NEXT: vbroadcasti32x8 {{.*#+}} zmm5 = [0,5,4,0,0,6,5,0,0,5,4,0,0,6,5,0] -+; AVX512DQ-SLOW-NEXT: # zmm5 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] -+; AVX512DQ-SLOW-NEXT: vpermd (%rax), %zmm5, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm28, %zmm5 -+; AVX512DQ-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm5 - ; AVX512DQ-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm2, 128(%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm9, (%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm6, 320(%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm1, 256(%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm7, 192(%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm0, 64(%rax) --; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm3, 384(%rax) --; AVX512DQ-SLOW-NEXT: addq $632, %rsp # imm = 0x278 -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm5, 128(%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm7, (%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm2, 320(%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm4, 256(%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm0, 192(%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm1, 64(%rax) -+; AVX512DQ-SLOW-NEXT: vmovdqa64 %zmm29, 384(%rax) -+; AVX512DQ-SLOW-NEXT: addq $648, %rsp # imm = 0x288 - ; AVX512DQ-SLOW-NEXT: vzeroupper - ; AVX512DQ-SLOW-NEXT: retq - ; -@@ -6235,9 +6229,9 @@ - ; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm13 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128,128,128] - ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm13, %ymm6 --; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm14 -+; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm15 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = --; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm14, %ymm7 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm15, %ymm7 - ; AVX512DQ-FAST-NEXT: vporq %ymm6, %ymm7, %ymm25 - ; AVX512DQ-FAST-NEXT: vpshufb %ymm4, %ymm10, %ymm4 - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm6 -@@ -6251,8 +6245,9 @@ - ; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 - ; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm15 --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm15, %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqa %ymm1, %ymm14 - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %ymm4 - ; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm4, %ymm1 - ; AVX512DQ-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 -@@ -6283,11 +6278,11 @@ - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 - ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] - ; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm3, %xmm3 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm30 -+; AVX512DQ-FAST-NEXT: vmovdqa %xmm5, %xmm8 - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm19 = [2,1,3,3,8,8,9,9] - ; AVX512DQ-FAST-NEXT: vpermi2q %zmm3, %zmm2, %zmm19 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] --; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[3,3,3,3,7,7,7,7] -+; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[3,3,3,3,7,7,7,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm3[2],ymm2[3,4],ymm3[5],ymm2[6,7,8,9],ymm3[10],ymm2[11,12],ymm3[13],ymm2[14,15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm20 = [2,2,2,3,8,8,8,9] - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm3 -@@ -6315,52 +6310,54 @@ - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm27 = [0,0,0,1,8,9,9,11] - ; AVX512DQ-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm27 - ; AVX512DQ-FAST-NEXT: vprold $16, %ymm13, %ymm0 --; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm14[1,2,2,3,5,6,6,7] -+; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm3 = ymm15[1,2,2,3,5,6,6,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm0 = ymm3[0,1],ymm0[2],ymm3[3,4],ymm0[5],ymm3[6,7,8,9],ymm0[10],ymm3[11,12],ymm0[13],ymm3[14,15] - ; AVX512DQ-FAST-NEXT: vpbroadcastd {{.*#+}} ymm5 = [18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21,18,19,20,21] - ; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm13, %ymm3 --; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm14[0,0,2,1,4,4,6,5] -+; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm7 = ymm15[0,0,2,1,4,4,6,5] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm3 = ymm3[0,1,2],ymm7[3],ymm3[4,5],ymm7[6],ymm3[7,8,9,10],ymm7[11],ymm3[12,13],ymm7[14],ymm3[15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm28 = [2,2,3,3,10,9,11,10] - ; AVX512DQ-FAST-NEXT: vpermi2q %zmm0, %zmm3, %zmm28 --; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm8 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm15 - ; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm0 --; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm8[4],xmm0[5],xmm8[5],xmm0[6],xmm8[6],xmm0[7],xmm8[7] --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm3 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm3, %ymm17 -+; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm3 = xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] -+; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm9 - ; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm25, %zmm0, %zmm2 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm30, %xmm9 --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm1, %xmm1 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm1, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm8, %xmm18 - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm25 = <0,0,1,1,12,13,u,15> - ; AVX512DQ-FAST-NEXT: vpermi2q %zmm2, %zmm1, %zmm25 - ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax - ; AVX512DQ-FAST-NEXT: vpbroadcastd 8(%rax), %ymm1 --; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535] -+; AVX512DQ-FAST-NEXT: vpandn %ymm1, %ymm2, %ymm1 - ; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm3 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [12,13,128,128,128,128,128,128,128,128,128,128,128,128,14,15,128,128,128,128,128,128,128,128,128,128,128,128,16,17,128,128] - ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm3, %ymm7 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm3, %ymm16 - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm1, %zmm30 - ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm1 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18,19,20,21,18,19,20,21,24,25,26,27,22,23,22,23] - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[1,1,1,1,5,5,5,5] --; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 -+; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm8 = ymm1[0,1],ymm6[2],ymm1[3,4],ymm6[5],ymm1[6,7,8,9],ymm6[10],ymm1[11,12],ymm6[13],ymm1[14,15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm31, %ymm13 - ; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} ymm1 = ymm13[0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11] - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm6 = ymm12[0,1,1,3,4,5,5,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm7 = ymm6[0,1],ymm1[2],ymm6[3,4],ymm1[5],ymm6[6,7,8,9],ymm1[10],ymm6[11,12],ymm1[13],ymm6[14,15] --; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm15, %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa %ymm14, %ymm3 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm14, %ymm1 - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm4[0,0,2,1,4,4,6,5] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm1 = ymm1[0,1,2],ymm5[3],ymm1[4,5],ymm5[6],ymm1[7,8,9,10],ymm5[11],ymm1[12,13],ymm5[14],ymm1[15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm29 = <0,1,u,3,10,10,11,11> - ; AVX512DQ-FAST-NEXT: vpermi2q %zmm1, %zmm21, %zmm29 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rax), %ymm6 --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm1 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = --; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm2, %ymm2 --; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm2, %ymm2 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm14 -+; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm1, %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] -+; AVX512DQ-FAST-NEXT: vpandn %ymm1, %ymm5, %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm6, %ymm2 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm14 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm13[14,15,12,13,u,u,u,u,u,u,u,u,u,u,u,u,30,31,28,29,u,u,u,u,30,31,28,29,u,u,u,u] - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm5 = ymm12[3,3,3,3,7,7,7,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm5[0],ymm2[1],ymm5[2,3],ymm2[4],ymm5[5,6,7,8],ymm2[9],ymm5[10,11],ymm2[12],ymm5[13,14,15] -@@ -6371,21 +6368,21 @@ - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm13 = ymm13[0,1],ymm12[2],ymm13[3,4],ymm12[5],ymm13[6,7,8,9],ymm12[10],ymm13[11,12],ymm12[13],ymm13[14,15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm21 = [2,2,2,3,8,10,10,11] - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm13 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,26,27,28,29,26,27,28,29,26,27,28,29,30,31,30,31] - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm12 = ymm4[3,3,3,3,7,7,7,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm2 = ymm2[0,1],ymm12[2],ymm2[3,4],ymm12[5],ymm2[6,7,8,9],ymm12[10],ymm2[11,12],ymm12[13],ymm2[14,15] --; AVX512DQ-FAST-NEXT: vprold $16, %ymm15, %ymm12 -+; AVX512DQ-FAST-NEXT: vprold $16, %ymm3, %ymm12 - ; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[1,2,2,3,5,6,6,7] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm12 = ymm4[0,1],ymm12[2],ymm4[3,4],ymm12[5],ymm4[6,7,8,9],ymm12[10],ymm4[11,12],ymm12[13],ymm4[14,15] - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = [2,1,3,2,10,10,10,11] - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm31, %zmm12 --; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm18 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm22, %zmm18, %zmm13 -+; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm17 = [65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535,0,0,65535,65535,65535,65535,65535] -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm22, %zmm17, %zmm13 - ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm12 --; AVX512DQ-FAST-NEXT: vmovdqa64 (%rax), %zmm15 -+; AVX512DQ-FAST-NEXT: vmovdqa64 (%rax), %zmm3 - ; AVX512DQ-FAST-NEXT: vbroadcasti32x8 {{.*#+}} zmm4 = [30,5,0,0,31,6,0,31,30,5,0,0,31,6,0,31] - ; AVX512DQ-FAST-NEXT: # zmm4 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] --; AVX512DQ-FAST-NEXT: vpermi2d %zmm15, %zmm6, %zmm4 -+; AVX512DQ-FAST-NEXT: vpermi2d %zmm3, %zmm6, %zmm4 - ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm4 - ; AVX512DQ-FAST-NEXT: vpunpckhwd {{.*#+}} xmm6 = xmm11[4],xmm10[4],xmm11[5],xmm10[5],xmm11[6],xmm10[6],xmm11[7],xmm10[7] - ; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} xmm12 = [6,7,4,5,0,0,8,9,6,7,4,5,0,0,8,9] -@@ -6400,14 +6397,15 @@ - ; AVX512DQ-FAST-NEXT: # xmm6 = xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] - ; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm24, %xmm1 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm13 = xmm1[0,1,2,3,8,9,10,11,14,15,12,13,14,15,12,13] --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm6, %xmm6 - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm22 = [0,1,1,3,8,8,9,9] - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm6, %zmm22, %zmm13 - ; AVX512DQ-FAST-NEXT: vprold $16, %xmm0, %xmm6 --; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm8[1,1,2,3] -+; AVX512DQ-FAST-NEXT: vpshufd {{.*#+}} xmm2 = xmm15[1,1,2,3] - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} xmm2 = xmm2[0,1],xmm6[2],xmm2[3,4],xmm6[5],xmm2[6,7] --; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm8[0],xmm0[0],xmm8[1],xmm0[1],xmm8[2],xmm0[2],xmm8[3],xmm0[3] --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 -+; AVX512DQ-FAST-NEXT: vpunpcklwd {{.*#+}} xmm0 = xmm15[0],xmm0[0],xmm15[1],xmm0[1],xmm15[2],xmm0[2],xmm15[3],xmm0[3] -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm11, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm2 - ; AVX512DQ-FAST-NEXT: vpshufb %xmm12, %xmm2, %xmm6 -@@ -6436,8 +6434,8 @@ - ; AVX512DQ-FAST-NEXT: vpblendw {{.*#+}} ymm5 = ymm10[0,1],ymm5[2],ymm10[3,4],ymm5[5],ymm10[6,7,8,9],ymm5[10],ymm10[11,12],ymm5[13],ymm10[14,15] - ; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm10 = xmm12[0,2,3,3,4,5,6,7] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,2,1] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm17[0,0,1,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm16[2,2,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[0,0,1,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,2,3] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,1,3,2] - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm0, %zmm31, %zmm5 - ; AVX512DQ-FAST-NEXT: vpbroadcastd (%rax), %ymm0 -@@ -6456,8 +6454,8 @@ - ; AVX512DQ-FAST-NEXT: vpermt2q %zmm2, %zmm21, %zmm12 - ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = <6,u,u,u,7,u,u,7> --; AVX512DQ-FAST-NEXT: vpermd %ymm3, %ymm2, %ymm2 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm15, %zmm3 -+; AVX512DQ-FAST-NEXT: vpermd %ymm16, %ymm2, %ymm2 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm16, %zmm3, %zmm3 - ; AVX512DQ-FAST-NEXT: vbroadcasti32x8 {{.*#+}} zmm5 = [0,13,4,0,0,14,5,0,0,13,4,0,0,14,5,0] - ; AVX512DQ-FAST-NEXT: # zmm5 = mem[0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7] - ; AVX512DQ-FAST-NEXT: vpermd %zmm3, %zmm5, %zmm3 -@@ -6466,7 +6464,7 @@ - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm5 # 32-byte Folded Reload - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm9, %zmm9 # 32-byte Folded Reload - ; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm5, %zmm11, %zmm9 --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm23, %zmm18, %zmm19 -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm23, %zmm17, %zmm19 - ; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm30 - ; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm30 - ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Reload -@@ -11708,7 +11706,7 @@ - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 96(%r8), %ymm12 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm12, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17,u,u,u,u],zero,zero -+; AVX512F-ONLY-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15],zero,zero,ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17],zero,zero,ymm12[u,u],zero,zero - ; AVX512F-ONLY-SLOW-NEXT: vpternlogq $248, %ymm9, %ymm7, %ymm6 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqa 96(%r9), %ymm15 - ; AVX512F-ONLY-SLOW-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -@@ -12367,7 +12365,7 @@ - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 96(%r8), %ymm8 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17,u,u,u,u],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15],zero,zero,ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17],zero,zero,ymm8[u,u],zero,zero - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm13, %ymm9, %ymm12 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm13, %ymm14 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 96(%r9), %ymm13 -@@ -13031,7 +13029,7 @@ - ; AVX512DQ-SLOW-NEXT: vmovdqa {{.*#+}} ymm9 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] - ; AVX512DQ-SLOW-NEXT: vmovdqa 96(%r8), %ymm12 - ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm12, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17,u,u,u,u],zero,zero -+; AVX512DQ-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[14,15],zero,zero,ymm12[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm12[16,17],zero,zero,ymm12[u,u],zero,zero - ; AVX512DQ-SLOW-NEXT: vpternlogq $248, %ymm9, %ymm7, %ymm6 - ; AVX512DQ-SLOW-NEXT: vmovdqa 96(%r9), %ymm15 - ; AVX512DQ-SLOW-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -@@ -13690,7 +13688,7 @@ - ; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm9 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [65535,65535,65535,65535,65535,0,65535,65535,65535,65535,65535,65535,0,65535,65535,65535] - ; AVX512DQ-FAST-NEXT: vmovdqa 96(%r8), %ymm8 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15,u,u,u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17,u,u,u,u],zero,zero -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[14,15],zero,zero,ymm8[u,u],zero,zero,zero,zero,zero,zero,zero,zero,ymm8[16,17],zero,zero,ymm8[u,u],zero,zero - ; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm13, %ymm9, %ymm12 - ; AVX512DQ-FAST-NEXT: vmovdqa %ymm13, %ymm14 - ; AVX512DQ-FAST-NEXT: vmovdqa 96(%r9), %ymm13 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-5.ll -@@ -3967,7 +3967,8 @@ - ; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm0, %ymm1 - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm0 = ymm0[0,2,1,1,4,6,5,5] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,3,2] --; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm14 = [255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255] -+; AVX512F-SLOW-NEXT: vpandn %ymm0, %ymm14, %ymm0 - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm25 - ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm0 = [9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12] - ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm9, %ymm9 -@@ -4063,63 +4064,63 @@ - ; - ; AVX512F-FAST-LABEL: store_i8_stride5_vf64: - ; AVX512F-FAST: # %bb.0: --; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %ymm7 -+; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %ymm6 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm13 = [128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128,128,128,128] --; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm7, %ymm0 --; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 --; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = <12,13,128,15,12,13,14,128,12,13,14,15,128,u,u,u,16,128,18,19,16,17,128,19,16,17,18,128,16,17,18,19> --; AVX512F-FAST-NEXT: vpshufb %ymm6, %ymm3, %ymm1 -+; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm6, %ymm0 -+; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 -+; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = <12,13,128,15,12,13,14,128,12,13,14,15,128,u,u,u,16,128,18,19,16,17,128,19,16,17,18,128,16,17,18,19> -+; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm1 - ; AVX512F-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-FAST-NEXT: vmovdqa 32(%rdi), %xmm1 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm15 = <8,128,u,7,128,9,128,u,128,u,10,128,12,128,u,11> - ; AVX512F-FAST-NEXT: vpshufb %xmm15, %xmm1, %xmm0 - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm1, %xmm18 --; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %xmm2 -+; AVX512F-FAST-NEXT: vmovdqa 32(%rsi), %xmm3 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <128,8,u,128,7,128,9,u,11,u,128,10,128,12,u,128> --; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm1 -+; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm3, %xmm1 - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm4, %xmm25 --; AVX512F-FAST-NEXT: vmovdqa64 %xmm2, %xmm17 -+; AVX512F-FAST-NEXT: vmovdqa64 %xmm3, %xmm17 - ; AVX512F-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 - ; AVX512F-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-FAST-NEXT: vmovdqa 32(%rcx), %ymm9 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm0 = [128,128,128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128,128] --; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm2 -+; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm3 - ; AVX512F-FAST-NEXT: vmovdqa 32(%rdx), %ymm8 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = - ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm4 --; AVX512F-FAST-NEXT: vpor %ymm2, %ymm4, %ymm2 --; AVX512F-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-FAST-NEXT: vpor %ymm3, %ymm4, %ymm3 -+; AVX512F-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-FAST-NEXT: vmovdqa 32(%rcx), %xmm10 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <128,6,128,8,u,128,7,128,9,128,11,u,128,10,128,12> --; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm10, %xmm2 -+; AVX512F-FAST-NEXT: vpshufb %xmm4, %xmm10, %xmm3 - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm4, %xmm26 - ; AVX512F-FAST-NEXT: vmovdqa 32(%rdx), %xmm11 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = <6,128,8,128,u,7,128,9,128,11,128,u,10,128,12,128> - ; AVX512F-FAST-NEXT: vpshufb %xmm14, %xmm11, %xmm4 --; AVX512F-FAST-NEXT: vporq %xmm2, %xmm4, %xmm21 --; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm7[11,u,u,10,u,12,u,u,u,u,13,u,15,u,u,14,27,u,u,26,u,28,u,u,u,u,29,u,31,u,u,30] --; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm7[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm4, %zmm22 --; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm3[26],zero,ymm3[28],zero,zero,ymm3[27],zero,ymm3[29],zero,ymm3[31],zero,zero,ymm3[30],zero --; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[21],zero,zero,ymm3[20],zero,ymm3[22],zero,ymm3[24],zero,zero,ymm3[23],zero,ymm3[25],zero,zero --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm23 -+; AVX512F-FAST-NEXT: vporq %xmm3, %xmm4, %xmm19 -+; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm6[11,u,u,10,u,12,u,u,u,u,13,u,15,u,u,14,27,u,u,26,u,28,u,u,u,u,29,u,31,u,u,30] -+; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm6[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm22 -+; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[26],zero,ymm2[28],zero,zero,ymm2[27],zero,ymm2[29],zero,ymm2[31],zero,zero,ymm2[30],zero -+; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[21],zero,zero,ymm2[20],zero,ymm2[22],zero,ymm2[24],zero,zero,ymm2[23],zero,ymm2[25],zero,zero -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm23 - ; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm8[27],zero,zero,ymm8[26],zero,ymm8[28],zero,ymm8[30],zero,zero,ymm8[29],zero,ymm8[31],zero,zero - ; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm4 = [128,128,19,128,21,128,128,20,128,22,128,24,128,128,23,128,128,128,19,128,21,128,128,20,128,22,128,24,128,128,23,128] - ; AVX512F-FAST-NEXT: # ymm4 = mem[0,1,0,1] - ; AVX512F-FAST-NEXT: vpshufb %ymm4, %ymm9, %ymm3 - ; AVX512F-FAST-NEXT: vmovdqa64 %ymm4, %ymm30 - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm24 --; AVX512F-FAST-NEXT: vmovdqa (%rcx), %ymm7 --; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm7, %ymm0 --; AVX512F-FAST-NEXT: vmovdqa (%rdx), %ymm12 --; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 --; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm19 -+; AVX512F-FAST-NEXT: vmovdqa (%rcx), %ymm12 -+; AVX512F-FAST-NEXT: vpshufb %ymm0, %ymm12, %ymm0 -+; AVX512F-FAST-NEXT: vmovdqa (%rdx), %ymm6 -+; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm6, %ymm1 -+; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm20 - ; AVX512F-FAST-NEXT: vmovdqa (%rsi), %ymm5 - ; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm5, %ymm0 - ; AVX512F-FAST-NEXT: vmovdqa (%rdi), %ymm4 --; AVX512F-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm1 --; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm20 -+; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm4, %ymm1 -+; AVX512F-FAST-NEXT: vporq %ymm0, %ymm1, %ymm21 - ; AVX512F-FAST-NEXT: vmovdqa (%rdi), %xmm1 - ; AVX512F-FAST-NEXT: vpshufb %xmm15, %xmm1, %xmm0 - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm1, %xmm16 -@@ -4127,9 +4128,9 @@ - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm25, %xmm1 - ; AVX512F-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm2 - ; AVX512F-FAST-NEXT: vporq %xmm0, %xmm2, %xmm28 --; AVX512F-FAST-NEXT: vmovdqa (%rcx), %xmm6 -+; AVX512F-FAST-NEXT: vmovdqa (%rcx), %xmm7 - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm26, %xmm0 --; AVX512F-FAST-NEXT: vpshufb %xmm0, %xmm6, %xmm0 -+; AVX512F-FAST-NEXT: vpshufb %xmm0, %xmm7, %xmm0 - ; AVX512F-FAST-NEXT: vmovdqa (%rdx), %xmm2 - ; AVX512F-FAST-NEXT: vpshufb %xmm14, %xmm2, %xmm14 - ; AVX512F-FAST-NEXT: vporq %xmm0, %xmm14, %xmm29 -@@ -4141,25 +4142,26 @@ - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [12,128,128,128,128,13,128,128,128,128,14,128,128,128,128,15,128,128,128,128,16,128,128,128,128,17,128,128,128,128,18,128] - ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm15 - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm14, %zmm27 --; AVX512F-FAST-NEXT: vmovdqa64 (%r8), %zmm26 -+; AVX512F-FAST-NEXT: vmovdqa64 (%r8), %zmm25 - ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm31 = <4,u,5,5,5,5,u,6,30,30,30,u,31,31,31,31> --; AVX512F-FAST-NEXT: vpermi2d %zmm26, %zmm0, %zmm31 --; AVX512F-FAST-NEXT: vmovdqa (%r8), %ymm0 --; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 -+; AVX512F-FAST-NEXT: vpermi2d %zmm25, %zmm0, %zmm31 - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm15 = <4,u,5,5,5,5,u,6> -+; AVX512F-FAST-NEXT: vmovdqa (%r8), %ymm0 - ; AVX512F-FAST-NEXT: vpermd %ymm0, %ymm15, %ymm15 --; AVX512F-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm15, %ymm15 --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm1, %zmm25 -+; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} ymm26 = [255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255,255,0,255,255,255] -+; AVX512F-FAST-NEXT: vpandnq %ymm15, %ymm26, %ymm15 -+; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm15, %zmm1, %zmm26 - ; AVX512F-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12,9,14,11,0,13,10,15,12] - ; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm9, %ymm9 - ; AVX512F-FAST-NEXT: vmovdqa64 %ymm30, %ymm13 --; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm7, %ymm15 --; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm7, %ymm1 --; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm7 = [18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25,18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25] --; AVX512F-FAST-NEXT: # ymm7 = mem[0,1,0,1] --; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm8, %ymm8 --; AVX512F-FAST-NEXT: vpshufb %ymm7, %ymm12, %ymm7 --; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm12[27],zero,zero,ymm12[26],zero,ymm12[28],zero,ymm12[30],zero,zero,ymm12[29],zero,ymm12[31],zero,zero -+; AVX512F-FAST-NEXT: vpshufb %ymm13, %ymm12, %ymm15 -+; AVX512F-FAST-NEXT: vpshufb %ymm1, %ymm12, %ymm1 -+; AVX512F-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25,18,19,128,21,128,21,20,128,22,128,24,128,22,23,128,25] -+; AVX512F-FAST-NEXT: # ymm12 = mem[0,1,0,1] -+; AVX512F-FAST-NEXT: vpshufb %ymm12, %ymm8, %ymm8 -+; AVX512F-FAST-NEXT: vpshufb %ymm12, %ymm6, %ymm12 -+; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm6[27],zero,zero,ymm6[26],zero,ymm6[28],zero,ymm6[30],zero,zero,ymm6[29],zero,ymm6[31],zero,zero - ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm10 = xmm10[0],xmm11[0],xmm10[1],xmm11[1],xmm10[2],xmm11[2],xmm10[3],xmm11[3],xmm10[4],xmm11[4],xmm10[5],xmm11[5],xmm10[6],xmm11[6],xmm10[7],xmm11[7] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm30 = ymm9[2,2,3,3] - ; AVX512F-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm5[3,u,5,u,u,4,u,6,u,8,u,u,7,u,9,u,19,u,21,u,u,20,u,22,u,24,u,u,23,u,25,u] -@@ -4171,11 +4173,11 @@ - ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm14 = xmm14[0],xmm13[0],xmm14[1],xmm13[1],xmm14[2],xmm13[2],xmm14[3],xmm13[3],xmm14[4],xmm13[4],xmm14[5],xmm13[5],xmm14[6],xmm13[6],xmm14[7],xmm13[7] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,2,3,3] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,2,3,3] --; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,2,3,3] -+; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,2,3,3] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,2,3,3] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,2,3,3] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,2,3,3] --; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,2,3,3] -+; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,2,3,3] - ; AVX512F-FAST-NEXT: vmovdqa64 %xmm16, %xmm13 - ; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm3 = xmm13[0],xmm3[0],xmm13[1],xmm3[1],xmm13[2],xmm3[2],xmm13[3],xmm3[3],xmm13[4],xmm3[4],xmm13[5],xmm3[5],xmm13[6],xmm3[6],xmm13[7],xmm3[7] - ; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,4,5,u,2,3,6,7,10,11,u,8,9,12,13> -@@ -4188,42 +4190,42 @@ - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,2,3,3] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[0,0,1,1] - ; AVX512F-FAST-NEXT: vinserti32x4 $2, %xmm28, %zmm3, %zmm3 --; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm6[0],xmm2[0],xmm6[1],xmm2[1],xmm6[2],xmm2[2],xmm6[3],xmm2[3],xmm6[4],xmm2[4],xmm6[5],xmm2[5],xmm6[6],xmm2[6],xmm6[7],xmm2[7] -+; AVX512F-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm7[0],xmm2[0],xmm7[1],xmm2[1],xmm7[2],xmm2[2],xmm7[3],xmm2[3],xmm7[4],xmm2[4],xmm7[5],xmm2[5],xmm7[6],xmm2[6],xmm7[7],xmm2[7] - ; AVX512F-FAST-NEXT: vpshufb %xmm13, %xmm2, %xmm2 - ; AVX512F-FAST-NEXT: vinserti32x4 $2, %xmm29, %zmm2, %zmm2 --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm26, %zmm0 --; AVX512F-FAST-NEXT: vpermq $80, {{[-0-9]+}}(%r{{[sb]}}p), %ymm6 # 32-byte Folded Reload --; AVX512F-FAST-NEXT: # ymm6 = mem[0,0,1,1] --; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6, %zmm6 # 32-byte Folded Reload --; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm21[0,0,1,1] -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm25, %zmm0 -+; AVX512F-FAST-NEXT: vpermq $80, {{[-0-9]+}}(%r{{[sb]}}p), %ymm7 # 32-byte Folded Reload -+; AVX512F-FAST-NEXT: # ymm7 = mem[0,0,1,1] -+; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm7 # 32-byte Folded Reload -+; AVX512F-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm19[0,0,1,1] - ; AVX512F-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm13 # 32-byte Folded Reload - ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm16 = [255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0] --; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm16, %zmm13 --; AVX512F-FAST-NEXT: vpor %ymm7, %ymm15, %ymm6 --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm19, %zmm6 --; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [18374966859431608575,18374966859431608575,18446463693966278400,18446463693966278400] --; AVX512F-FAST-NEXT: vpternlogq $248, %ymm7, %ymm11, %ymm9 --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm9 --; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm16, %zmm9 --; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm22[2,2,3,3,6,6,7,7] -+; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm16, %zmm13 -+; AVX512F-FAST-NEXT: vpor %ymm15, %ymm12, %ymm7 -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm20, %zmm7 -+; AVX512F-FAST-NEXT: vmovdqa {{.*#+}} ymm12 = [18374966859431608575,18374966859431608575,18446463693966278400,18446463693966278400] -+; AVX512F-FAST-NEXT: vpternlogq $248, %ymm12, %ymm11, %ymm9 -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm21, %zmm9 -+; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm16, %zmm9 -+; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm7 = zmm22[2,2,3,3,6,6,7,7] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm11 = zmm23[2,2,3,3,6,6,7,7] --; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm11 --; AVX512F-FAST-NEXT: vpternlogq $248, %ymm7, %ymm1, %ymm12 --; AVX512F-FAST-NEXT: vpandq %ymm7, %ymm30, %ymm1 -+; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm11 -+; AVX512F-FAST-NEXT: vpternlogq $248, %ymm12, %ymm1, %ymm6 -+; AVX512F-FAST-NEXT: vpandq %ymm12, %ymm30, %ymm1 - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm8, %zmm1 --; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm24[2,2,3,3,6,6,7,7] --; AVX512F-FAST-NEXT: vporq %zmm6, %zmm1, %zmm1 --; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm6 = [0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255] --; AVX512F-FAST-NEXT: vpternlogq $226, %zmm11, %zmm6, %zmm1 --; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm10, %zmm12, %zmm7 -+; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm7 = zmm24[2,2,3,3,6,6,7,7] -+; AVX512F-FAST-NEXT: vporq %zmm7, %zmm1, %zmm1 -+; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm7 = [0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255,0,0,255,255,255] -+; AVX512F-FAST-NEXT: vpternlogq $226, %zmm11, %zmm7, %zmm1 -+; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm10, %zmm6, %zmm6 - ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm4 - ; AVX512F-FAST-NEXT: vinserti64x4 $1, %ymm14, %zmm4, %zmm4 --; AVX512F-FAST-NEXT: vpternlogq $226, %zmm7, %zmm6, %zmm4 -+; AVX512F-FAST-NEXT: vpternlogq $226, %zmm6, %zmm7, %zmm4 - ; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm27 - ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm31 --; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm25 -+; AVX512F-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm26 - ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm1 = <6,6,6,u,7,7,7,7,u,8,8,8,8,u,9,9> --; AVX512F-FAST-NEXT: vpermd %zmm26, %zmm1, %zmm1 -+; AVX512F-FAST-NEXT: vpermd %zmm25, %zmm1, %zmm1 - ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm1 - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm3 = zmm3[0,0,1,1,4,4,5,5] - ; AVX512F-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm2[0,0,1,1,4,4,5,5] -@@ -4231,7 +4233,7 @@ - ; AVX512F-FAST-NEXT: vmovdqa64 {{.*#+}} zmm3 = - ; AVX512F-FAST-NEXT: vpermd %zmm0, %zmm3, %zmm0 - ; AVX512F-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 --; AVX512F-FAST-NEXT: vmovdqa64 %zmm25, 64(%r9) -+; AVX512F-FAST-NEXT: vmovdqa64 %zmm26, 64(%r9) - ; AVX512F-FAST-NEXT: vmovdqa64 %zmm0, (%r9) - ; AVX512F-FAST-NEXT: vmovdqa64 %zmm1, 128(%r9) - ; AVX512F-FAST-NEXT: vmovdqa64 %zmm31, 256(%r9) -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll ---- a/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll -+++ b/llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-7.ll -@@ -7382,14 +7382,14 @@ - ; - ; AVX512F-SLOW-LABEL: store_i8_stride7_vf64: - ; AVX512F-SLOW: # %bb.0: --; AVX512F-SLOW-NEXT: subq $1464, %rsp # imm = 0x5B8 -+; AVX512F-SLOW-NEXT: subq $1416, %rsp # imm = 0x588 - ; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %ymm1 - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] --; AVX512F-SLOW-NEXT: vmovdqa %ymm1, %ymm9 -+; AVX512F-SLOW-NEXT: vmovdqa %ymm1, %ymm12 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %ymm2 - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[0,1,14],zero,ymm2[12,13,0,1,14,15],zero,ymm2[3,12,13,2,3,16],zero,ymm2[30,31,28,29,16,17],zero,ymm2[31,18,19,28,29,18],zero --; AVX512F-SLOW-NEXT: vmovdqa %ymm2, %ymm10 -+; AVX512F-SLOW-NEXT: vmovdqa %ymm2, %ymm9 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -@@ -7400,46 +7400,44 @@ - ; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %ymm8 - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] - ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm1 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm17 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm28 - ; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vmovdqa (%r8), %ymm0 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm2 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] --; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm0 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm20 --; AVX512F-SLOW-NEXT: vmovdqa (%r9), %ymm1 --; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm0[14],zero,zero,zero,zero,zero,zero,ymm0[15],zero,zero,zero,zero,zero,zero,ymm0[16],zero,zero,zero,zero,zero,zero,ymm0[17],zero,zero,zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa (%r9), %ymm2 - ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm3 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] - ; AVX512F-SLOW-NEXT: # ymm3 = mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpshufb %ymm3, %ymm1, %ymm1 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm3, %ymm21 --; AVX512F-SLOW-NEXT: vpor %ymm0, %ymm1, %ymm0 --; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %ymm13 --; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm14 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm14[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm14[27],zero,zero,zero,zero,ymm14[30],zero,ymm14[28],zero,zero,zero,zero,ymm14[31],zero,ymm14[29] --; AVX512F-SLOW-NEXT: vmovdqu %ymm14, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm13[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm13[25],zero,ymm13[23],zero,zero,zero,zero,ymm13[26],zero,ymm13[24],zero,zero --; AVX512F-SLOW-NEXT: vmovdqu %ymm13, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-SLOW-NEXT: vpshufb %ymm3, %ymm2, %ymm1 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm3, %ymm17 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm16 -+; AVX512F-SLOW-NEXT: vporq %ymm0, %ymm1, %ymm23 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %ymm10 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %ymm11 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[27],zero,zero,zero,zero,ymm11[30],zero,ymm11[28],zero,zero,zero,zero,ymm11[31],zero,ymm11[29] -+; AVX512F-SLOW-NEXT: vmovdqu %ymm11, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[25],zero,ymm10[23],zero,zero,zero,zero,ymm10[26],zero,ymm10[24],zero,zero - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %ymm6 --; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %ymm11 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm11[30],zero,ymm11[28],zero,zero,zero,zero,ymm11[31],zero,ymm11[29],zero,zero --; AVX512F-SLOW-NEXT: vmovdqu %ymm11, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %ymm5 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %ymm6 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero -+; AVX512F-SLOW-NEXT: vmovdqu %ymm6, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [128,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128,128,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128] - ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm6, %ymm1 -+; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm1 -+; AVX512F-SLOW-NEXT: vmovdqu %ymm5, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %ymm5 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm5[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm5[30],zero,ymm5[28],zero,zero,zero,zero,ymm5[31],zero,ymm5[29],zero,zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %ymm1 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm1, %ymm21 - ; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %ymm4 - ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm1 = [128,23,128,128,128,128,26,128,24,128,128,128,128,27,128,25,128,23,128,128,128,128,26,128,24,128,128,128,128,27,128,25] - ; AVX512F-SLOW-NEXT: # ymm1 = mem[0,1,0,1] - ; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm4, %ymm3 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm19 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm20 - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-SLOW-NEXT: movq {{[0-9]+}}(%rsp), %rax -@@ -7461,179 +7459,182 @@ - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm8, %ymm18 - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm10, %ymm1 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm9[21],zero,ymm9[19],zero,zero,zero,zero,ymm9[22],zero,ymm9[20],zero,zero -+; AVX512F-SLOW-NEXT: vpshufb %ymm1, %ymm9, %ymm1 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm1 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa (%rax), %ymm3 --; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm3, %ymm0 -+; AVX512F-SLOW-NEXT: vmovdqa (%rax), %ymm1 -+; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm0 - ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm25 = --; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] --; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm1, %zmm25 --; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %xmm2 -+; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm2 = ymm1[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] -+; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm25 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rdx), %xmm3 - ; AVX512F-SLOW-NEXT: vmovdqa 32(%rcx), %xmm15 --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm1 = --; AVX512F-SLOW-NEXT: vpshufb %xmm1, %xmm15, %xmm0 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm1, %xmm16 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm2 = -+; AVX512F-SLOW-NEXT: vpshufb %xmm2, %xmm15, %xmm0 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm2, %xmm19 - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = --; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm2, %xmm1 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm4, %xmm23 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm2, %xmm30 --; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm1, %xmm0 -+; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm3, %xmm2 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm4, %xmm29 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm3, %xmm30 -+; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm2, %xmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %xmm7 --; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %xmm1 --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm8 = --; AVX512F-SLOW-NEXT: vpshufb %xmm8, %xmm1, %xmm0 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rdi), %xmm8 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%rsi), %xmm0 -+; AVX512F-SLOW-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = -+; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm0, %xmm0 - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = --; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm7, %xmm2 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm7, %xmm22 -+; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm8, %xmm2 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm8, %xmm22 - ; AVX512F-SLOW-NEXT: vpor %xmm0, %xmm2, %xmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm7 = <0,u,0,u,2,3,u,1,u,18,u,19,18,u,19,u> -+; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = <0,u,0,u,2,3,u,1,u,18,u,19,18,u,19,u> - ; AVX512F-SLOW-NEXT: vmovdqa 32(%rax), %xmm2 --; AVX512F-SLOW-NEXT: vmovdqa %xmm2, (%rsp) # 16-byte Spill -+; AVX512F-SLOW-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill - ; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm0 = xmm2[0,1,2,3,4,5,5,6] - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm7 --; AVX512F-SLOW-NEXT: vmovdqu64 %zmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-SLOW-NEXT: vpermi2d %zmm0, %zmm2, %zmm8 -+; AVX512F-SLOW-NEXT: vmovdqu64 %zmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-SLOW-NEXT: vmovdqa 32(%r9), %xmm0 --; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %xmm2 --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm7 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> --; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm0, %xmm9 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm0, %xmm31 --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm12 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> --; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm2, %xmm10 --; AVX512F-SLOW-NEXT: vporq %xmm9, %xmm10, %xmm24 -+; AVX512F-SLOW-NEXT: vmovdqa 32(%r8), %xmm13 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm12 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> -+; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm0, %xmm8 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm0, %xmm26 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> -+; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm13, %xmm9 -+; AVX512F-SLOW-NEXT: vporq %xmm8, %xmm9, %xmm24 - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm27, %ymm0 -+; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm5, %ymm8 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm0 - ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm6, %ymm9 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm6, %ymm26 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm0 --; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm11, %ymm10 --; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm0 -+; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm9 = zero,zero,zero,ymm5[14],zero,zero,zero,zero,zero,zero,ymm5[15],zero,zero,zero,zero,zero,zero,ymm5[16],zero,zero,zero,zero,zero,zero,ymm5[17],zero,zero,zero,zero,zero,zero,ymm5[18] --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm27 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm19, %ymm0 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[0,1,14],zero,ymm0[12,13,0,1,14,15],zero,ymm0[3,12,13,2,3,16],zero,ymm0[30,31,28,29,16,17],zero,ymm0[31,18,19,28,29,18],zero --; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm5 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm21, %ymm3 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero,zero,zero,ymm3[18] -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm20, %ymm0 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm9 = ymm0[0,1,14],zero,ymm0[12,13,0,1,14,15],zero,ymm0[3,12,13,2,3,16],zero,ymm0[30,31,28,29,16,17],zero,ymm0[31,18,19,28,29,18],zero -+; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm5 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm20, %ymm9 --; AVX512F-SLOW-NEXT: vpshufb %ymm9, %ymm14, %ymm9 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm21, %ymm6 --; AVX512F-SLOW-NEXT: vpshufb %ymm6, %ymm13, %ymm10 --; AVX512F-SLOW-NEXT: vpor %ymm9, %ymm10, %ymm5 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = zero,zero,zero,zero,zero,zero,ymm11[14],zero,zero,zero,zero,zero,zero,ymm11[15],zero,zero,zero,zero,zero,zero,ymm11[16],zero,zero,zero,zero,zero,zero,ymm11[17],zero,zero,zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm6 -+; AVX512F-SLOW-NEXT: vpshufb %ymm6, %ymm10, %ymm9 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm10, %ymm28 -+; AVX512F-SLOW-NEXT: vpor %ymm8, %ymm9, %ymm5 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %xmm9 --; AVX512F-SLOW-NEXT: vpshufb %xmm8, %xmm9, %xmm8 --; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %xmm10 --; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm10, %xmm4 --; AVX512F-SLOW-NEXT: vporq %xmm8, %xmm4, %xmm21 --; AVX512F-SLOW-NEXT: vmovdqa (%rcx), %xmm5 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm16, %xmm4 --; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm5, %xmm4 --; AVX512F-SLOW-NEXT: vmovdqa %xmm5, %xmm11 --; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm6 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm23, %xmm5 --; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm6, %xmm8 --; AVX512F-SLOW-NEXT: vporq %xmm4, %xmm8, %xmm19 -+; AVX512F-SLOW-NEXT: vmovdqa (%rsi), %xmm6 -+; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm6, %xmm5 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm6, %xmm20 -+; AVX512F-SLOW-NEXT: vmovdqa (%rdi), %xmm9 -+; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm9, %xmm4 -+; AVX512F-SLOW-NEXT: vporq %xmm5, %xmm4, %xmm21 -+; AVX512F-SLOW-NEXT: vmovdqa (%rcx), %xmm2 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm19, %xmm4 -+; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm2, %xmm4 -+; AVX512F-SLOW-NEXT: vmovdqa (%rdx), %xmm10 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm29, %xmm5 -+; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm10, %xmm7 -+; AVX512F-SLOW-NEXT: vporq %xmm4, %xmm7, %xmm19 - ; AVX512F-SLOW-NEXT: vmovdqa (%r9), %xmm5 - ; AVX512F-SLOW-NEXT: vmovdqa %xmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512F-SLOW-NEXT: vpshufb %xmm7, %xmm5, %xmm4 --; AVX512F-SLOW-NEXT: vmovdqa (%r8), %xmm8 --; AVX512F-SLOW-NEXT: vmovdqa %xmm8, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm8, %xmm7 --; AVX512F-SLOW-NEXT: vpor %xmm4, %xmm7, %xmm4 -+; AVX512F-SLOW-NEXT: vpshufb %xmm12, %xmm5, %xmm4 -+; AVX512F-SLOW-NEXT: vmovdqa (%r8), %xmm7 -+; AVX512F-SLOW-NEXT: vmovdqa %xmm7, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm7, %xmm6 -+; AVX512F-SLOW-NEXT: vpor %xmm4, %xmm6, %xmm4 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm4 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm7 = xmm8[8],xmm5[8],xmm8[9],xmm5[9],xmm8[10],xmm5[10],xmm8[11],xmm5[11],xmm8[12],xmm5[12],xmm8[13],xmm5[13],xmm8[14],xmm5[14],xmm8[15],xmm5[15] -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm4 -+; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm6 = xmm7[8],xmm5[8],xmm7[9],xmm5[9],xmm7[10],xmm5[10],xmm7[11],xmm5[11],xmm7[12],xmm5[12],xmm7[13],xmm5[13],xmm7[14],xmm5[14],xmm7[15],xmm5[15] - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm5 = --; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm7, %xmm7 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm5, %xmm29 --; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm7[0,1,0,1],zmm4[4,5,6,7] -+; AVX512F-SLOW-NEXT: vpshufb %xmm5, %xmm6, %xmm6 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm5, %xmm27 -+; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm6[0,1,0,1],zmm4[4,5,6,7] - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vmovdqa (%rax), %xmm13 --; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm4 = xmm13[0,1,2,3,4,5,5,6] -+; AVX512F-SLOW-NEXT: vmovdqa (%rax), %xmm12 -+; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} xmm4 = xmm12[0,1,2,3,4,5,5,6] - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[2,2,3,3] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[0,1,0,1] --; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm4 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm12 = zero,ymm3[13],zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm12, %zmm4, %zmm23 -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm11 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] -+; AVX512F-SLOW-NEXT: vpandn %ymm4, %ymm11, %ymm4 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = zero,ymm1[13],zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm11, %zmm4, %zmm23 - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm18, %ymm4 - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm4 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[30],zero,ymm4[28],zero,zero,zero,zero,ymm4[31],zero,ymm4[29],zero,zero - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm4, %ymm18 - ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm5 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] --; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Reload -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm16, %ymm14 - ; AVX512F-SLOW-NEXT: vpshufb %ymm5, %ymm14, %ymm4 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm29 - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] --; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Reload --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm12 = ymm5[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm5[27],zero,zero,zero,zero,ymm5[30],zero,ymm5[28],zero,zero,zero,zero,ymm5[31],zero,ymm5[29] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] --; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm12 --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm31, %xmm7 --; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm4 = xmm2[0],xmm7[0],xmm2[1],xmm7[1],xmm2[2],xmm7[2],xmm2[3],xmm7[3],xmm2[4],xmm7[4],xmm2[5],xmm7[5],xmm2[6],xmm7[6],xmm2[7],xmm7[7] -+; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Reload -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm8[27],zero,zero,zero,zero,ymm8[30],zero,ymm8[28],zero,zero,zero,zero,ymm8[31],zero,ymm8[29] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm4, %ymm11 -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm26, %xmm7 -+; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm4 = xmm13[0],xmm7[0],xmm13[1],xmm7[1],xmm13[2],xmm7[2],xmm13[3],xmm7[3],xmm13[4],xmm7[4],xmm13[5],xmm7[5],xmm13[6],xmm7[6],xmm13[7],xmm7[7] - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm4 = xmm4[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] --; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm12[0,1,2,3],zmm4[0,1,0,1] -+; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm4 = zmm11[0,1,2,3],zmm4[0,1,0,1] - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm3 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] --; AVX512F-SLOW-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} ymm3 = ymm0[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] -+; AVX512F-SLOW-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-SLOW-NEXT: vpshufhw {{.*#+}} ymm1 = ymm0[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm17 --; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm3 = ymm3[2,2,3,3,6,6,7,7] -+; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[2,2,3,3,6,6,7,7] - ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm4 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm27, %ymm8 --; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm8, %ymm12 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm12, %zmm20 --; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm3[30],zero,ymm3[28],zero,zero,zero,zero,ymm3[31],zero,ymm3[29],zero,zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa %ymm3, %ymm6 -+; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm3, %ymm11 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm11, %zmm26 -+; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm3, %ymm3 -+; AVX512F-SLOW-NEXT: vpshufb %ymm4, %ymm1, %ymm1 - ; AVX512F-SLOW-NEXT: vpshuflw $233, {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Folded Reload - ; AVX512F-SLOW-NEXT: # ymm4 = mem[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm4 = ymm4[0,0,1,1,4,4,5,5] --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm0 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm1, %zmm4, %zmm0 - ; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm30, %xmm0 --; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm3 = xmm0[0],xmm15[0],xmm0[1],xmm15[1],xmm0[2],xmm15[2],xmm0[3],xmm15[3],xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm3, %xmm16 --; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm15[8],xmm0[8],xmm15[9],xmm0[9],xmm15[10],xmm0[10],xmm15[11],xmm0[11],xmm15[12],xmm0[12],xmm15[13],xmm0[13],xmm15[14],xmm0[14],xmm15[15],xmm0[15] --; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm11[8],xmm6[8],xmm11[9],xmm6[9],xmm11[10],xmm6[10],xmm11[11],xmm6[11],xmm11[12],xmm6[12],xmm11[13],xmm6[13],xmm11[14],xmm6[14],xmm11[15],xmm6[15] --; AVX512F-SLOW-NEXT: vmovdqa %xmm11, %xmm12 -+; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm1 = xmm0[0],xmm15[0],xmm0[1],xmm15[1],xmm0[2],xmm15[2],xmm0[3],xmm15[3],xmm0[4],xmm15[4],xmm0[5],xmm15[5],xmm0[6],xmm15[6],xmm0[7],xmm15[7] -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm1, %xmm16 -+; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm15[8],xmm0[8],xmm15[9],xmm0[9],xmm15[10],xmm0[10],xmm15[11],xmm0[11],xmm15[12],xmm0[12],xmm15[13],xmm0[13],xmm15[14],xmm0[14],xmm15[15],xmm0[15] -+; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm2[8],xmm10[8],xmm2[9],xmm10[9],xmm2[10],xmm10[10],xmm2[11],xmm10[11],xmm2[12],xmm10[12],xmm2[13],xmm10[13],xmm2[14],xmm10[14],xmm2[15],xmm10[15] -+; AVX512F-SLOW-NEXT: vmovdqa %xmm2, %xmm11 - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm15 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> - ; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm4, %xmm0 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm3, %xmm3 -+; AVX512F-SLOW-NEXT: vpshufb %xmm15, %xmm1, %xmm1 - ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload --; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm3, %zmm0, %zmm30 -+; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm30 - ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm22, %xmm0 -+; AVX512F-SLOW-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload - ; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm15 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] - ; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] --; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm20, %xmm5 -+; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm5[8],xmm9[8],xmm5[9],xmm9[9],xmm5[10],xmm9[10],xmm5[11],xmm9[11],xmm5[12],xmm9[12],xmm5[13],xmm9[13],xmm5[14],xmm9[14],xmm5[15],xmm9[15] - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm4 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> - ; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm3, %xmm0 - ; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vpshufb %xmm4, %xmm1, %xmm1 - ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload - ; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm22 --; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm2[8],xmm7[8],xmm2[9],xmm7[9],xmm2[10],xmm7[10],xmm2[11],xmm7[11],xmm2[12],xmm7[12],xmm2[13],xmm7[13],xmm2[14],xmm7[14],xmm2[15],xmm7[15] --; AVX512F-SLOW-NEXT: vmovdqa64 %xmm29, %xmm1 -+; AVX512F-SLOW-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm13[8],xmm7[8],xmm13[9],xmm7[9],xmm13[10],xmm7[10],xmm13[11],xmm7[11],xmm13[12],xmm7[12],xmm13[13],xmm7[13],xmm13[14],xmm7[14],xmm13[15],xmm7[15] -+; AVX512F-SLOW-NEXT: vmovdqa64 %xmm27, %xmm1 - ; AVX512F-SLOW-NEXT: vpshufb %xmm1, %xmm0, %xmm0 - ; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm24 = zmm24[0,1,0,1],zmm0[0,1,0,1] - ; AVX512F-SLOW-NEXT: vpbroadcastq {{.*#+}} ymm2 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] - ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload - ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm4 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm26, %ymm0 --; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm7 -+; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload -+; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm13 - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm29 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm31 - ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25],zero,ymm0[23],zero,ymm0[21,22,23,26],zero,ymm0[24],zero,ymm0[28,29,26,27] -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm7 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25],zero,ymm0[23],zero,ymm0[21,22,23,26],zero,ymm0[24],zero,ymm0[28,29,26,27] - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,18],zero,ymm0[18,19,20,21],zero,ymm0[19],zero,ymm0[25,26,27,22],zero,ymm0[20],zero --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm26 --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm8 = ymm8[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm8[21],zero,ymm8[19],zero,zero,zero,zero,ymm8[22],zero,ymm8[20],zero,zero --; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm0 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm20 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm6 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm28, %ymm1 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm29, %ymm0 - ; AVX512F-SLOW-NEXT: vpshufb %ymm0, %ymm1, %ymm3 - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = ymm14[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm14[25],zero,ymm14[23],zero,zero,zero,zero,ymm14[26],zero,ymm14[24],zero,zero - ; AVX512F-SLOW-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -@@ -7647,40 +7648,40 @@ - ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] - ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload - ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm1 --; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm2 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm27 -+; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm2 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm2, %ymm29 - ; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] - ; AVX512F-SLOW-NEXT: # ymm2 = mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm5, %ymm5 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm5, %ymm28 -+; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm8, %ymm8 -+; AVX512F-SLOW-NEXT: vmovdqa64 %ymm8, %ymm27 - ; AVX512F-SLOW-NEXT: vpshufb %ymm2, %ymm0, %ymm0 --; AVX512F-SLOW-NEXT: vmovdqa64 %ymm0, %ymm31 -+; AVX512F-SLOW-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-SLOW-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} ymm2 = zero,ymm0[13],zero,zero,zero,zero,zero,zero,ymm0[14],zero,zero,zero,zero,zero,zero,ymm0[15],zero,zero,zero,zero,zero,zero,ymm0[16],zero,zero,zero,zero,zero,zero,ymm0[17],zero,zero - ; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm14 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm14 = ymm14[0,1,1,3,4,5,5,7] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,2,3,2] --; AVX512F-SLOW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm14, %ymm14 -+; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} ymm28 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] -+; AVX512F-SLOW-NEXT: vpandnq %ymm14, %ymm28, %ymm14 - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm14, %zmm2, %zmm2 --; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm10 = xmm10[0],xmm9[0],xmm10[1],xmm9[1],xmm10[2],xmm9[2],xmm10[3],xmm9[3],xmm10[4],xmm9[4],xmm10[5],xmm9[5],xmm10[6],xmm9[6],xmm10[7],xmm9[7] -+; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm9 = xmm9[0],xmm5[0],xmm9[1],xmm5[1],xmm9[2],xmm5[2],xmm9[3],xmm5[3],xmm9[4],xmm5[4],xmm9[5],xmm5[5],xmm9[6],xmm5[6],xmm9[7],xmm5[7] - ; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> - ; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm15, %xmm15 --; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm10, %xmm10 --; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm21, %zmm10, %zmm0 --; AVX512F-SLOW-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm6[0],xmm12[0],xmm6[1],xmm12[1],xmm6[2],xmm12[2],xmm6[3],xmm12[3],xmm6[4],xmm12[4],xmm6[5],xmm12[5],xmm6[6],xmm12[6],xmm6[7],xmm12[7] --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm14 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> -+; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm9, %xmm9 -+; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm21, %zmm9, %zmm14 -+; AVX512F-SLOW-NEXT: vpunpcklbw {{.*#+}} xmm9 = xmm10[0],xmm11[0],xmm10[1],xmm11[1],xmm10[2],xmm11[2],xmm10[3],xmm11[3],xmm10[4],xmm11[4],xmm10[5],xmm11[5],xmm10[6],xmm11[6],xmm10[7],xmm11[7] -+; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} xmm10 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> - ; AVX512F-SLOW-NEXT: vmovdqa64 %xmm16, %xmm0 --; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm0, %xmm10 --; AVX512F-SLOW-NEXT: vpshufb %xmm14, %xmm6, %xmm6 --; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm19, %zmm6, %zmm6 --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm12 = ymm4[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm14 = ymm18[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpshufb %xmm10, %xmm0, %xmm8 -+; AVX512F-SLOW-NEXT: vpshufb %xmm10, %xmm9, %xmm9 -+; AVX512F-SLOW-NEXT: vinserti32x4 $2, %xmm19, %zmm9, %zmm9 -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm10 = ymm4[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm18[2,3,2,3] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm19 = ymm3[2,3,2,3] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm1[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm7 = ymm7[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm11[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm8[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm7[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm6[2,3,2,3] - ; AVX512F-SLOW-NEXT: vmovdqa64 %ymm17, %ymm1 - ; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} ymm1 = ymm1[1,2,2,3,4,5,6,7,9,10,10,11,12,13,14,15] - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm21 = ymm1[0,0,1,1,4,4,5,5] -@@ -7690,28 +7691,29 @@ - ; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] - ; AVX512F-SLOW-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm1 # 64-byte Folded Reload - ; AVX512F-SLOW-NEXT: # zmm1 = zmm1[0,1,0,1],mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm13[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} xmm8 = xmm13[1,1,0,0,4,5,6,7] --; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm8 = xmm8[0,1,2,0] --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm8, %zmm3 --; AVX512F-SLOW-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm19, %ymm8 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm0, %zmm0 --; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload --; AVX512F-SLOW-NEXT: # zmm8 = mem[2,3,2,3,6,7,6,7] --; AVX512F-SLOW-NEXT: vporq %zmm8, %zmm0, %zmm0 --; AVX512F-SLOW-NEXT: vmovdqa {{.*#+}} ymm8 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] --; AVX512F-SLOW-NEXT: vpand %ymm7, %ymm8, %ymm7 -+; AVX512F-SLOW-NEXT: vpshufb {{.*#+}} xmm3 = xmm12[4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512F-SLOW-NEXT: vpshuflw {{.*#+}} xmm6 = xmm12[1,1,0,0,4,5,6,7] -+; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm6 = xmm6[0,1,2,0] -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm3, %zmm6, %zmm3 -+; AVX512F-SLOW-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm19, %ymm6 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm0 -+; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload -+; AVX512F-SLOW-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] -+; AVX512F-SLOW-NEXT: vporq %zmm6, %zmm0, %zmm0 -+; AVX512F-SLOW-NEXT: vbroadcasti128 {{.*#+}} ymm6 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] -+; AVX512F-SLOW-NEXT: # ymm6 = mem[0,1,0,1] -+; AVX512F-SLOW-NEXT: vpand %ymm6, %ymm13, %ymm7 - ; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm5, %zmm5 - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload - ; AVX512F-SLOW-NEXT: # zmm7 = mem[2,3,2,3,6,7,6,7] - ; AVX512F-SLOW-NEXT: vporq %zmm7, %zmm5, %zmm5 - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload - ; AVX512F-SLOW-NEXT: # zmm7 = mem[2,3,2,3,6,7,6,7] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm9 = zmm20[2,3,2,3,6,7,6,7] --; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm9 --; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm9 -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm12 = zmm26[2,3,2,3,6,7,6,7] -+; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm12 -+; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm12 - ; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm5 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] --; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm9 -+; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm12 - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Folded Reload - ; AVX512F-SLOW-NEXT: # zmm0 = mem[2,3,2,3,6,7,6,7] - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Folded Reload -@@ -7726,80 +7728,80 @@ - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm30[0,1,0,1,4,5,4,5] - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm7 = zmm22[0,1,0,1,4,5,4,5] - ; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm0, %zmm5, %zmm7 --; AVX512F-SLOW-NEXT: vpternlogq $248, %ymm8, %ymm12, %ymm14 -+; AVX512F-SLOW-NEXT: vpternlogq $248, %ymm6, %ymm10, %ymm11 - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm0 = ymm21[2,3,2,3] --; AVX512F-SLOW-NEXT: vpternlogq $236, %ymm8, %ymm4, %ymm0 --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm10[0,1,0,1] --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm14, %zmm4 -+; AVX512F-SLOW-NEXT: vpternlogq $236, %ymm6, %ymm4, %ymm0 -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm8[0,1,0,1] -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm11, %zmm4 - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm5 # 32-byte Folded Reload - ; AVX512F-SLOW-NEXT: # ymm5 = mem[2,3,2,3] --; AVX512F-SLOW-NEXT: vpshufhw $190, {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: # ymm8 = mem[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] --; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm8 = ymm8[2,2,3,3,6,6,7,7] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] --; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm8 -+; AVX512F-SLOW-NEXT: vpshufhw $190, {{[-0-9]+}}(%r{{[sb]}}p), %ymm6 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: # ymm6 = mem[0,1,2,3,6,7,7,6,8,9,10,11,14,15,15,14] -+; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} ymm6 = ymm6[2,2,3,3,6,6,7,7] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm6 = ymm6[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpternlogq $236, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm5, %ymm6 - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm5 = ymm15[0,1,0,1] --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm8, %zmm5 --; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm8 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] --; AVX512F-SLOW-NEXT: vpternlogq $184, %zmm4, %zmm8, %zmm5 --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm29[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm11 = ymm26[2,3,2,3] --; AVX512F-SLOW-NEXT: vpor %ymm4, %ymm11, %ymm4 --; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm10, %zmm4 --; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm10 # 64-byte Reload --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm10, %zmm0 --; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm12 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: # ymm12 = mem[0,1,0,1] --; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm14 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: # ymm14 = mem[2,3,2,3] --; AVX512F-SLOW-NEXT: vpshuflw $5, (%rsp), %xmm15 # 16-byte Folded Reload -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm5, %zmm6, %zmm5 -+; AVX512F-SLOW-NEXT: vmovdqa64 {{.*#+}} zmm6 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] -+; AVX512F-SLOW-NEXT: vpternlogq $184, %zmm4, %zmm6, %zmm5 -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm31[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm8 = ymm20[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpor %ymm4, %ymm8, %ymm4 -+; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm8, %zmm4 -+; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm0, %zmm8, %zmm0 -+; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm8 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: # ymm8 = mem[0,1,0,1] -+; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm10 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: # ymm10 = mem[0,1,0,1] -+; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm11 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: # ymm11 = mem[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpshuflw $5, {{[-0-9]+}}(%r{{[sb]}}p), %xmm15 # 16-byte Folded Reload - ; AVX512F-SLOW-NEXT: # xmm15 = mem[1,1,0,0,4,5,6,7] - ; AVX512F-SLOW-NEXT: vpshufd {{.*#+}} xmm15 = xmm15[0,1,2,0] - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm17 # 32-byte Folded Reload - ; AVX512F-SLOW-NEXT: # ymm17 = mem[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm27[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm18 = ymm29[2,3,2,3] - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm19 # 32-byte Folded Reload - ; AVX512F-SLOW-NEXT: # ymm19 = mem[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm20 = ymm28[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm20 = ymm27[2,3,2,3] - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm21 # 32-byte Folded Reload - ; AVX512F-SLOW-NEXT: # ymm21 = mem[2,3,2,3] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm22 = ymm31[2,3,2,3] --; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm8, %zmm0 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm11, %zmm4 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm12, %zmm8 # 32-byte Folded Reload --; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm8 -+; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %ymm22 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: # ymm22 = mem[2,3,2,3] -+; AVX512F-SLOW-NEXT: vpternlogq $226, %zmm4, %zmm6, %zmm0 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8, %zmm4 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm6 # 32-byte Folded Reload -+; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm6 - ; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload - ; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm23 --; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm23 -+; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} ymm4 = ymm15[0,0,1,0] --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm14, %zmm4 --; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload --; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm4 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm4, %zmm11, %zmm4 -+; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Reload -+; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm4 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm4 - ; AVX512F-SLOW-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload - ; AVX512F-SLOW-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] --; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm5 --; AVX512F-SLOW-NEXT: vporq %ymm17, %ymm18, %ymm8 --; AVX512F-SLOW-NEXT: vporq %ymm19, %ymm20, %ymm9 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm8, %zmm0, %zmm8 --; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm8 = zmm9[0,1,2,3],zmm8[4,5,6,7] --; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm25 -+; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm12, %zmm5 -+; AVX512F-SLOW-NEXT: vporq %ymm17, %ymm18, %ymm6 -+; AVX512F-SLOW-NEXT: vporq %ymm19, %ymm20, %ymm8 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm6 -+; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm6 = zmm8[0,1,2,3],zmm6[4,5,6,7] -+; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm25 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm13, %zmm25 --; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload --; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm24 -+; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Reload -+; AVX512F-SLOW-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm24 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm24 --; AVX512F-SLOW-NEXT: vporq %ymm21, %ymm22, %ymm7 --; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 --; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Reload --; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm7 = zmm8[0,1,2,3],zmm7[4,5,6,7] --; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm2 -+; AVX512F-SLOW-NEXT: vporq %ymm21, %ymm22, %ymm6 -+; AVX512F-SLOW-NEXT: vinserti64x4 $1, %ymm6, %zmm0, %zmm6 -+; AVX512F-SLOW-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload -+; AVX512F-SLOW-NEXT: vshufi64x2 {{.*#+}} zmm6 = zmm7[0,1,2,3],zmm6[4,5,6,7] -+; AVX512F-SLOW-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 --; AVX512F-SLOW-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Folded Reload --; AVX512F-SLOW-NEXT: # zmm0 = mem[0,1,0,1,4,5,4,5] --; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm6 = zmm6[0,1,0,1,4,5,4,5] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm14[0,1,0,1,4,5,4,5] -+; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm6 = zmm9[0,1,0,1,4,5,4,5] - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm6 - ; AVX512F-SLOW-NEXT: vpermq {{.*#+}} zmm0 = zmm3[0,0,1,0,4,4,5,4] - ; AVX512F-SLOW-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm0 -@@ -7812,35 +7814,33 @@ - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm5, 384(%rax) - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm4, 192(%rax) - ; AVX512F-SLOW-NEXT: vmovdqa64 %zmm23, 64(%rax) --; AVX512F-SLOW-NEXT: addq $1464, %rsp # imm = 0x5B8 -+; AVX512F-SLOW-NEXT: addq $1416, %rsp # imm = 0x588 - ; AVX512F-SLOW-NEXT: vzeroupper - ; AVX512F-SLOW-NEXT: retq - ; - ; AVX512F-ONLY-FAST-LABEL: store_i8_stride7_vf64: - ; AVX512F-ONLY-FAST: # %bb.0: --; AVX512F-ONLY-FAST-NEXT: subq $1256, %rsp # imm = 0x4E8 -+; AVX512F-ONLY-FAST-NEXT: subq $1496, %rsp # imm = 0x5D8 - ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero --; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm1, %ymm14 --; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[25],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero --; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm2, %ymm13 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %ymm7 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %ymm15 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm15, %ymm17 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm7[25],zero,ymm7[23],zero,zero,zero,zero,ymm7[26],zero,ymm7[24],zero,zero,zero,zero - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %ymm1 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero,ymm2[27],zero,ymm2[25] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm17 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %ymm15 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[23],zero,zero,zero,zero,ymm3[26],zero,ymm3[24],zero,zero,zero,zero,ymm3[27],zero,ymm3[25] -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %ymm4 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm18 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[25],zero,ymm4[23],zero,zero,zero,zero,ymm4[26],zero,ymm4[24],zero,zero - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -@@ -7853,431 +7853,446 @@ - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm23 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,1,14],zero,ymm1[12,13,0,1,14,15],zero,ymm1[3,12,13,2,3,16],zero,ymm1[30,31,28,29,16,17],zero,ymm1[31,18,19,28,29,18],zero - ; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %ymm1 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm1, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm1, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %ymm10 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm3 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm3, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm10, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm25 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm0, %ymm3 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm25 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm5, (%rsp) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] --; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm5, %ymm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm30 --; AVX512F-ONLY-FAST-NEXT: vporq %ymm3, %ymm5, %ymm24 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %xmm3 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %xmm6 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm6, %xmm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm6, %xmm28 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm6 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm9, %xmm19 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm3, %xmm29 --; AVX512F-ONLY-FAST-NEXT: vpor %xmm5, %xmm6, %xmm3 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %xmm10 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm6 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm11, %xmm6, %xmm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm10, %xmm9 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm10, %xmm27 --; AVX512F-ONLY-FAST-NEXT: vpor %xmm5, %xmm9, %xmm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm15 --; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %xmm10 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm15, %xmm9 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm22 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] -+; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm2, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm29 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm31 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vporq %ymm0, %ymm1, %ymm23 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdx), %xmm5 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rcx), %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm1, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm2, %xmm18 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm1, %xmm20 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm21 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rdi), %xmm11 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm14, %xmm1, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm1, %xmm28 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm11, %xmm5 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm5, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r9), %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%r8), %xmm9 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm12 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm16 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm10, %xmm12 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm21 --; AVX512F-ONLY-FAST-NEXT: vporq %xmm9, %xmm12, %xmm22 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm13, %ymm20 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm14, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm7, %ymm2, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm16, %ymm7 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm7[14],zero,zero,zero,zero,zero,zero,ymm7[15],zero,zero,zero,zero,zero,zero,ymm7[16],zero,zero,zero,zero,zero,zero,ymm7[17],zero,zero,zero,zero,zero,zero,ymm7[18] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm7 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[0,1,14],zero,ymm7[12,13,0,1,14,15],zero,ymm7[3,12,13,2,3,16],zero,ymm7[30,31,28,29,16,17],zero,ymm7[31,18,19,28,29,18],zero --; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm7, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm18, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %ymm18, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm9, %xmm13 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm12, %xmm13, %xmm12 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm12, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm7, %ymm6 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm7, %ymm24 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm13 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm7 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm4, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm0, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm13 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm11, %xmm13, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm9 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm9, %xmm2 --; AVX512F-ONLY-FAST-NEXT: vporq %xmm0, %xmm2, %xmm31 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm14 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm14, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %xmm8 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm19, %xmm2 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm8, %xmm2 --; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm7, %ymm6 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm6, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm6 = zero,zero,zero,ymm15[14],zero,zero,zero,zero,zero,zero,ymm15[15],zero,zero,zero,zero,zero,zero,ymm15[16],zero,zero,zero,zero,zero,zero,ymm15[17],zero,zero,zero,zero,zero,zero,ymm15[18] -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm3[0,1,14],zero,ymm3[12,13,0,1,14,15],zero,ymm3[3,12,13,2,3,16],zero,ymm3[30,31,28,29,16,17],zero,ymm3[31,18,19,28,29,18],zero -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm7, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %ymm19, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm6 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm6 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm3, %ymm6, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rsi), %xmm4 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm14, %xmm4, %xmm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm4, %xmm17 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdi), %xmm7 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm7, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm3, %xmm1, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rcx), %xmm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm3, %xmm12 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rdx), %xmm5 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm2 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm1, %xmm2, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r9), %xmm2 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm2, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm2, %xmm1 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm2, %xmm3 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%r8), %xmm4 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm2 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm4, %xmm2 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vpor %xmm1, %xmm2, %xmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm23, %ymm12 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm26, %ymm6 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm26, %ymm11 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm11[25],zero,ymm11[23],zero,zero,zero,zero,ymm11[26],zero,ymm11[24],zero,zero,zero,zero --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[18],zero,zero,zero,zero,ymm11[21],zero,ymm11[19],zero,zero,zero,zero,ymm11[22],zero,ymm11[20] --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm15 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm15[25],zero,ymm15[23],zero,zero,zero,zero,ymm15[26],zero,ymm15[24],zero,zero,zero,zero -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20] -+; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm8 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] --; AVX512F-ONLY-FAST-NEXT: # ymm2 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] --; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm19 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm1, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm30 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] -+; AVX512F-ONLY-FAST-NEXT: # ymm14 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] -+; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm10, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm10, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm29 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm24, %zmm0, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm1 - ; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm2, %xmm2 --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm0[4,5,6,7] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm2 -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm1[4,5,6,7] - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [2,2,3,3,2,2,3,3] --; AVX512F-ONLY-FAST-NEXT: # ymm2 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm23 = [2,2,3,3,2,2,3,3] -+; AVX512F-ONLY-FAST-NEXT: # ymm23 = mem[0,1,2,3,0,1,2,3] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %xmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm2, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm4 --; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm5, %ymm4, %ymm4 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm5, %ymm18 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm24 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm23 -+; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm2 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm2, %ymm23, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] -+; AVX512F-ONLY-FAST-NEXT: vpandn %ymm2, %ymm3, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa (%rax), %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm18 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[30],zero,ymm10[28],zero,zero,zero,zero,ymm10[31],zero,ymm10[29],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm26 - ; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] --; AVX512F-ONLY-FAST-NEXT: vmovdqu (%rsp), %ymm0 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm25 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm26 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] --; AVX512F-ONLY-FAST-NEXT: # ymm26 = mem[0,1,2,3,0,1,2,3] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm26, %ymm0, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm10[0],xmm15[0],xmm10[1],xmm15[1],xmm10[2],xmm15[2],xmm10[3],xmm15[3],xmm10[4],xmm15[4],xmm10[5],xmm15[5],xmm10[6],xmm15[6],xmm10[7],xmm15[7] --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm21 = zmm1[0,1,2,3],zmm0[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm29, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm28, %xmm1 --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] --; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] --; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm14[8],xmm8[8],xmm14[9],xmm8[9],xmm14[10],xmm8[10],xmm14[11],xmm8[11],xmm14[12],xmm8[12],xmm14[13],xmm8[13],xmm14[14],xmm8[14],xmm14[15],xmm8[15] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm1, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm31, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm22, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[27],zero,zero,zero,zero,ymm0[30],zero,ymm0[28],zero,zero,zero,zero,ymm0[31],zero,ymm0[29] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm31 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] -+; AVX512F-ONLY-FAST-NEXT: # ymm31 = mem[0,1,2,3,0,1,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm31, %ymm2, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm16, %xmm10 -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm9[0],xmm10[0],xmm9[1],xmm10[1],xmm9[2],xmm10[2],xmm9[3],xmm10[3],xmm9[4],xmm10[4],xmm9[5],xmm10[5],xmm9[6],xmm10[6],xmm9[7],xmm10[7] -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm16 = zmm3[0,1,2,3],zmm2[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm2, %xmm22 -+; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] -+; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm12[8],xmm5[8],xmm12[9],xmm5[9],xmm12[10],xmm5[10],xmm12[11],xmm5[11],xmm12[12],xmm5[12],xmm12[13],xmm5[13],xmm12[14],xmm5[14],xmm12[15],xmm5[15] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm5, %xmm20 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm12, %xmm21 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm3, %xmm1 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm0, %xmm2, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm28, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm5 = xmm11[0],xmm0[0],xmm11[1],xmm0[1],xmm11[2],xmm0[2],xmm11[3],xmm0[3],xmm11[4],xmm0[4],xmm11[5],xmm0[5],xmm11[6],xmm0[6],xmm11[7],xmm0[7] -+; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm11[8],xmm0[9],xmm11[9],xmm0[10],xmm11[10],xmm0[11],xmm11[11],xmm0[12],xmm11[12],xmm0[13],xmm11[13],xmm0[14],xmm11[14],xmm0[15],xmm11[15] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm17, %xmm3 -+; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm3[8],xmm7[8],xmm3[9],xmm7[9],xmm3[10],xmm7[10],xmm3[11],xmm7[11],xmm3[12],xmm7[12],xmm3[13],xmm7[13],xmm3[14],xmm7[14],xmm3[15],xmm7[15] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm2, %xmm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] -+; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm12 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm27 = ymm1[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm13, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm28 = ymm1[2,3,2,3] - ; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm4, %xmm0, %xmm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm28 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm27, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm6, %xmm1 --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm0[0],xmm6[0],xmm0[1],xmm6[1],xmm0[2],xmm6[2],xmm0[3],xmm6[3],xmm0[4],xmm6[4],xmm0[5],xmm6[5],xmm0[6],xmm6[6],xmm0[7],xmm6[7] --; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] --; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm13[8],xmm9[8],xmm13[9],xmm9[9],xmm13[10],xmm9[10],xmm13[11],xmm9[11],xmm13[12],xmm9[12],xmm13[13],xmm9[13],xmm13[14],xmm9[14],xmm13[15],xmm9[15] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm1, %xmm1 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm27 --; AVX512F-ONLY-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm10[8],xmm15[8],xmm10[9],xmm15[9],xmm10[10],xmm15[10],xmm10[11],xmm15[11],xmm10[12],xmm15[12],xmm10[13],xmm15[13],xmm10[14],xmm15[14],xmm10[15],xmm15[15] --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm3, %xmm1, %xmm1 --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm22[0,1,0,1],zmm1[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm1[0,1,0,1],zmm0[0,1,0,1] - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512F-ONLY-FAST-NEXT: vmovdqa 32(%rax), %xmm0 --; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm1 = xmm0[0,1,2,3,4,5,5,6] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm0, %xmm29 --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm23, %ymm0 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm5 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm20, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm22 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm2, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm20 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm2, %ymm4 --; AVX512F-ONLY-FAST-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm16, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm0, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm2, %ymm30 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm11, %ymm10 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm24 = ymm0[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm13, %ymm8 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[21],zero,ymm15[19],zero,zero,zero,zero,ymm15[22],zero,ymm15[20],zero,zero - ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm18, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm4 - ; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm0 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] - ; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,5,7,4,5] - ; AVX512F-ONLY-FAST-NEXT: vpermd %ymm0, %ymm1, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm16 --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm9[0],xmm13[0],xmm9[1],xmm13[1],xmm9[2],xmm13[2],xmm9[3],xmm13[3],xmm9[4],xmm13[4],xmm9[5],xmm13[5],xmm9[6],xmm13[6],xmm9[7],xmm13[7] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] --; AVX512F-ONLY-FAST-NEXT: # ymm9 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm17, %ymm3 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm9, %ymm3, %ymm15 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm15[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} ymm25 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] -+; AVX512F-ONLY-FAST-NEXT: vpandnq %ymm0, %ymm25, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm4, %zmm23 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] -+; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm9 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm9[2,3,2,3] - ; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm2 - ; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm13 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm18 = ymm13[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, %xmm31, %zmm0, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3],xmm8[4],xmm14[4],xmm8[5],xmm14[5],xmm8[6],xmm14[6],xmm8[7],xmm14[7] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm7, %xmm7 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm13 --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm8 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm0[23],zero,ymm0[23,24,25,26],zero,ymm0[24],zero,ymm0[30,31] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm8[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm12, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm12[30],zero,ymm12[28],zero,zero,zero,zero,ymm12[31],zero,ymm12[29],zero,zero,zero --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm9, %ymm2, %ymm9 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm23[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm29 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm25 = ymm13[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm7[0],xmm3[0],xmm7[1],xmm3[1],xmm7[2],xmm3[2],xmm7[3],xmm3[3],xmm7[4],xmm3[4],xmm7[5],xmm3[5],xmm7[6],xmm3[6],xmm7[7],xmm3[7] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm13, %xmm5, %xmm5 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm13, %xmm7, %xmm7 -+; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm1 # 16-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm1[23],zero,ymm1[23,24,25,26],zero,ymm1[24],zero,ymm1[30,31] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm7[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa %ymm6, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero,zero -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm7[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm3, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm0[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm12[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm26[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm21, %xmm4 -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{.*#+}} xmm12 = xmm1[0],xmm4[0],xmm1[1],xmm4[1],xmm1[2],xmm4[2],xmm1[3],xmm4[3],xmm1[4],xmm4[4],xmm1[5],xmm4[5],xmm1[6],xmm4[6],xmm1[7],xmm4[7] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm22, %xmm4 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm4, %xmm13 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,1,0,1] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm11[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm31 # 16-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] --; AVX512F-ONLY-FAST-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm3, %ymm14 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm3 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm3, %ymm1, %ymm13 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm2, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm11[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm1, %xmm12, %xmm1 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] -+; AVX512F-ONLY-FAST-NEXT: # ymm12 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm15, %ymm9 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm11 # 16-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpor %ymm12, %ymm9, %ymm9 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm9, %zmm6 --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] --; AVX512F-ONLY-FAST-NEXT: # ymm9 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm9, %ymm5, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm3, %ymm12 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm6, %ymm14, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm2, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] -+; AVX512F-ONLY-FAST-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, %ymm5, %ymm7, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm7 - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm6, %zmm0, %zmm7 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm4, %ymm10, %ymm4 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm2, %zmm4 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm8, %ymm14, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm7 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm10, %ymm8, %ymm2 - ; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm5 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm4, %zmm0, %zmm5 --; AVX512F-ONLY-FAST-NEXT: vpandq %ymm9, %ymm22, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm20, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm4, %ymm9, %ymm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm6 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm6 -+; AVX512F-ONLY-FAST-NEXT: vpandq %ymm5, %ymm27, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm28, %zmm0 - ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] - ; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm0, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm19, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm24, %ymm2 - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm17, %zmm2, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vporq %zmm4, %zmm2, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] -+; AVX512F-ONLY-FAST-NEXT: vporq %zmm3, %zmm2, %zmm2 - ; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm18, %ymm0 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm15, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vporq %zmm4, %zmm0, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm4 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $184, %zmm2, %zmm4, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm28[0,1,0,1,4,5,4,5] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm8 = zmm27[0,1,0,1,4,5,4,5] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm4, %zmm8 --; AVX512F-ONLY-FAST-NEXT: vpandq %ymm26, %ymm13, %ymm2 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 -+; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm25, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm19, %zmm0 -+; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] -+; AVX512F-ONLY-FAST-NEXT: vporq %zmm3, %zmm0, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $184, %zmm2, %zmm0, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] -+; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm8 = mem[0,1,0,1,4,5,4,5] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm8 -+; AVX512F-ONLY-FAST-NEXT: vpandq %ymm31, %ymm1, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm12, %zmm1 - ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] - ; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm1, %zmm1 - ; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm6, %zmm9 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm1, %zmm4, %zmm9 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm1 # 16-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # xmm1 = xmm1[0],mem[0],xmm1[1],mem[1],xmm1[2],mem[2],xmm1[3],mem[3],xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm18 # 64-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # zmm18 = zmm1[0,1,0,1],mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %xmm29, %xmm3 --; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm2 = xmm3[1,1,0,0,4,5,6,7] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm4 = [0,1,0,1,2,0,0,1] --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm2, %ymm4, %ymm19 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm6 = xmm1[1,1,0,0,4,5,6,7] --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm6, %ymm4, %ymm17 --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm6 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm6, %xmm3, %xmm10 --; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm6, %xmm1, %xmm6 --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] --; AVX512F-ONLY-FAST-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm12 --; AVX512F-ONLY-FAST-NEXT: vmovdqu (%rsp), %ymm1 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm13 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[25],zero,ymm1[23],zero,zero,zero,zero,ymm1[26],zero,ymm1[24],zero,zero --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm11 --; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] --; AVX512F-ONLY-FAST-NEXT: # ymm14 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm25, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm1, %ymm15 --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm2[23],zero,ymm2[23,24,25,26],zero,ymm2[24],zero,ymm2[30,31] --; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm14, %ymm2, %ymm14 --; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm4 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] --; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [4,5,4,5,5,7,4,5] --; AVX512F-ONLY-FAST-NEXT: vpermd %ymm4, %ymm2, %ymm20 --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] --; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] --; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm22 # 64-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # zmm22 = mem[2,3,2,3,6,7,6,7] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm22 --; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] -+; AVX512F-ONLY-FAST-NEXT: vporq %zmm2, %zmm5, %zmm22 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $226, %zmm1, %zmm0, %zmm22 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 16-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # xmm0 = xmm0[0],mem[0],xmm0[1],mem[1],xmm0[2],mem[2],xmm0[3],mem[3],xmm0[4],mem[4],xmm0[5],mem[5],xmm0[6],mem[6],xmm0[7],mem[7] -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm26 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm26 = zmm0[0,1,0,1],mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm4 # 16-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm1 = xmm4[1,1,0,0,4,5,6,7] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,2,0,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm19 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} xmm5 = xmm0[1,1,0,0,4,5,6,7] -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm5, %ymm2, %ymm17 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm10 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %xmm5, %xmm0, %xmm5 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] -+; AVX512F-ONLY-FAST-NEXT: # ymm12 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm13 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm14 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[25],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm12 -+; AVX512F-ONLY-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] - ; AVX512F-ONLY-FAST-NEXT: # ymm0 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm23 # 32-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: # ymm23 = mem[0,1,0,1] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %ymm30, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm1, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm9 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm9[23],zero,ymm9[23,24,25,26],zero,ymm9[24],zero,ymm9[30,31] -+; AVX512F-ONLY-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm0 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpshuflw {{.*#+}} ymm15 = ymm4[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] -+; AVX512F-ONLY-FAST-NEXT: vmovdqa {{.*#+}} ymm9 = [4,5,4,5,5,7,4,5] -+; AVX512F-ONLY-FAST-NEXT: vpermd %ymm15, %ymm9, %ymm20 -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm15 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] -+; AVX512F-ONLY-FAST-NEXT: vpshufb {{.*#+}} ymm9 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] -+; AVX512F-ONLY-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm24 # 64-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # zmm24 = mem[2,3,2,3,6,7,6,7] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm24 -+; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # ymm3 = mem[0,1,0,1] -+; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm25 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: # ymm25 = mem[0,1,0,1] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,0,1,0] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,0,1,0] - ; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm23, %zmm23 # 32-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm23 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm24 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm23, %zmm24 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm2, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm21 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm21 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm0 # 32-byte Folded Reload --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm0 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm12, %ymm15, %ymm2 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3, %zmm3 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm25, %zmm25 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm25 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm18 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm18 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm15, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm16 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm16 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm3 # 32-byte Folded Reload -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm3 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm2, %ymm13, %ymm2 - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm0, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm7[0,1,2,3],zmm2[4,5,6,7] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm16 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm16 -+; AVX512F-ONLY-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm4[0,1,2,3],zmm2[4,5,6,7] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm23 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 - ; AVX512F-ONLY-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512F-ONLY-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] --; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm5 = zmm31[0,1,0,1,4,5,4,5] --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm17, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm1, %ymm13, %ymm1 --; AVX512F-ONLY-FAST-NEXT: vpor %ymm11, %ymm14, %ymm5 -+; AVX512F-ONLY-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm11[0,1,0,1,4,5,4,5] -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm6 -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm17, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm1, %ymm14, %ymm1 -+; AVX512F-ONLY-FAST-NEXT: vpor %ymm0, %ymm12, %ymm0 - ; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm1 --; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm5[0,1,2,3],zmm1[4,5,6,7] --; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm20, %zmm4 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm4 --; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm4 -+; AVX512F-ONLY-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm1[4,5,6,7] -+; AVX512F-ONLY-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm1 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 -+; AVX512F-ONLY-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm22, %zmm1 - ; AVX512F-ONLY-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm4, 128(%rax) -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm1, 128(%rax) - ; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm2, (%rax) --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm16, 320(%rax) --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm0, 256(%rax) --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm21, 192(%rax) --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm24, 64(%rax) --; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm22, 384(%rax) --; AVX512F-ONLY-FAST-NEXT: addq $1256, %rsp # imm = 0x4E8 -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm23, 320(%rax) -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm3, 256(%rax) -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm16, 192(%rax) -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm18, 64(%rax) -+; AVX512F-ONLY-FAST-NEXT: vmovdqa64 %zmm24, 384(%rax) -+; AVX512F-ONLY-FAST-NEXT: addq $1496, %rsp # imm = 0x5D8 - ; AVX512F-ONLY-FAST-NEXT: vzeroupper - ; AVX512F-ONLY-FAST-NEXT: retq - ; - ; AVX512DQ-FAST-LABEL: store_i8_stride7_vf64: - ; AVX512DQ-FAST: # %bb.0: --; AVX512DQ-FAST-NEXT: subq $1256, %rsp # imm = 0x4E8 -+; AVX512DQ-FAST-NEXT: subq $1496, %rsp # imm = 0x5D8 - ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %ymm1 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero --; AVX512DQ-FAST-NEXT: vmovdqa %ymm1, %ymm14 --; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm2[25],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero --; AVX512DQ-FAST-NEXT: vmovdqa %ymm2, %ymm13 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %ymm7 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %ymm15 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm15, %ymm17 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm7[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm7[25],zero,ymm7[23],zero,zero,zero,zero,ymm7[26],zero,ymm7[24],zero,zero,zero,zero - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %ymm1 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm2 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero,zero --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm16 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm2[23],zero,zero,zero,zero,ymm2[26],zero,ymm2[24],zero,zero,zero,zero,ymm2[27],zero,ymm2[25] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm17 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %ymm15 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %ymm3 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm15[30],zero,ymm15[28],zero,zero,zero,zero,ymm15[31],zero,ymm15[29],zero,zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm15, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm3[23],zero,zero,zero,zero,ymm3[26],zero,ymm3[24],zero,zero,zero,zero,ymm3[27],zero,ymm3[25] -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm3, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %ymm4 - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %ymm1 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm18 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm4[25],zero,ymm4[23],zero,zero,zero,zero,ymm4[26],zero,ymm4[24],zero,zero - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0 -@@ -8290,403 +8305,420 @@ - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %ymm1 - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero,zero,zero,ymm1[18] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm23 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 - ; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %ymm1 - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[0,1,14],zero,ymm1[12,13,0,1,14,15],zero,ymm1[3,12,13,2,3,16],zero,ymm1[30,31,28,29,16,17],zero,ymm1[31,18,19,28,29,18],zero - ; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %ymm1 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm7 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] --; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm1, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm26 --; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm6 = [128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128,128] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm1, %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %ymm10 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,14,128,14,15,0,1,14,15,128,13,14,15,16,17,16,128,30,31,30,31,16,17,128,31,28,29,30,31] --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm3 --; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm3, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm10, %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm25 -+; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm1, %ymm0 - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128,128,128] --; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm0, %ymm3 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm25 --; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm5 --; AVX512DQ-FAST-NEXT: vmovdqu %ymm5, (%rsp) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] --; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm5, %ymm5 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm30 --; AVX512DQ-FAST-NEXT: vporq %ymm3, %ymm5, %ymm24 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %xmm3 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %xmm6 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = --; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm6, %xmm5 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm6, %xmm28 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm3, %xmm6 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm9, %xmm19 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm3, %xmm29 --; AVX512DQ-FAST-NEXT: vpor %xmm5, %xmm6, %xmm3 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %xmm10 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm6 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm11 = --; AVX512DQ-FAST-NEXT: vpshufb %xmm11, %xmm6, %xmm5 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = --; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm10, %xmm9 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm10, %xmm27 --; AVX512DQ-FAST-NEXT: vpor %xmm5, %xmm9, %xmm5 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm5, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm15 --; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %xmm10 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> --; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm15, %xmm9 -+; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = zero,zero,zero,zero,zero,zero,ymm1[14],zero,zero,zero,zero,zero,zero,ymm1[15],zero,zero,zero,zero,zero,zero,ymm1[16],zero,zero,zero,zero,zero,zero,ymm1[17],zero,zero,zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm22 -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %ymm2 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0,13,0,0,0,128,16,128,14,0,0,0,128,17,128,15,0] -+; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm2, %ymm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm29 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm31 -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vporq %ymm0, %ymm1, %ymm23 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdx), %xmm5 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rcx), %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = -+; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm1, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm2, %xmm18 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm1, %xmm20 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm2 = -+; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm21 -+; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm1, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rdi), %xmm11 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%rsi), %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm14 = -+; AVX512DQ-FAST-NEXT: vpshufb %xmm14, %xmm1, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm1, %xmm28 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm11, %xmm5 -+; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm5, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%r9), %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa 32(%r8), %xmm9 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <128,4,u,u,u,128,7,128,5,u,u,u,128,8,128,6> -+; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm12 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm16 - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <4,128,u,u,u,7,128,5,128,u,u,u,8,128,6,128> --; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm10, %xmm12 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm21 --; AVX512DQ-FAST-NEXT: vporq %xmm9, %xmm12, %xmm22 --; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm13, %ymm20 --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm14, %ymm2 --; AVX512DQ-FAST-NEXT: vpor %ymm7, %ymm2, %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm16, %ymm7 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = zero,zero,zero,ymm7[14],zero,zero,zero,zero,zero,zero,ymm7[15],zero,zero,zero,zero,zero,zero,ymm7[16],zero,zero,zero,zero,zero,zero,ymm7[17],zero,zero,zero,zero,zero,zero,ymm7[18] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm7 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm7[0,1,14],zero,ymm7[12,13,0,1,14,15],zero,ymm7[3,12,13,2,3,16],zero,ymm7[30,31,28,29,16,17],zero,ymm7[31,18,19,28,29,18],zero --; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm7, %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm18, %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqu64 %ymm18, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm9, %xmm13 -+; AVX512DQ-FAST-NEXT: vpor %xmm12, %xmm13, %xmm12 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm12, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm7, %ymm6 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm7, %ymm24 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm13 - ; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm7 --; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm2, %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 --; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm4, %ymm0 --; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm0, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm13 --; AVX512DQ-FAST-NEXT: vpshufb %xmm11, %xmm13, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm9 --; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm9, %xmm2 --; AVX512DQ-FAST-NEXT: vporq %xmm0, %xmm2, %xmm31 --; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm14 --; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm14, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %xmm8 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm19, %xmm2 --; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm8, %xmm2 --; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb %ymm7, %ymm13, %ymm7 -+; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm7, %ymm6 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm6, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm6 = zero,zero,zero,ymm15[14],zero,zero,zero,zero,zero,zero,ymm15[15],zero,zero,zero,zero,zero,zero,ymm15[16],zero,zero,zero,zero,zero,zero,ymm15[17],zero,zero,zero,zero,zero,zero,ymm15[18] -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm3[0,1,14],zero,ymm3[12,13,0,1,14,15],zero,ymm3[3,12,13,2,3,16],zero,ymm3[30,31,28,29,16,17],zero,ymm3[31,18,19,28,29,18],zero -+; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm7, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %ymm19, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm3 = zero,zero,zero,zero,zero,zero,ymm3[14],zero,zero,zero,zero,zero,zero,ymm3[15],zero,zero,zero,zero,zero,zero,ymm3[16],zero,zero,zero,zero,zero,zero,ymm3[17],zero,zero,zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm6 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm6, %ymm4, %ymm6 -+; AVX512DQ-FAST-NEXT: vpor %ymm3, %ymm6, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm3, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa (%rsi), %xmm4 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm14, %xmm4, %xmm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm4, %xmm17 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %xmm7 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm7, %xmm1 -+; AVX512DQ-FAST-NEXT: vpor %xmm3, %xmm1, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa (%rcx), %xmm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm18, %xmm1 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm3, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa %xmm3, %xmm12 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rdx), %xmm5 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm5, %xmm2 -+; AVX512DQ-FAST-NEXT: vpor %xmm1, %xmm2, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa %xmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa (%r9), %xmm2 --; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm2, %xmm0 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm2, %xmm1 - ; AVX512DQ-FAST-NEXT: vmovdqa %xmm2, %xmm3 - ; AVX512DQ-FAST-NEXT: vmovdqa %xmm2, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa (%r8), %xmm4 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm2 --; AVX512DQ-FAST-NEXT: vpshufb %xmm2, %xmm4, %xmm2 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm4, %xmm2 - ; AVX512DQ-FAST-NEXT: vmovdqa %xmm4, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512DQ-FAST-NEXT: vpor %xmm0, %xmm2, %xmm0 -+; AVX512DQ-FAST-NEXT: vpor %xmm1, %xmm2, %xmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm23, %ymm12 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm12[21],zero,ymm12[19],zero,zero,zero,zero,ymm12[22],zero,ymm12[20],zero,zero --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero,zero,zero,ymm0[27],zero,ymm0[25] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm26, %ymm6 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm6[21],zero,ymm6[19],zero,zero,zero,zero,ymm6[22],zero,ymm6[20],zero,zero -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm26, %ymm11 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm11[25],zero,ymm11[23],zero,zero,zero,zero,ymm11[26],zero,ymm11[24],zero,zero,zero,zero --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm11[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm11[18],zero,zero,zero,zero,ymm11[21],zero,ymm11[19],zero,zero,zero,zero,ymm11[22],zero,ymm11[20] --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm15 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,ymm15[25],zero,ymm15[23],zero,zero,zero,zero,ymm15[26],zero,ymm15[24],zero,zero,zero,zero -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20] -+; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm8 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] --; AVX512DQ-FAST-NEXT: # ymm2 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] --; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm19 --; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm1, %ymm2 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm30 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm2, %zmm0 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27,24,25,128,23,128,21,22,23,26,128,24,128,28,29,26,27] -+; AVX512DQ-FAST-NEXT: # ymm14 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128,18,128,18,19,20,21,128,19,128,25,26,27,22,128,20,128] -+; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm10, %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm10, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm29 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm2, %zmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm24, %zmm0, %zmm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm23, %zmm0, %zmm1 - ; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm4[8],xmm3[8],xmm4[9],xmm3[9],xmm4[10],xmm3[10],xmm4[11],xmm3[11],xmm4[12],xmm3[12],xmm4[13],xmm3[13],xmm4[14],xmm3[14],xmm4[15],xmm3[15] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm3 = --; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm2, %xmm2 --; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm0[4,5,6,7] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = -+; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm2, %xmm2 -+; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm2[0,1,0,1],zmm1[4,5,6,7] - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm2 = [2,2,3,3,2,2,3,3] --; AVX512DQ-FAST-NEXT: # ymm2 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vbroadcasti32x4 {{.*#+}} ymm23 = [2,2,3,3,2,2,3,3] -+; AVX512DQ-FAST-NEXT: # ymm23 = mem[0,1,2,3,0,1,2,3] - ; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %xmm0 - ; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill --; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] --; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm2, %ymm0 --; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm4 --; AVX512DQ-FAST-NEXT: vmovdqu %ymm4, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm5 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] --; AVX512DQ-FAST-NEXT: vpshufb %ymm5, %ymm4, %ymm4 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm5, %ymm18 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm0, %zmm24 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29],zero,zero --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm23 -+; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm2 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512DQ-FAST-NEXT: vpermd %ymm2, %ymm23, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm3 = [255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255] -+; AVX512DQ-FAST-NEXT: vpandn %ymm2, %ymm3, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa (%rax), %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [128,13,128,128,128,128,128,128,14,128,128,128,128,128,128,15,128,128,128,128,128,128,16,128,128,128,128,128,128,17,128,128] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm19 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm2, %zmm18 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm10[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm10[30],zero,ymm10[28],zero,zero,zero,zero,ymm10[31],zero,ymm10[29],zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm26 - ; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14] --; AVX512DQ-FAST-NEXT: vmovdqu (%rsp), %ymm0 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm0 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm25 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm1[27],zero,zero,zero,zero,ymm1[30],zero,ymm1[28],zero,zero,zero,zero,ymm1[31],zero,ymm1[29] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vbroadcasti64x2 {{.*#+}} ymm26 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] --; AVX512DQ-FAST-NEXT: # ymm26 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm26, %ymm0, %ymm1 --; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm10[0],xmm15[0],xmm10[1],xmm15[1],xmm10[2],xmm15[2],xmm10[3],xmm15[3],xmm10[4],xmm15[4],xmm10[5],xmm15[5],xmm10[6],xmm15[6],xmm10[7],xmm15[7] --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] --; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm21 = zmm1[0,1,2,3],zmm0[0,1,0,1] --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm29, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm28, %xmm1 --; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] --; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] --; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm14[8],xmm8[8],xmm14[9],xmm8[9],xmm14[10],xmm8[10],xmm14[11],xmm8[11],xmm14[12],xmm8[12],xmm14[13],xmm8[13],xmm14[14],xmm8[14],xmm14[15],xmm8[15] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm4 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> --; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm1, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm31, %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm22, %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm3 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[27],zero,zero,zero,zero,ymm0[30],zero,ymm0[28],zero,zero,zero,zero,ymm0[31],zero,ymm0[29] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm3 = ymm3[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vbroadcasti64x2 {{.*#+}} ymm31 = [18374967954648269055,71777218572844800,18374967954648269055,71777218572844800] -+; AVX512DQ-FAST-NEXT: # ymm31 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm31, %ymm2, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm16, %xmm10 -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm9[0],xmm10[0],xmm9[1],xmm10[1],xmm9[2],xmm10[2],xmm9[3],xmm10[3],xmm9[4],xmm10[4],xmm9[5],xmm10[5],xmm9[6],xmm10[6],xmm9[7],xmm10[7] -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm2 = xmm2[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] -+; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm16 = zmm3[0,1,2,3],zmm2[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm2 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm2, %xmm22 -+; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] -+; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm3 = xmm12[8],xmm5[8],xmm12[9],xmm5[9],xmm12[10],xmm5[10],xmm12[11],xmm5[11],xmm12[12],xmm5[12],xmm12[13],xmm5[13],xmm12[14],xmm5[14],xmm12[15],xmm5[15] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm5, %xmm20 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm12, %xmm21 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm0 = <6,3,2,u,u,u,9,8,5,4,u,u,u,11,10,7> -+; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm3, %xmm1 - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm1, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb %xmm0, %xmm2, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm28, %xmm0 -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm5 = xmm11[0],xmm0[0],xmm11[1],xmm0[1],xmm11[2],xmm0[2],xmm11[3],xmm0[3],xmm11[4],xmm0[4],xmm11[5],xmm0[5],xmm11[6],xmm0[6],xmm11[7],xmm0[7] -+; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm0[8],xmm11[8],xmm0[9],xmm11[9],xmm0[10],xmm11[10],xmm0[11],xmm11[11],xmm0[12],xmm11[12],xmm0[13],xmm11[13],xmm0[14],xmm11[14],xmm0[15],xmm11[15] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm17, %xmm3 -+; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm2 = xmm3[8],xmm7[8],xmm3[9],xmm7[9],xmm3[10],xmm7[10],xmm3[11],xmm7[11],xmm3[12],xmm7[12],xmm3[13],xmm7[13],xmm3[14],xmm7[14],xmm3[15],xmm7[15] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm2, %xmm2 -+; AVX512DQ-FAST-NEXT: vmovdqu %ymm2, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm0, %xmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm0 -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm0 = xmm9[8],xmm10[8],xmm9[9],xmm10[9],xmm9[10],xmm10[10],xmm9[11],xmm10[11],xmm9[12],xmm10[12],xmm9[13],xmm10[13],xmm9[14],xmm10[14],xmm9[15],xmm10[15] -+; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm8, %ymm12 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm24, %ymm11 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm1 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm27 = ymm1[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm13, %ymm1 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm28 = ymm1[2,3,2,3] - ; AVX512DQ-FAST-NEXT: vpshufb %xmm4, %xmm0, %xmm0 - ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm1 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm0, %zmm1, %zmm28 --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm27, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqa %xmm6, %xmm1 --; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm6 = xmm0[0],xmm6[0],xmm0[1],xmm6[1],xmm0[2],xmm6[2],xmm0[3],xmm6[3],xmm0[4],xmm6[4],xmm0[5],xmm6[5],xmm0[6],xmm6[6],xmm0[7],xmm6[7] --; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm1[8],xmm0[8],xmm1[9],xmm0[9],xmm1[10],xmm0[10],xmm1[11],xmm0[11],xmm1[12],xmm0[12],xmm1[13],xmm0[13],xmm1[14],xmm0[14],xmm1[15],xmm0[15] --; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm4 = xmm13[8],xmm9[8],xmm13[9],xmm9[9],xmm13[10],xmm9[10],xmm13[11],xmm9[11],xmm13[12],xmm9[12],xmm13[13],xmm9[13],xmm13[14],xmm9[14],xmm13[15],xmm9[15] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = <2,u,u,u,9,8,5,4,u,u,u,11,10,7,6,u> --; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm0 --; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm1, %xmm1 --; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm1, %zmm0, %zmm27 --; AVX512DQ-FAST-NEXT: vpunpckhbw {{.*#+}} xmm1 = xmm10[8],xmm15[8],xmm10[9],xmm15[9],xmm10[10],xmm15[10],xmm10[11],xmm15[11],xmm10[12],xmm15[12],xmm10[13],xmm15[13],xmm10[14],xmm15[14],xmm10[15],xmm15[15] --; AVX512DQ-FAST-NEXT: vpshufb %xmm3, %xmm1, %xmm1 --; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm22[0,1,0,1],zmm1[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm1[0,1,0,1],zmm0[0,1,0,1] - ; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill - ; AVX512DQ-FAST-NEXT: vmovdqa 32(%rax), %xmm0 --; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm1 = xmm0[0,1,2,3,4,5,5,6] --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm0, %xmm29 --; AVX512DQ-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqa %xmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 16-byte Spill -+; AVX512DQ-FAST-NEXT: vpshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,4,5,5,6] -+; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm23, %ymm0 - ; AVX512DQ-FAST-NEXT: vmovdqu %ymm0, {{[-0-9]+}}(%r{{[sb]}}p) # 32-byte Spill --; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm1 = [11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12,11,0,0,0,15,14,13,12] --; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm11, %ymm5 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm20, %ymm0 --; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm1 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm22 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 --; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm2, %ymm1 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm20 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm10 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,ymm0[18],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm0 --; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm2, %ymm4 --; AVX512DQ-FAST-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10,9,8,7,0,0,0,11,10] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm16, %ymm0 --; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm0, %ymm1 --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm2, %ymm30 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[21],zero,ymm0[19],zero,zero,zero,zero,ymm0[22],zero,ymm0[20],zero,zero -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm11, %ymm10 -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm0 = ymm15[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm24 = ymm0[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm13, %ymm8 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm11 = ymm15[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm15[21],zero,ymm15[19],zero,zero,zero,zero,ymm15[22],zero,ymm15[20],zero,zero - ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm18, %ymm1 --; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm19, %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm1, %ymm0, %ymm4 - ; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm0 = ymm0[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] - ; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm1 = [4,5,4,5,5,7,4,5] - ; AVX512DQ-FAST-NEXT: vpermd %ymm0, %ymm1, %ymm0 --; AVX512DQ-FAST-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm3, %zmm16 --; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm9[0],xmm13[0],xmm9[1],xmm13[1],xmm9[2],xmm13[2],xmm9[3],xmm13[3],xmm9[4],xmm13[4],xmm9[5],xmm13[5],xmm9[6],xmm13[6],xmm9[7],xmm13[7] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm9 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm6, %xmm6 --; AVX512DQ-FAST-NEXT: vpshufb %xmm9, %xmm0, %xmm0 --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] --; AVX512DQ-FAST-NEXT: # ymm9 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm17, %ymm3 --; AVX512DQ-FAST-NEXT: vpshufb %ymm9, %ymm3, %ymm15 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm15[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} ymm25 = [255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255,0,255,255,255,255,255,255] -+; AVX512DQ-FAST-NEXT: vpandnq %ymm0, %ymm25, %ymm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm4, %zmm23 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29,28,29,30,128,28,128,30,31,30,31,128,29,128,31,28,29] -+; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm15 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm15, %ymm9 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm17 = ymm9[2,3,2,3] - ; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm2 - ; AVX512DQ-FAST-NEXT: vpshufb %ymm2, %ymm1, %ymm13 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm18 = ymm13[2,3,2,3] --; AVX512DQ-FAST-NEXT: vinserti32x4 $2, %xmm31, %zmm0, %zmm0 --; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm0, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill --; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm0 = xmm8[0],xmm14[0],xmm8[1],xmm14[1],xmm8[2],xmm14[2],xmm8[3],xmm14[3],xmm8[4],xmm14[4],xmm8[5],xmm14[5],xmm8[6],xmm14[6],xmm8[7],xmm14[7] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm8 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> --; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm7, %xmm7 --; AVX512DQ-FAST-NEXT: vpshufb %xmm8, %xmm0, %xmm13 --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm8 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm0[23],zero,ymm0[23,24,25,26],zero,ymm0[24],zero,ymm0[30,31] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm0, %ymm25 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm8[2,3,2,3] --; AVX512DQ-FAST-NEXT: vmovdqa %ymm12, %ymm1 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm12 = ymm12[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm12[30],zero,ymm12[28],zero,zero,zero,zero,ymm12[31],zero,ymm12[29],zero,zero,zero --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb %ymm9, %ymm2, %ymm9 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm23[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm7[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm29 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm25 = ymm13[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm7 = xmm7[0],xmm3[0],xmm7[1],xmm3[1],xmm7[2],xmm3[2],xmm7[3],xmm3[3],xmm7[4],xmm3[4],xmm7[5],xmm3[5],xmm7[6],xmm3[6],xmm7[7],xmm3[7] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm13 = <0,1,u,u,u,6,7,2,3,u,u,u,8,9,4,5> -+; AVX512DQ-FAST-NEXT: vpshufb %xmm13, %xmm5, %xmm5 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm13, %xmm7, %xmm7 -+; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm7, %zmm1 # 16-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vmovdqu64 %zmm1, {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Spill -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm1[23],zero,ymm1[23,24,25,26],zero,ymm1[24],zero,ymm1[30,31] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm1, %ymm30 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm19 = ymm7[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa %ymm6, %ymm2 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm7 = ymm6[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,ymm6[30],zero,ymm6[28],zero,zero,zero,zero,ymm6[31],zero,ymm6[29],zero,zero,zero -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm7[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm3, %ymm0 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm0[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm7 = ymm12[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm26[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm20, %xmm1 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm21, %xmm4 -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{.*#+}} xmm12 = xmm1[0],xmm4[0],xmm1[1],xmm4[1],xmm1[2],xmm4[2],xmm1[3],xmm4[3],xmm1[4],xmm4[4],xmm1[5],xmm4[5],xmm1[6],xmm4[6],xmm1[7],xmm4[7] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm1 = <4,5,0,1,u,u,u,6,7,2,3,u,u,u,8,9> -+; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm22, %xmm4 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm4, %xmm13 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[0,1,0,1] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm11[2,3,2,3] --; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm13, %zmm31 # 16-byte Folded Reload --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] --; AVX512DQ-FAST-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm3, %ymm14 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm3 --; AVX512DQ-FAST-NEXT: vpshufb %ymm3, %ymm1, %ymm13 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm2, %ymm1 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm8 = ymm8[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm11[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpshufb %xmm1, %xmm12, %xmm1 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23,18,19,20,21,128,19,128,21,20,21,22,128,20,128,22,23] -+; AVX512DQ-FAST-NEXT: # ymm12 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm15, %ymm9 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vinserti32x4 $2, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm11 # 16-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[9,u,7,u,u,u,u,10,u,8,u,u,u,u,11,u,25,u,23,u,u,u,u,26,u,24,u,u,u,u,27,u] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpor %ymm12, %ymm9, %ymm9 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm9, %zmm6 --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm9 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] --; AVX512DQ-FAST-NEXT: # ymm9 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm9, %ymm5, %ymm0 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm7, %zmm0, %zmm7 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm3, %ymm12 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpor %ymm6, %ymm14, %ymm2 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm2, %zmm2 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm5 = [18374966859431673855,18446463693966278655,18374966859431673855,18446463693966278655] -+; AVX512DQ-FAST-NEXT: # ymm5 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpternlogq $248, %ymm5, %ymm7, %ymm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm13, %zmm0, %zmm7 - ; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255] --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm6, %zmm0, %zmm7 --; AVX512DQ-FAST-NEXT: vpor %ymm4, %ymm10, %ymm4 --; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm2, %zmm4 --; AVX512DQ-FAST-NEXT: vpor %ymm8, %ymm14, %ymm2 -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm7 -+; AVX512DQ-FAST-NEXT: vpor %ymm10, %ymm8, %ymm2 - ; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm5 --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm4, %zmm0, %zmm5 --; AVX512DQ-FAST-NEXT: vpandq %ymm9, %ymm22, %ymm0 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm20, %zmm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm3, %zmm2 -+; AVX512DQ-FAST-NEXT: vpor %ymm4, %ymm9, %ymm3 -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm3, %zmm4, %zmm6 -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm6 -+; AVX512DQ-FAST-NEXT: vpandq %ymm5, %ymm27, %ymm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm28, %zmm0 - ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] - ; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm0, %zmm0 --; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm19, %ymm2 -+; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm24, %ymm2 - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm17, %zmm2, %zmm2 --; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload --; AVX512DQ-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] --; AVX512DQ-FAST-NEXT: vporq %zmm4, %zmm2, %zmm2 -+; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] -+; AVX512DQ-FAST-NEXT: vporq %zmm3, %zmm2, %zmm2 - ; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm2 --; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm18, %ymm0 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm15, %zmm0 --; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Folded Reload --; AVX512DQ-FAST-NEXT: # zmm4 = mem[2,3,2,3,6,7,6,7] --; AVX512DQ-FAST-NEXT: vporq %zmm4, %zmm0, %zmm0 --; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm4 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] --; AVX512DQ-FAST-NEXT: vpternlogq $184, %zmm2, %zmm4, %zmm0 --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm2 = zmm28[0,1,0,1,4,5,4,5] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm8 = zmm27[0,1,0,1,4,5,4,5] --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm4, %zmm8 --; AVX512DQ-FAST-NEXT: vpandq %ymm26, %ymm13, %ymm2 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm1, %zmm1 -+; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm25, %ymm0 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm0, %zmm19, %zmm0 -+; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm3 = mem[2,3,2,3,6,7,6,7] -+; AVX512DQ-FAST-NEXT: vporq %zmm3, %zmm0, %zmm3 -+; AVX512DQ-FAST-NEXT: vmovdqa64 {{.*#+}} zmm0 = [255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255,255,255,255,255,0,0,255] -+; AVX512DQ-FAST-NEXT: vpternlogq $184, %zmm2, %zmm0, %zmm3 -+; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] -+; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm8 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm8 = mem[0,1,0,1,4,5,4,5] -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm2, %zmm0, %zmm8 -+; AVX512DQ-FAST-NEXT: vpandq %ymm31, %ymm1, %ymm1 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm12, %zmm1 - ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] - ; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm1, %zmm1 - ; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512DQ-FAST-NEXT: # zmm2 = mem[2,3,2,3,6,7,6,7] --; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm6 # 64-byte Folded Reload --; AVX512DQ-FAST-NEXT: # zmm6 = mem[2,3,2,3,6,7,6,7] --; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm6, %zmm9 --; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm1, %zmm4, %zmm9 --; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload --; AVX512DQ-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm1, %xmm1 # 16-byte Folded Reload --; AVX512DQ-FAST-NEXT: # xmm1 = xmm1[0],mem[0],xmm1[1],mem[1],xmm1[2],mem[2],xmm1[3],mem[3],xmm1[4],mem[4],xmm1[5],mem[5],xmm1[6],mem[6],xmm1[7],mem[7] --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] --; AVX512DQ-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm1, %zmm18 # 64-byte Folded Reload --; AVX512DQ-FAST-NEXT: # zmm18 = zmm1[0,1,0,1],mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vmovdqa64 %xmm29, %xmm3 --; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm2 = xmm3[1,1,0,0,4,5,6,7] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm4 = [0,1,0,1,2,0,0,1] --; AVX512DQ-FAST-NEXT: vpermd %ymm2, %ymm4, %ymm19 --; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm1 # 16-byte Reload --; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm6 = xmm1[1,1,0,0,4,5,6,7] --; AVX512DQ-FAST-NEXT: vpermd %ymm6, %ymm4, %ymm17 --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm6 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] --; AVX512DQ-FAST-NEXT: vpshufb %xmm6, %xmm3, %xmm10 --; AVX512DQ-FAST-NEXT: vpshufb %xmm6, %xmm1, %xmm6 --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm11 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] --; AVX512DQ-FAST-NEXT: # ymm11 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm1 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm12 --; AVX512DQ-FAST-NEXT: vmovdqu (%rsp), %ymm1 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm13 = ymm1[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm1[25],zero,ymm1[23],zero,zero,zero,zero,ymm1[26],zero,ymm1[24],zero,zero --; AVX512DQ-FAST-NEXT: vpshufb %ymm11, %ymm1, %ymm11 --; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm14 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] --; AVX512DQ-FAST-NEXT: # ymm14 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm25, %ymm1 --; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm1, %ymm15 --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm2 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm2[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm2[23],zero,ymm2[23,24,25,26],zero,ymm2[24],zero,ymm2[30,31] --; AVX512DQ-FAST-NEXT: vpshufb %ymm14, %ymm2, %ymm14 --; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Reload --; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm4 = ymm3[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] --; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [4,5,4,5,5,7,4,5] --; AVX512DQ-FAST-NEXT: vpermd %ymm4, %ymm2, %ymm20 --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm2 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] --; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm4 = ymm3[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] --; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm22 # 64-byte Folded Reload --; AVX512DQ-FAST-NEXT: # zmm22 = mem[2,3,2,3,6,7,6,7] --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm22 --; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm5 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm5 = mem[2,3,2,3,6,7,6,7] -+; AVX512DQ-FAST-NEXT: vporq %zmm2, %zmm5, %zmm22 -+; AVX512DQ-FAST-NEXT: vpternlogq $226, %zmm1, %zmm0, %zmm22 -+; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload -+; AVX512DQ-FAST-NEXT: vpunpcklbw {{[-0-9]+}}(%r{{[sb]}}p), %xmm0, %xmm0 # 16-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # xmm0 = xmm0[0],mem[0],xmm0[1],mem[1],xmm0[2],mem[2],xmm0[3],mem[3],xmm0[4],mem[4],xmm0[5],mem[5],xmm0[6],mem[6],xmm0[7],mem[7] -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[u,u,4,5,0,1,u,u,u,6,7,2,3,u,u,u] -+; AVX512DQ-FAST-NEXT: vshufi64x2 $0, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm26 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm26 = zmm0[0,1,0,1],mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm4 # 16-byte Reload -+; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm1 = xmm4[1,1,0,0,4,5,6,7] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm2 = [0,1,0,1,2,0,0,1] -+; AVX512DQ-FAST-NEXT: vpermd %ymm1, %ymm2, %ymm19 -+; AVX512DQ-FAST-NEXT: vmovdqa {{[-0-9]+}}(%r{{[sb]}}p), %xmm0 # 16-byte Reload -+; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} xmm5 = xmm0[1,1,0,0,4,5,6,7] -+; AVX512DQ-FAST-NEXT: vpermd %ymm5, %ymm2, %ymm17 -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} xmm5 = [4,5,4,5,4,5,8,9,6,7,6,7,6,7,6,7] -+; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm4, %xmm10 -+; AVX512DQ-FAST-NEXT: vpshufb %xmm5, %xmm0, %xmm5 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm12 = [128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22,128,20,128,18,128,128,128,128,21,128,19,128,128,128,128,22] -+; AVX512DQ-FAST-NEXT: # ymm12 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm29, %ymm0 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm13 -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm0 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm14 = ymm0[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u],zero,zero,zero,zero,ymm0[25],zero,ymm0[23],zero,zero,zero,zero,ymm0[26],zero,ymm0[24],zero,zero -+; AVX512DQ-FAST-NEXT: vpshufb %ymm12, %ymm0, %ymm12 -+; AVX512DQ-FAST-NEXT: vbroadcasti128 {{.*#+}} ymm0 = [20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128,20,128,18,128,20,21,20,21,128,19,128,19,20,21,22,128] - ; AVX512DQ-FAST-NEXT: # ymm0 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm23 # 32-byte Folded Reload --; AVX512DQ-FAST-NEXT: # ymm23 = mem[0,1,0,1] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vmovdqa64 %ymm30, %ymm1 -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm1, %ymm2 -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm9 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm1 = ymm9[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,24,25,24,25],zero,ymm9[23],zero,ymm9[23,24,25,26],zero,ymm9[24],zero,ymm9[30,31] -+; AVX512DQ-FAST-NEXT: vpshufb %ymm0, %ymm9, %ymm0 -+; AVX512DQ-FAST-NEXT: vmovdqu {{[-0-9]+}}(%r{{[sb]}}p), %ymm4 # 32-byte Reload -+; AVX512DQ-FAST-NEXT: vpshuflw {{.*#+}} ymm15 = ymm4[2,1,1,2,4,5,6,7,10,9,9,10,12,13,14,15] -+; AVX512DQ-FAST-NEXT: vmovdqa {{.*#+}} ymm9 = [4,5,4,5,5,7,4,5] -+; AVX512DQ-FAST-NEXT: vpermd %ymm15, %ymm9, %ymm20 -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm15 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,28,29,26,27,28,29,30,31,30,31,28,29,28,29,30,31] -+; AVX512DQ-FAST-NEXT: vpshufb {{.*#+}} ymm9 = ymm4[u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u,22,23,26,27,24,25,22,23,24,25,26,27,26,27,24,25] -+; AVX512DQ-FAST-NEXT: vpermq $238, {{[-0-9]+}}(%r{{[sb]}}p), %zmm24 # 64-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # zmm24 = mem[2,3,2,3,6,7,6,7] -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm24 -+; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm3 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # ymm3 = mem[0,1,0,1] -+; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %ymm25 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: # ymm25 = mem[0,1,0,1] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm15 = ymm15[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm6 = ymm6[0,0,1,0] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm10 = ymm10[0,0,1,0] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm13 = ymm13[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm11 = ymm11[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm2 = ymm2[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm5 = ymm5[0,0,1,0] - ; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm14 = ymm14[2,3,2,3] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm4 = ymm4[2,3,2,3] --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm0, %zmm0 # 32-byte Folded Reload --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm23, %zmm23 # 32-byte Folded Reload --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm23 --; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm0 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm24 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm23, %zmm24 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm2, %zmm0 --; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm21 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm21 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm0 # 32-byte Folded Reload --; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm0 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm0 --; AVX512DQ-FAST-NEXT: vpor %ymm12, %ymm15, %ymm2 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm12 = ymm12[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm0 = ymm0[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} ymm9 = ymm9[2,3,2,3] -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm3, %zmm3 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm25, %zmm25 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm25 -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm3 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm18 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm25, %zmm18 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm19, %zmm15, %zmm3 -+; AVX512DQ-FAST-NEXT: vpternlogq $228, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm3, %zmm16 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm7, %zmm16 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, {{[-0-9]+}}(%r{{[sb]}}p), %zmm10, %zmm3 # 32-byte Folded Reload -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm4, %zmm3 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm8, %zmm3 -+; AVX512DQ-FAST-NEXT: vpor %ymm2, %ymm13, %ymm2 - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm2, %zmm0, %zmm2 --; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm7 # 64-byte Reload --; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm7[0,1,2,3],zmm2[4,5,6,7] --; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm16 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm16 -+; AVX512DQ-FAST-NEXT: vmovdqu64 {{[-0-9]+}}(%r{{[sb]}}p), %zmm4 # 64-byte Reload -+; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm2 = zmm4[0,1,2,3],zmm2[4,5,6,7] -+; AVX512DQ-FAST-NEXT: vpternlogq $248, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm23 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm23 - ; AVX512DQ-FAST-NEXT: vpermq $68, {{[-0-9]+}}(%r{{[sb]}}p), %zmm2 # 64-byte Folded Reload - ; AVX512DQ-FAST-NEXT: # zmm2 = mem[0,1,0,1,4,5,4,5] --; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm5 = zmm31[0,1,0,1,4,5,4,5] --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm5 --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm6, %zmm17, %zmm2 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm18, %zmm2 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm5, %zmm2 --; AVX512DQ-FAST-NEXT: vpor %ymm1, %ymm13, %ymm1 --; AVX512DQ-FAST-NEXT: vpor %ymm11, %ymm14, %ymm5 -+; AVX512DQ-FAST-NEXT: vpermq {{.*#+}} zmm6 = zmm11[0,1,0,1,4,5,4,5] -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2, %zmm6 -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm5, %zmm17, %zmm2 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm26, %zmm2 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm6, %zmm2 -+; AVX512DQ-FAST-NEXT: vpor %ymm1, %ymm14, %ymm1 -+; AVX512DQ-FAST-NEXT: vpor %ymm0, %ymm12, %ymm0 - ; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm1, %zmm0, %zmm1 --; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm1 = zmm5[0,1,2,3],zmm1[4,5,6,7] --; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm4, %zmm20, %zmm4 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm1, %zmm4 --; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm9, %zmm4 -+; AVX512DQ-FAST-NEXT: vshufi64x2 {{.*#+}} zmm0 = zmm0[0,1,2,3],zmm1[4,5,6,7] -+; AVX512DQ-FAST-NEXT: vinserti64x4 $1, %ymm9, %zmm20, %zmm1 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm0, %zmm1 -+; AVX512DQ-FAST-NEXT: vpternlogq $216, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm22, %zmm1 - ; AVX512DQ-FAST-NEXT: movq {{[0-9]+}}(%rsp), %rax --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm4, 128(%rax) -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm1, 128(%rax) - ; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm2, (%rax) --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm16, 320(%rax) --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm0, 256(%rax) --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm21, 192(%rax) --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm24, 64(%rax) --; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm22, 384(%rax) --; AVX512DQ-FAST-NEXT: addq $1256, %rsp # imm = 0x4E8 -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm23, 320(%rax) -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm3, 256(%rax) -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm16, 192(%rax) -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm18, 64(%rax) -+; AVX512DQ-FAST-NEXT: vmovdqa64 %zmm24, 384(%rax) -+; AVX512DQ-FAST-NEXT: addq $1496, %rsp # imm = 0x5D8 - ; AVX512DQ-FAST-NEXT: vzeroupper - ; AVX512DQ-FAST-NEXT: retq - ; -diff -ruN --strip-trailing-cr a/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll ---- a/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll -+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/vararg_shadow.ll -@@ -1303,6 +1303,7 @@ - ; CHECK-NEXT: [[TMP64:%.*]] = xor i64 [[TMP63]], 87960930222080 - ; CHECK-NEXT: [[TMP65:%.*]] = inttoptr i64 [[TMP64]] to ptr - ; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 inttoptr (i64 add (i64 ptrtoint (ptr @__msan_va_arg_tls to i64), i64 688) to ptr), ptr align 8 [[TMP65]], i64 64, i1 false) -+; CHECK-NEXT: call void @llvm.memset.p0.i32(ptr align 8 inttoptr (i64 add (i64 ptrtoint (ptr @__msan_va_arg_tls to i64), i64 752) to ptr), i8 0, i32 48, i1 false) - ; CHECK-NEXT: store i64 1280, ptr @__msan_va_arg_overflow_size_tls, align 8 - ; CHECK-NEXT: call void (ptr, i32, ...) @_Z5test2I11LongDouble4EvT_iz(ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], i32 noundef 20, ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]], ptr noundef nonnull byval([[STRUCT_LONGDOUBLE4]]) align 16 [[ARG]]) - ; CHECK-NEXT: ret void +diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/X86/debug-names-types.ll b/llvm/test/DebugInfo/X86/debug-names-types.ll +--- a/llvm/test/DebugInfo/X86/debug-names-types.ll ++++ b/llvm/test/DebugInfo/X86/debug-names-types.ll +@@ -5,7 +5,7 @@ + ; RUN: llc -mtriple=x86_64 -generate-type-units -dwarf-version=5 -filetype=obj %s -o %t + ; RUN: llvm-dwarfdump -debug-info -debug-names %t | FileCheck %s + +-; RUN: llc -mtriple=x86_64 -generate-type-units -dwarf-version=5 -filetype=obj -split-dwarf-file=mainTypes.dwo --split-dwarf-output=mainTypes.dwo %s -o %t ++; RUN: llc -mtriple=x86_64 -generate-type-units -dwarf-version=5 -filetype=obj -split-dwarf-file=%t.mainTypes.dwo --split-dwarf-output=%t.mainTypes.dwo %s -o %t + ; RUN: llvm-readelf --sections %t | FileCheck %s --check-prefixes=CHECK-SPLIT + + ; CHECK-SPLIT-NOT: .debug_names diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 1a88322162b80d..e7665a06309746 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "506c47df00bbd9e527ecc5ac6e192b5fe5daa2c5" - LLVM_SHA256 = "5db67a5293810e6aebd2f757d660dbe2271fc86016c56cfcc56a89705dc22d80" + LLVM_COMMIT = "9bdbb8226e70fb248b40a4b5002699ee9eeeda93" + LLVM_SHA256 = "edec60ee200b44c9ed33376b277297e917f0a9695beb0ec75a5f3392544625a5" tf_http_archive( name = name, From 1bdb2dbb1473f25315c0bdf633812114e5ebbf1e Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Mon, 20 Nov 2023 12:50:45 -0800 Subject: [PATCH 303/391] Disable test_saved_model in linear_operator_inversion_test. This test is currently failing under the `xla_gpu` config - an issue exposed by the changes in 0736fdb. All this change does is disable that test temporarily until it can be fixed. This change preserves the status quo since the test was previously (silently) disabled. PiperOrigin-RevId: 584104512 --- .../kernel_tests/linalg/linear_operator_inversion_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py index ca5cb6e0d1f7a9..bcb1360ea6b2eb 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py @@ -33,6 +33,13 @@ class LinearOperatorInversionTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" + # TODO: b/311343496 - Re-enable this test. + @staticmethod + def skip_these_tests() -> list[str]: + return [ + "test_saved_model", + ] + def tearDown(self): config.enable_tensor_float_32_execution(self.tf32_keep_) From 4ca42b511e2a2b30edb710b32e4b4720c08bfc9f Mon Sep 17 00:00:00 2001 From: Steven Toribio Date: Mon, 20 Nov 2023 12:55:42 -0800 Subject: [PATCH 304/391] Update Flatbuffers Download to be compatible with Android API level >= 23 PiperOrigin-RevId: 584105785 --- third_party/flatbuffers/workspace.bzl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index 1aa9b2ff2d00ba..a0b943d7a9487b 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -2,12 +2,20 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +# _FLATBUFFERS_GIT_COMMIT / _FLATBUFFERS_SHA256 were added due to an urgent change being made to +# Flatbuffers that needed to be updated in order for Flatbuffers/TfLite be compatible with Android +# API level >= 23. They can be removed next flatbuffers offical release / update. +_FLATBUFFERS_GIT_COMMIT = "7d6d99c6befa635780a4e944d37ebfd58e68a108" + +# curl -L https://github.com/google/flatbuffers/archive/<_FLATBUFFERS_GIT_COMMIT>.tar.gz | shasum -a 256 +_FLATBUFFERS_SHA256 = "d27761f6b2fb1017ec00ed317a7b98cb7aed86b81d90528b498fb17ec13579a1" + def repo(): tf_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-23.5.26", - sha256 = "1cce06b17cddd896b6d73cc047e36a254fb8df4d7ea18a46acf16c4c0cd3f3f3", - urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz"), + strip_prefix = "flatbuffers-%s" % _FLATBUFFERS_GIT_COMMIT, + sha256 = _FLATBUFFERS_SHA256, + urls = tf_mirror_urls("https://github.com/google/flatbuffers/archive/%s.tar.gz" % _FLATBUFFERS_GIT_COMMIT), build_file = "//third_party/flatbuffers:flatbuffers.BUILD", system_build_file = "//third_party/flatbuffers:BUILD.system", link_files = { From 53beaa5f7771e67915d580618bec6e735e040224 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 20 Nov 2023 13:01:58 -0800 Subject: [PATCH 305/391] [XLA] Avoid making an extra copy of constant literals when compiling to llvm IR. Before this change the constants were first copied from HLO to LHLO MLIR, and then from LHLO MLIR into an intermediate buffer - `GpuExecutable::ConstantInfo`. This change turns the intermediate buffer into an alias to the data stored in the HLO graph in most cases. In the case of 4-bit integers we still need to make a copy since the needed data format is different. However, 4-bit ints are very rare. Especially with large constants, this change significantly reduces the peak memory usage. The copy of the data in LHLO MLIR remains, even tough it has a shorter lifetime than the HLO data. In theory, it could be possible to avoid also the copy in the LHLO MLIR, but that I have not looked into how much effort that could be (likely a lot). PiperOrigin-RevId: 584107265 --- third_party/xla/xla/service/gpu/BUILD | 8 +- .../service/gpu/compile_module_to_llvm_ir.cc | 2 +- .../xla/xla/service/gpu/gpu_compiler.cc | 6 +- .../xla/xla/service/gpu/gpu_executable.cc | 12 +- .../xla/xla/service/gpu/gpu_executable.h | 3 +- .../xla/xla/service/gpu/ir_emission_utils.cc | 113 +++--------------- .../xla/xla/service/gpu/ir_emission_utils.h | 37 +++++- .../xla/service/gpu/ir_emission_utils_test.cc | 51 ++++---- .../xla/xla/service/gpu/ir_emitter_context.cc | 13 +- .../xla/xla/service/gpu/ir_emitter_context.h | 3 +- .../xla/xla/service/gpu/ir_emitter_nested.cc | 6 +- .../xla/service/gpu/ir_emitter_unnested.cc | 34 +++--- .../xla/xla/service/gpu/ir_emitter_unnested.h | 2 +- third_party/xla/xla/stream_executor/BUILD | 1 + .../xla/xla/stream_executor/cuda/BUILD | 1 + .../xla/stream_executor/cuda/cuda_executor.cc | 3 +- third_party/xla/xla/stream_executor/gpu/BUILD | 1 + .../xla/stream_executor/gpu/gpu_executor.h | 3 +- .../stream_executor/rocm/rocm_gpu_executor.cc | 2 +- .../stream_executor_internal.h | 3 +- .../stream_executor/stream_executor_pimpl.cc | 5 +- .../stream_executor/stream_executor_pimpl.h | 2 +- 22 files changed, 148 insertions(+), 163 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 27997a6b74ec5b..df56ca9966ac8e 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -253,10 +253,12 @@ cc_library( deps = [ ":gpu_constants", ":gpu_executable", + ":ir_emission_utils", "//xla/service:buffer_assignment", "//xla/service:name_uniquer", - "//xla/stream_executor", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", ], @@ -297,6 +299,7 @@ cc_library( ":target_util", ":thunk", "//xla:autotuning_proto_cc", + "//xla:literal", "//xla:permutation_util", "//xla:shape_util", "//xla:status", @@ -400,6 +403,7 @@ cc_library( ":backend_configs_cc", ":hlo_fusion_analysis", ":hlo_to_ir_bindings", + ":ir_emission_utils", ":ir_emitter_context", ":kernel_reuse_cache", ":target_util", @@ -1027,6 +1031,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1069,6 +1074,7 @@ cc_library( deps = [ ":hlo_traversal", ":target_util", + "//xla:literal", "//xla:shape_util", "//xla:status", "//xla:status_macros", diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 1ce7531cff042d..fb8debda20771f 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -465,7 +465,7 @@ void RemoveUnusedAndUninitializedGlobals( for (const auto& info : constants) { // Empty content means the constant is initialized in the LLVM IR, so we // must not remove it. - if (!info.content.empty()) { + if (!info.content.span().empty()) { llvm::GlobalVariable* global = llvm_module->getGlobalVariable(info.symbol_name); CHECK(global != nullptr); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 8f78780c107da9..f83e2a7823aa24 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -286,7 +286,8 @@ class GpuAotCompilationResult : public AotCompilationResult { auto* cst_proto = xla_runtime_gpu_executable_.add_constants(); cst_proto->set_symbol_name(cst.symbol_name); cst_proto->set_allocation_index(cst.allocation_index); - cst_proto->set_content(cst.content.data(), cst.content.size()); + cst_proto->set_content(cst.content.span().data(), + cst.content.span().size()); } } @@ -332,7 +333,8 @@ StatusOr> GpuAotCompilationResult::LoadExecutable( for (auto& cst : xla_runtime_gpu_executable_.constants()) { GpuExecutable::ConstantInfo constant = { cst.symbol_name(), - {cst.content().begin(), cst.content().end()}, + DenseDataIntermediate::Own( + std::vector{cst.content().begin(), cst.content().end()}), cst.allocation_index()}; constants.push_back(std::move(constant)); } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 6ed3e3a171aa74..adc8571ca83d2c 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "mlir/Parser/Parser.h" // from @llvm-project @@ -337,19 +338,20 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { VLOG(3) << "Resolved global " << info.symbol_name << " to " << global.opaque(); - if (!info.content.empty()) { + if (!info.content.span().empty()) { // This means the constant did not have an initializer in the PTX and // therefore must be initialized by XLA here. - stream->ThenMemcpy(&global, info.content.data(), info.content.size()); + stream->ThenMemcpy(&global, info.content.span().data(), + info.content.span().size()); submitted_mem_copies = true; } } else { // The constant was not defined in the PTX and therefore must be both // allocated and initialized by XLA here. - CHECK(!info.content.empty()); + CHECK(!info.content.span().empty()); - TF_ASSIGN_OR_RETURN( - auto shared, executor->CreateOrShareConstant(stream, info.content)); + TF_ASSIGN_OR_RETURN(auto shared, executor->CreateOrShareConstant( + stream, info.content.span())); global = *shared; VLOG(3) << "Allocated (or shared) global " << info.symbol_name << " at " << global.opaque(); diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index 064b535a0cff42..228096cdc2bb55 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/executable.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/non_atomically_upgradeable_rw_lock.h" #include "xla/service/gpu/runtime/executable.h" #include "xla/service/gpu/thunk.h" @@ -64,7 +65,7 @@ class GpuExecutable : public Executable { struct ConstantInfo { std::string symbol_name; - std::vector content; + DenseDataIntermediate content; int allocation_index = -1; }; diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 3f94385d270e52..3fc8f0c302a481 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -64,6 +64,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" @@ -1005,104 +1006,26 @@ bool IsAMDGPU(const llvm::Module* module) { return llvm::Triple(module->getTargetTriple()).isAMDGPU(); } -namespace { -template -void CopyDenseElementsBy(mlir::DenseElementsAttr data, - std::vector* output) { - output->resize(data.getNumElements() * sizeof(T)); - int64_t i = 0; - for (T element : data.getValues()) { - std::memcpy(&(*output)[i], &element, sizeof(T)); - i += sizeof(T); +StatusOr LiteralToXlaFormat(const Literal& literal) { + PrimitiveType element_type = literal.shape().element_type(); + if (!primitive_util::IsArrayType(element_type)) { + return Internal("Unsupported type in LiteralToXlaFormat"); } -} - -template <> -void CopyDenseElementsBy(mlir::DenseElementsAttr data, - std::vector* output) { - output->resize(CeilOfRatio(data.getNumElements(), int64_t{2})); - absl::Span output_span = - absl::MakeSpan(reinterpret_cast(output->data()), output->size()); - PackInt4(data.getRawData(), output_span); -} -} // namespace - -Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, - std::vector* output) { - mlir::Type element_type = data.getType().getElementType(); - // TODO(hinsu): Support remaining XLA primitive types. - if (element_type.isInteger(1)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(4)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(8)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(16)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(32)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isInteger(64)) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E5M2()) { - CopyDenseElementsBy(data, output); - return OkStatus(); + int64_t byte_size = literal.size_bytes(); + if (primitive_util::Is4BitType(element_type)) { + std::vector output(CeilOfRatio(byte_size, int64_t{2})); + absl::Span output_span = + absl::MakeSpan(reinterpret_cast(output.data()), output.size()); + PackInt4( + absl::MakeSpan(reinterpret_cast(literal.untyped_data()), + byte_size), + output_span); + return DenseDataIntermediate::Own(std::move(output)); } - if (element_type.isFloat8E4M3FN()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E4M3B11FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E5M2FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isFloat8E4M3FNUZ()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isBF16()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF16()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF32()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (element_type.isF64()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (auto complex_type = element_type.dyn_cast()) { - if (complex_type.getElementType().isF32()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - if (complex_type.getElementType().isF64()) { - CopyDenseElementsBy(data, output); - return OkStatus(); - } - } - return Internal("Unsupported type in CopyDenseElementsDataToXlaFormat"); + + return DenseDataIntermediate::Alias(absl::MakeSpan( + reinterpret_cast(literal.untyped_data()), byte_size)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 80e7b45501ce02..172d405086f947 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -19,16 +19,21 @@ limitations under the License. #include #include #include +#include +#include #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/statusor.h" namespace xla { namespace gpu { @@ -231,8 +236,36 @@ std::string GetIrNameFromLoc(mlir::Location loc); // Whether the module's target is an AMD GPU. bool IsAMDGPU(const llvm::Module* module); -Status CopyDenseElementsDataToXlaFormat(mlir::DenseElementsAttr data, - std::vector* output); +// This class stores either a non-owning reference or owns data that represents +// a dense array in XLA format. It is used for intermediate storage during IR +// constant emission. +class DenseDataIntermediate { + public: + // Creates an instance of DenseDataIntermediate that owns the provided vector. + static DenseDataIntermediate Own(std::vector owned) { + DenseDataIntermediate di; + di.data_ = std::move(owned); + return di; + } + + // Creates an instance of DenseDataIntermediate that aliases the input. + static DenseDataIntermediate Alias(absl::Span aliased) { + DenseDataIntermediate di; + di.data_ = aliased; + return di; + } + + // Returns a reference to the data this object represents. + absl::Span span() const { + return data_.index() == 0 ? absl::Span(std::get<0>(data_)) + : std::get<1>(data_); + } + + private: + std::variant, absl::Span> data_; +}; + +StatusOr LiteralToXlaFormat(const Literal& literal); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 4aa4a539e858d9..21a05ae26a757d 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -478,35 +478,40 @@ ENTRY entry { } TEST_F(IrEmissionUtilsTest, LiteralToAttrToXlaFormat) { - mlir::MLIRContext context; - context.loadDialect(); - mlir::Builder builder(&context); - - // int16 + // int16, should be aliased. { - Literal x = LiteralUtil::CreateR2({{0, 1, 2}, {3, 4, 5}}); - TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, - CreateDenseElementsAttrFromLiteral(x, builder)); - - std::vector data; - TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); - for (int i = 0; i < 6; i++) { - int16_t x; - memcpy(&x, &data[i * 2], 2); - EXPECT_EQ(x, i); - } + Literal literal = LiteralUtil::CreateR2({{0, 1, 2}, {3, 4, 5}}); + + TF_ASSERT_OK_AND_ASSIGN(DenseDataIntermediate data, + LiteralToXlaFormat(literal)); + EXPECT_EQ(data.span().size(), literal.size_bytes()); + EXPECT_EQ(reinterpret_cast(data.span().data()), + literal.untyped_data()); } - // int4 + // int4, even, should be a new (unaliased) packed array. { - Literal x = LiteralUtil::CreateR2( + Literal literal = LiteralUtil::CreateR2( {{s4(0), s4(1), s4(2)}, {s4(3), s4(4), s4(5)}}); - TF_ASSERT_OK_AND_ASSIGN(mlir::DenseElementsAttr attr, - CreateDenseElementsAttrFromLiteral(x, builder)); - std::vector data; - TF_ASSERT_OK(CopyDenseElementsDataToXlaFormat(attr, &data)); - EXPECT_EQ(data, std::vector({0x01, 0x23, 0x45})); + TF_ASSERT_OK_AND_ASSIGN(DenseDataIntermediate data, + LiteralToXlaFormat(literal)); + EXPECT_EQ(data.span(), std::vector({0x01, 0x23, 0x45})); + EXPECT_NE(reinterpret_cast(data.span().data()), + literal.untyped_data()); + } + + // int4, odd, should be a new (unaliased) packed array. + { + Literal literal = LiteralUtil::CreateR2( + {{u4(0), u4(1), u4(2)}, {u4(3), u4(4), u4(5)}, {u4(6), u4(7), u4(8)}}); + + TF_ASSERT_OK_AND_ASSIGN(DenseDataIntermediate data, + LiteralToXlaFormat(literal)); + EXPECT_EQ(data.span(), + std::vector({0x01, 0x23, 0x45, 0x67, 0x80})); + EXPECT_NE(reinterpret_cast(data.span().data()), + literal.untyped_data()); } } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.cc b/third_party/xla/xla/service/gpu/ir_emitter_context.cc index af3a8530b5426c..9af81cb90fd454 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.cc @@ -20,7 +20,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "llvm/ADT/ArrayRef.h" #include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/ir_emission_utils.h" namespace xla { namespace gpu { @@ -29,7 +32,7 @@ void IrEmitterContext::emit_constant(int64_t num_elements, int64_t bytes_per_element, absl::string_view symbol_name, int allocation_idx, - llvm::ArrayRef content, + DenseDataIntermediate content, llvm::IRBuilder<>* b) { // LLVM and PTXAS don't deal well with large constants, so we only emit very // small constants directly in LLVM IR. Larger constants are emitted with @@ -50,15 +53,17 @@ void IrEmitterContext::emit_constant(int64_t num_elements, GpuExecutable::ConstantInfo info; llvm::Constant* initializer = [&]() -> llvm::Constant* { if (!should_emit_initializer) { - info.content = content; + info.content = std::move(content); return llvm::ConstantAggregateZero::get(global_type); } std::vector padded(kMinConstAllocationInBytes, 0); - absl::c_copy(content, padded.begin()); + absl::c_copy(content.span(), padded.begin()); return llvm::ConstantDataArray::get( llvm_module_->getContext(), - needs_padding ? llvm::ArrayRef(padded) : content); + needs_padding ? llvm::ArrayRef(padded) + : llvm::ArrayRef(content.span().data(), + content.span().size())); }(); // These globals will be looked up by name by GpuExecutable so we need to diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h index 3e9787a645b569..faa4379bc8e7d7 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/name_uniquer.h" #include "xla/stream_executor/device_description.h" @@ -91,7 +92,7 @@ class IrEmitterContext { // element, given symbol name and content. void emit_constant(int64_t num_elements, int64_t bytes_per_element, absl::string_view symbol_name, int allocation_idx, - llvm::ArrayRef content, llvm::IRBuilder<>* b); + DenseDataIntermediate content, llvm::IRBuilder<>* b); const DebugOptions& debug_options() const { return hlo_module_->config().debug_options(); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc index f040ae279340f5..b531a5655fb064 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -25,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_reuse_cache.h" @@ -252,7 +254,9 @@ Status IrEmitterNested::EmitConstants(const HloComputation& computation) { global_name, /*allocation_idx=*/-1, - llvm::ArrayRef(base, base + literal.size_bytes()), &b_); + DenseDataIntermediate::Alias( + absl::MakeSpan(base, base + literal.size_bytes())), + &b_); } return OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 489a367a0c7405..d2208d57df7c2d 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -83,6 +83,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -405,27 +406,19 @@ Status IrEmitterUnnested::EmitUnreachable(mlir::Operation* op, return OkStatus(); } -Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) { +Status IrEmitterUnnested::EmitConstant(mlir::Operation* op, + const Literal& literal) { auto get_global = mlir::cast(op); auto module = get_global->getParentOfType(); auto global = mlir::cast( module.lookupSymbol(get_global.getName())); - auto literal = global.getInitialValue()->dyn_cast(); - TF_RET_CHECK(literal); - std::vector content; - TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content)); - int num_elements, element_bytes; - if (literal.getType().getElementType().isInteger(4)) { - // Treat int4 constant as int8 constant with half the number of elements - TF_RET_CHECK(content.size() == - (literal.getType().getNumElements() + 1) / 2); - num_elements = content.size(); - element_bytes = 1; - } else { - num_elements = literal.getType().getNumElements(); - TF_ASSIGN_OR_RETURN( - element_bytes, GetElementTypeBytes(literal.getType().getElementType())); - } + TF_ASSIGN_OR_RETURN(DenseDataIntermediate content, + LiteralToXlaFormat(literal)); + + int element_bytes = primitive_util::ByteWidth(literal.shape().element_type()); + TF_RET_CHECK(content.span().size() % element_bytes == 0); + // Treat int4 constant as int8 constant with half the number of elements. + int num_elements = content.span().size() / element_bytes; int64_t arg_index = global->getAttrOfType("lmhlo.alloc").getInt(); @@ -433,7 +426,7 @@ Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) { ir_emitter_context_->emit_constant(num_elements, element_bytes, global.getSymName(), allocation_index, - content, &b_); + std::move(content), &b_); return OkStatus(); } @@ -3375,7 +3368,10 @@ Status IrEmitterUnnested::EmitOp( } if (mlir::isa(op)) { - return EmitConstant(op); + const HloConstantInstruction* hlo_const_instr = + DynCast(hlo_for_lmhlo.at(op)); + TF_RET_CHECK(hlo_const_instr); + return EmitConstant(op, hlo_const_instr->literal()); } if (auto call = mlir::dyn_cast(op)) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 484910eef12de2..b13764fd06dcc2 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -127,7 +127,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. It also mixes in some special handling for custom kernels // via the ThunkEmitter. - Status EmitConstant(mlir::Operation* op); + Status EmitConstant(mlir::Operation* op, const Literal& literal); Status EmitConditional( mlir::Operation* op, diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 9406c210b646d0..67bf9a3d8db120 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -370,6 +370,7 @@ cc_library( "//xla/stream_executor/platform", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index b719a603d089af..bdf4aa6c6c95bc 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -555,6 +555,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_absl//absl/strings:str_format", "//xla/stream_executor:command_buffer", "//xla/stream_executor:kernel", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 4a9bbaa92e9625..08656ad60119ba 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_driver.h" @@ -346,7 +347,7 @@ int fpus_per_core(int cc_major, int cc_minor) { tsl::StatusOr> GpuExecutor::CreateOrShareConstant(Stream* stream, - const std::vector& content) { + absl::Span content) { absl::MutexLock lock{&shared_constants_mu_}; // We assume all constants are uniquely identified by this hash. In the // (highly unlikely) event of a hash collision, the program will likely crash diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 1ea30b06ef827f..9f8b083d150d93 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -160,6 +160,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 35a261fda6ace7..536356e1f7eec6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_kernel.h" @@ -113,7 +114,7 @@ class GpuExecutor : public internal::StreamExecutorInterface { // content. Or, if a device with identical content is already on-device, // returns a pointer to that buffer with shared ownership. tsl::StatusOr> CreateOrShareConstant( - Stream* stream, const std::vector& content) override; + Stream* stream, absl::Span content) override; tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& k, diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc index bd596857070335..296b2f7edf12d5 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -121,7 +121,7 @@ int fpus_per_core(std::string gcn_arch_name) { tsl::StatusOr> GpuExecutor::CreateOrShareConstant(Stream* stream, - const std::vector& content) { + absl::Span content) { return tsl::errors::Unimplemented("Not implemented for ROCm"); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 1928d1afc913bf..3c98029d3dbb7a 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -31,6 +31,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" @@ -241,7 +242,7 @@ class StreamExecutorInterface { return absl::UnimplementedError("Not Implemented"); } virtual tsl::StatusOr> - CreateOrShareConstant(Stream* stream, const std::vector& content) { + CreateOrShareConstant(Stream* stream, absl::Span content) { return absl::UnimplementedError("Not Implemented"); } virtual tsl::Status Launch(Stream* stream, const ThreadDim& thread_dims, diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index 83bfbcfd8e4cc8..a58beb89c6bab2 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/notification.h" +#include "absl/types/span.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/fft.h" @@ -192,8 +193,8 @@ bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { tsl::StatusOr> StreamExecutor::CreateOrShareConstant(Stream* stream, - const std::vector& content) { - return implementation_->CreateOrShareConstant(stream, std::move(content)); + absl::Span content) { + return implementation_->CreateOrShareConstant(stream, content); } void StreamExecutor::Deallocate(DeviceMemoryBase* mem) { diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h index a7f16930aa8b58..d7235077d44f43 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.h +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.h @@ -128,7 +128,7 @@ class StreamExecutor { bool UnloadModule(ModuleHandle module_handle); tsl::StatusOr> CreateOrShareConstant( - Stream* stream, const std::vector& content); + Stream* stream, absl::Span content); // Synchronously allocates an array on the device of type T with element_count // elements. From d0164e424b8df8b2573b7c4f86e2ae23118f679c Mon Sep 17 00:00:00 2001 From: Tristen Allen Date: Mon, 20 Nov 2023 13:10:30 -0800 Subject: [PATCH 306/391] Increase shard count on tests flaking due to timeout. 0736fdb in effect re-enabled tests in these targets that had been disabled for a while. This is causing these tests to time out inconsistently and therefore flake. This change increases the shard count for these tests to avoid such a timeout. PiperOrigin-RevId: 584109682 --- tensorflow/python/kernel_tests/linalg/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index f695417f3fbca3..2298216f4aaed9 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -412,7 +412,7 @@ cuda_py_strict_test( name = "linear_operator_low_rank_update_test", size = "medium", srcs = ["linear_operator_low_rank_update_test.py"], - shard_count = 10, + shard_count = 15, tags = ["optonly"], deps = [ "//tensorflow/python/framework:config", @@ -516,7 +516,7 @@ cuda_py_strict_test( name = "linear_operator_tridiag_test", size = "medium", srcs = ["linear_operator_tridiag_test.py"], - shard_count = 5, + shard_count = 10, tags = [ "no_windows_gpu", "optonly", From c52fd22c711650dd6905e00c35066042ec9a10b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 13:12:02 -0800 Subject: [PATCH 307/391] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4416bda9204dadf7d280e9993a0f7b876bad1f5d. PiperOrigin-RevId: 584110186 --- third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 617614dba84506..af94f99d39a697 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "383e601d6140ee499349c4bb53085eb4a891f500" - TFRT_SHA256 = "edacc0434ee28a2203f6699b0cac3feed1bf2fe8b011f59e5bcfcb48a74e4bcb" + TFRT_COMMIT = "4416bda9204dadf7d280e9993a0f7b876bad1f5d" + TFRT_SHA256 = "13a1fff504670708ba6d83aeeced41aa988f67bcaadc36de08040c35824b8624" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 617614dba84506..af94f99d39a697 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "383e601d6140ee499349c4bb53085eb4a891f500" - TFRT_SHA256 = "edacc0434ee28a2203f6699b0cac3feed1bf2fe8b011f59e5bcfcb48a74e4bcb" + TFRT_COMMIT = "4416bda9204dadf7d280e9993a0f7b876bad1f5d" + TFRT_SHA256 = "13a1fff504670708ba6d83aeeced41aa988f67bcaadc36de08040c35824b8624" tf_http_archive( name = "tf_runtime", From c729393846af0a10b4c0cf0df088f37b8c44ac6f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 20 Nov 2023 13:12:08 -0800 Subject: [PATCH 308/391] Add a default constructor for ffi::Span std::span has one, see https://en.cppreference.com/w/cpp/container/span/span. PiperOrigin-RevId: 584110227 --- third_party/xla/xla/ffi/api/ffi.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 3c0f3f5f50f1dc..005827dcb4f24b 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -61,6 +61,8 @@ enum class DataType : uint8_t { template class Span { public: + constexpr Span() : data_(nullptr), size_(0) {} + Span(T* data, size_t size) : data_(data), size_(size) {} Span(const std::vector>& vec) // NOLINT : Span(vec.data(), vec.size()) {} From 9a31876fa426ee12a3d6306e937ba59e7ab9b33a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 13:13:56 -0800 Subject: [PATCH 309/391] Skip constant op if the float value is NaN or INF. PiperOrigin-RevId: 584110640 --- .../lite/quantization/quantization_driver.cc | 11 +++++++ .../stablehlo/tests/prepare_quantize.mlir | 33 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 57a5c93556c4cf..62c2733d2b510c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -24,10 +25,12 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -777,6 +780,14 @@ void QuantizationDriver::PreprocessConstantOps() { auto type = cst.getType().dyn_cast(); if (!type || !type.getElementType().isa()) return; + // Skip if the value is NaN or INF. + // Otherwise the illegal scale/zp will be calculated. + auto float_attr = cst.getValueAttr().dyn_cast(); + if (float_attr) { + auto cst_float_falue = float_attr.getValues()[0]; + if (!cst_float_falue.isFinite()) return; + } + Value value = cst.getResult(); builder_.setInsertionPoint(cst); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir index 8f38f889f28e33..a873f30a20cff8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_quantize.mlir @@ -105,3 +105,36 @@ func.func @merge_consecutive_qcast(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, % %6 = "quantfork.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> func.return %3, %6 : tensor<*xf32>, tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func @skip_nan_inf_constant +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @skip_nan_inf_constant(%arg0: tensor) -> tensor { + // CHECK: %[[cst0:.*]] = stablehlo.constant + // CHECK: %[[cst1:.*]] = stablehlo.constant + // CHECK: %[[cst2:.*]] = stablehlo.constant + // CHECK: %[[cst3:.*]] = stablehlo.constant + // CHECK-NOT: %[[q0:.*]] = "quantfork.qcast"(%[[cst0]]) + // CHECK-NOT: %[[q1:.*]] = "quantfork.qcast"(%[[cst1]]) + // CHECK: %[[q2:.*]] = "quantfork.qcast"(%[[cst2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[q3:.*]] = "quantfork.qcast"(%[[cst3]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq3:.*]] = "quantfork.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0x7FC00000> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.constant dense<0.000000e+00> : tensor + %4 = "stablehlo.add"(%0, %1) : (tensor, tensor) -> tensor + %5 = stablehlo.clamp %3, %arg0, %2 : (tensor, tensor, tensor) -> tensor + %6 = "stablehlo.reduce_window"(%5, %4) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %7 : tensor + }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor, tensor) -> tensor + return %6 : tensor +} From c2e2e1c152dbe4273712c19049c6d57dd75cf2df Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 20 Nov 2023 13:16:41 -0800 Subject: [PATCH 310/391] [xla:gpu:runtime] Support `GetObjFile` for non-`JitExecutable`. PiperOrigin-RevId: 584111333 --- third_party/xla/xla/service/gpu/runtime/executable.cc | 7 ++----- third_party/xla/xla/service/gpu/runtime/executable.h | 6 +++++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/executable.cc b/third_party/xla/xla/service/gpu/runtime/executable.cc index 823c94f6809c68..50e3c5fc44404d 100644 --- a/third_party/xla/xla/service/gpu/runtime/executable.cc +++ b/third_party/xla/xla/service/gpu/runtime/executable.cc @@ -494,7 +494,7 @@ Status GpuRuntimeExecutable::Execute( //===----------------------------------------------------------------------===// -Executable& GpuRuntimeExecutable::executable() { +const Executable& GpuRuntimeExecutable::executable() const { if (auto* jit = std::get_if>(&executable_)) { return *(*jit)->DefaultExecutable(); } @@ -502,10 +502,7 @@ Executable& GpuRuntimeExecutable::executable() { } StatusOr GpuRuntimeExecutable::GetObjFile() const { - const auto* jit = std::get_if>(&executable_); - if (!jit) return InternalError("ObjFile is not available"); - - if (auto obj_file = (*jit)->DefaultExecutable()->obj_file()) + if (auto obj_file = executable().obj_file()) return std::string_view(obj_file->getBuffer()); return InternalError("gpu runtime executable didn't save the obj file"); diff --git a/third_party/xla/xla/service/gpu/runtime/executable.h b/third_party/xla/xla/service/gpu/runtime/executable.h index e3951647fedf44..fbc95fe960ac45 100644 --- a/third_party/xla/xla/service/gpu/runtime/executable.h +++ b/third_party/xla/xla/service/gpu/runtime/executable.h @@ -139,7 +139,11 @@ class GpuRuntimeExecutable { // Depending on the state of `executable_` returns a reference to active // Xla runtime executable. - runtime::Executable& executable(); + runtime::Executable& executable() { + return const_cast( + const_cast(this)->executable()); + } + const runtime::Executable& executable() const; std::vector buffer_sizes_; From 1cfc5badc4ec119b4cc2b5cfe8fa8b3c104256a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 20 Nov 2023 13:20:37 -0800 Subject: [PATCH 311/391] [Triton] Fix a dominance violation bug in `TritonGPURemoveLayoutConversionsPass`. openai/triton PR: https://github.com/openai/triton/pull/2659 PiperOrigin-RevId: 584112220 --- third_party/triton/cl582925648.patch | 24 +++++++++++++++++++ third_party/triton/workspace.bzl | 1 + .../xla/third_party/triton/cl582925648.patch | 24 +++++++++++++++++++ .../xla/third_party/triton/workspace.bzl | 1 + 4 files changed, 50 insertions(+) create mode 100644 third_party/triton/cl582925648.patch create mode 100644 third_party/xla/third_party/triton/cl582925648.patch diff --git a/third_party/triton/cl582925648.patch b/third_party/triton/cl582925648.patch new file mode 100644 index 00000000000000..9d86a1e9a21d32 --- /dev/null +++ b/third_party/triton/cl582925648.patch @@ -0,0 +1,24 @@ +diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +--- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +@@ -787,7 +787,7 @@ static void rewriteSlice(SetVector(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + for (Value operand : yieldOp.getOperands()) { +diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir +--- a/test/TritonGPU/combine.mlir ++++ b/test/TritonGPU/combine.mlir +@@ -53,7 +53,7 @@ tt.func @remat(%arg0: i32) -> tensor<102 + // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]> + // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]> +- // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]> ++ // CHECK: %6 = arith.addi %5, %4 : tensor<1024xi32, [[$target_layout]]> + // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]> + } + diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index bd4c52b02d48a5..e981a633d29614 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -18,5 +18,6 @@ def repo(): "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", + "//third_party/triton:cl582925648.patch", ], ) diff --git a/third_party/xla/third_party/triton/cl582925648.patch b/third_party/xla/third_party/triton/cl582925648.patch new file mode 100644 index 00000000000000..9d86a1e9a21d32 --- /dev/null +++ b/third_party/xla/third_party/triton/cl582925648.patch @@ -0,0 +1,24 @@ +diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +--- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +@@ -787,7 +787,7 @@ static void rewriteSlice(SetVector(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + for (Value operand : yieldOp.getOperands()) { +diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir +--- a/test/TritonGPU/combine.mlir ++++ b/test/TritonGPU/combine.mlir +@@ -53,7 +53,7 @@ tt.func @remat(%arg0: i32) -> tensor<102 + // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[$target_layout]]> + // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[$target_layout]]> +- // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[$target_layout]]> ++ // CHECK: %6 = arith.addi %5, %4 : tensor<1024xi32, [[$target_layout]]> + // CHECK: tt.return %6 : tensor<1024xi32, [[$target_layout]]> + } + diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index bd4c52b02d48a5..e981a633d29614 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -18,5 +18,6 @@ def repo(): "//third_party/triton:b311157761.patch", "//third_party/triton:cl568176943.patch", "//third_party/triton:b304456327.patch", + "//third_party/triton:cl582925648.patch", ], ) From 1a6ed44670279ff61e9d984067127acf5704344f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 20 Nov 2023 14:10:03 -0800 Subject: [PATCH 312/391] [stream_executor] Add support for conditional If command #6973 PiperOrigin-RevId: 584124746 --- .../xla/xla/stream_executor/command_buffer.cc | 19 +++++ .../xla/xla/stream_executor/command_buffer.h | 33 ++++++-- .../xla/xla/stream_executor/cuda/BUILD | 8 ++ .../cuda/cuda_command_buffer_test.cc | 71 ++++++++++++++++ .../cuda/cuda_conditional_kernels.cu.cc | 46 +++++++++++ .../xla/stream_executor/cuda/cuda_driver.cc | 82 ++++++++++++++++--- .../xla/stream_executor/cuda/cuda_executor.cc | 10 +++ .../cuda/cuda_test_kernels.cu.cc | 9 ++ .../stream_executor/cuda/cuda_test_kernels.h | 3 + third_party/xla/xla/stream_executor/gpu/BUILD | 6 +- .../stream_executor/gpu/gpu_command_buffer.cc | 72 +++++++++++++++- .../stream_executor/gpu/gpu_command_buffer.h | 41 +++++++++- .../xla/xla/stream_executor/gpu/gpu_driver.h | 38 +++++++++ .../xla/stream_executor/gpu/gpu_executor.h | 7 ++ .../xla/xla/stream_executor/gpu/gpu_types.h | 10 +++ .../stream_executor_internal.h | 9 ++ 16 files changed, 438 insertions(+), 26 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc diff --git a/third_party/xla/xla/stream_executor/command_buffer.cc b/third_party/xla/xla/stream_executor/command_buffer.cc index 9eeb5dad43be6b..f04889419bd350 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.cc +++ b/third_party/xla/xla/stream_executor/command_buffer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" @@ -70,6 +71,20 @@ CommandBuffer& CommandBuffer::operator=(CommandBuffer&&) = default; return command_buffer; } +const internal::CommandBufferInterface* CommandBuffer::implementation() const { + return implementation_.get(); +} + +internal::CommandBufferInterface* CommandBuffer::implementation() { + return implementation_.get(); +} + +/*static*/ CommandBuffer CommandBuffer::Wrap( + StreamExecutor* executor, + std::unique_ptr implementation) { + return CommandBuffer(executor, std::move(implementation)); +} + CommandBuffer::CommandBuffer( StreamExecutor* executor, std::unique_ptr implementation) @@ -91,6 +106,10 @@ tsl::Status CommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return implementation_->MemcpyDeviceToDevice(dst, src, size); } +tsl::Status CommandBuffer::If(DeviceMemory pred, Builder then_builder) { + return implementation_->If(executor_, pred, std::move(then_builder)); +} + CommandBuffer::Mode CommandBuffer::mode() const { return implementation_->mode(); } diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index dade330b666579..5a28b409646395 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -17,10 +17,11 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_COMMAND_BUFFER_H_ #include +#include #include -#include #include "absl/functional/any_invocable.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "tsl/platform/errors.h" @@ -48,6 +49,9 @@ class CommandBufferInterface; // device. class CommandBuffer { public: + // Builder constructs nested command buffers owned by a parent command buffer. + using Builder = std::function; + ~CommandBuffer(); CommandBuffer(CommandBuffer&&); CommandBuffer& operator=(CommandBuffer&&); @@ -109,6 +113,12 @@ class CommandBuffer { tsl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size); + // Adds a conditional operation that will execute a command buffer constructed + // by `then_builder` if predicate is true. Builder should not call `Update` or + // `Finalize` on command buffer argument, parent command buffer is responsible + // for updating and finalizing conditional command buffers. + tsl::Status If(DeviceMemory pred, Builder then_builder); + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. tsl::Status Finalize(); @@ -130,15 +140,22 @@ class CommandBuffer { // Returns command buffer state. State state() const; - internal::CommandBufferInterface* implementation() { - return implementation_.get(); - } - StreamExecutor* executor() const { return executor_; } - const internal::CommandBufferInterface* implementation() const { - return implementation_.get(); - } + //===--------------------------------------------------------------------===// + // Semi-internal APIs + //===--------------------------------------------------------------------===// + + // Following APIs are public, but considered to be implementation detail and + // discouraged from uses outside of StreamExecutor package. + const internal::CommandBufferInterface* implementation() const; + internal::CommandBufferInterface* implementation(); + + // Wraps platform-specific command buffer implementation into a top-level + // StreamExecutor command buffer. + static CommandBuffer Wrap( + StreamExecutor* executor, + std::unique_ptr implementation); private: CommandBuffer( diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index bdf4aa6c6c95bc..92fbe7c4616e68 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -443,6 +443,13 @@ cuda_library( deps = ["@local_config_cuda//cuda:cuda_headers"], ) +cuda_library( + name = "cuda_conditional_kernels", + srcs = if_cuda_is_configured(["cuda_conditional_kernels.cu.cc"]), + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + xla_test( name = "cuda_kernel_test", srcs = if_cuda_is_configured(["cuda_kernel_test.cc"]), @@ -471,6 +478,7 @@ xla_test( "//xla/stream_executor:multi_platform_manager", "//xla/stream_executor:platform", "@com_google_absl//absl/log:check", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index 7f37980caa7f88..b6a142102e5bf9 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_test_kernels.h" #include "xla/stream_executor/kernel.h" @@ -216,6 +217,76 @@ TEST(CudaCommandBufferTest, LaunchNestedCommandBuffer) { ASSERT_EQ(dst, expected); } +TEST(CudaCommandBufferTest, ConditionalIf) { +#if CUDA_VERSION < 12030 + GTEST_SKIP() << "CUDA graph conditionals are not supported"; +#endif + + Platform* platform = MultiPlatformManager::PlatformWithName("CUDA").value(); + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + AddI32Kernel add(executor); + + { // Load addition kernel. + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32CudaKernel(), "add"); + TF_ASSERT_OK(executor->GetKernel(spec, &add)); + } + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + stream.ThenMemcpy(&pred, &kTrue, 1); + stream.ThenMemset32(&a, 1, byte_length); + stream.ThenMemset32(&b, 2, byte_length); + stream.ThenMemZero(&c, byte_length); + + // if (pred == true) c = a + b + CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Launch(add, ThreadDim(), BlockDim(4), a, b, c); + }; + + // Create a command buffer with a single conditional operation. + auto cmd_buffer = CommandBuffer::Create(executor).value(); + TF_ASSERT_OK(cmd_buffer.If(pred, then_builder)); + TF_ASSERT_OK(cmd_buffer.Finalize()); + + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + // Copy `c` data back to host. + std::vector dst(4, 42); + stream.ThenMemcpy(dst.data(), c, byte_length); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + + // Reset predicate to false and clear output buffer. + constexpr bool kFalse = false; + stream.ThenMemcpy(&pred, &kFalse, 1); + stream.ThenMemZero(&c, byte_length); + + // Submit the same command buffer, but this time it should not execute + // conditional branch as conditional handle should be updated to false. + TF_ASSERT_OK(executor->Submit(&stream, cmd_buffer)); + + stream.ThenMemcpy(dst.data(), c, byte_length); + std::vector zeroes = {0, 0, 0, 0}; + ASSERT_EQ(dst, zeroes); + + // TODO(ezhulenev): Test conditional command buffer updates. +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc new file mode 100644 index 00000000000000..52cebe09965b38 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/gpus/cuda/include/cuda.h" + +namespace stream_executor { +namespace cuda { +namespace { + +#if CUDA_VERSION >= 12030 + +__global__ void SetCondition(cudaGraphConditionalHandle handle, + bool* predicate) { + if (*predicate) { + cudaGraphSetConditional(handle, 1); + } else { + cudaGraphSetConditional(handle, 0); + } +} + +#else +__global__ void SetCondition() {} +#endif // CUDA_VERSION >= 12030 + +} // namespace +} // namespace cuda + +namespace gpu { +void* GetSetConditionKernel() { + return reinterpret_cast(&cuda::SetCondition); +} +} // namespace gpu + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index f1387689e0abe8..3a32a8f8df1801 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/casts.h" @@ -59,17 +60,17 @@ static constexpr bool FLAGS_gpuexec_cuda_driver_inject_init_error = false; static constexpr bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false; static constexpr bool FLAGS_gpuexec_cuda_device_0_only = false; -#define RETURN_IF_CUDA_RES_ERROR(expr, ...) \ - do { \ - CUresult _res = (expr); \ - if (ABSL_PREDICT_FALSE(_res != CUDA_SUCCESS)) { \ - if (_res == CUDA_ERROR_OUT_OF_MEMORY) \ - return tsl::errors::ResourceExhausted( \ - __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res)); \ - else \ - return tsl::errors::Internal(__VA_ARGS__, ": ", \ - ::stream_executor::gpu::ToString(_res)); \ - } \ +#define RETURN_IF_CUDA_RES_ERROR(expr, ...) \ + do { \ + CUresult _res = (expr); \ + if (ABSL_PREDICT_FALSE(_res != CUDA_SUCCESS)) { \ + if (_res == CUDA_ERROR_OUT_OF_MEMORY) \ + return absl::ResourceExhaustedError(absl::StrCat( \ + __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \ + else \ + return absl::InternalError(absl::StrCat( \ + __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \ + } \ } while (0) #define FAIL_IF_CUDA_RES_ERROR(expr, ...) \ @@ -715,6 +716,65 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) { return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; } +/* static */ tsl::Status GpuDriver::GraphConditionalHandleCreate( + GpuGraphConditionalHandle* handle, CUgraph graph, GpuContext* context, + unsigned int default_launch_value, unsigned int flags) { + VLOG(2) << "Create conditional handle for a graph " << graph + << "; context: " << context + << "; default_launch_value: " << default_launch_value + << "; flags: " << flags; + +#if CUDA_VERSION >= 12030 + RETURN_IF_CUDA_RES_ERROR( + cuGraphConditionalHandleCreate(handle, graph, context->context(), + default_launch_value, flags), + "Failed to create conditional handle for a CUDA graph"); +#else + return absl::UnimplementedError( + "CUDA graph conditional nodes are not implemented"); +#endif // CUDA_VERSION >= 12030 + return ::tsl::OkStatus(); +} + +/* static */ tsl::StatusOr +GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, + absl::Span deps, + const GpuGraphNodeParams& params) { +#if CUDA_VERSION >= 12030 + // Add conditional node to a graph. + if (auto* conditional = std::get_if(¶ms)) { + VLOG(2) << "Add conditional node to a graph " << graph + << "; deps: " << deps.size(); + + CUgraphNodeParams cu_params; + memset(&cu_params, 0, sizeof(cu_params)); + + cu_params.type = CU_GRAPH_NODE_TYPE_CONDITIONAL; + cu_params.conditional.handle = conditional->handle; + cu_params.conditional.ctx = conditional->context->context(); + cu_params.conditional.size = 1; + + switch (conditional->type) { + case GpuDriver::GpuGraphConditionalNodeParams::Type::kIf: + cu_params.conditional.type = CU_GRAPH_COND_TYPE_IF; + break; + } + + RETURN_IF_CUDA_RES_ERROR( + cuGraphAddNode(node, graph, deps.data(), deps.size(), &cu_params), + "Failed to add conditional node to a CUDA graph"); + + GpuGraphConditionalNodeParams::Result result; + result.graph = cu_params.conditional.phGraph_out[0]; + + VLOG(2) << "Created conditional CUDA graph " << result.graph; + return result; + } +#endif // CUDA_VERSION >= 12030 + + return absl::UnimplementedError("unsupported node type"); +} + /* static */ tsl::Status GpuDriver::GraphAddKernelNode( CUgraphNode* node, CUgraph graph, absl::Span deps, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 08656ad60119ba..643435a17ee2b8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -903,6 +903,16 @@ GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode) { return std::make_unique(mode, /*parent=*/this, graph); } +std::unique_ptr +GpuExecutor::GetCommandBufferImplementation(CommandBuffer::Mode mode, + GpuGraphHandle graph, + bool is_owned_graph) { + VLOG(2) << "Create CUDA command buffer (CUDA graph) from existing graph " + << graph << "; is_owned_graph=" << is_owned_graph; + return std::make_unique(mode, /*parent=*/this, graph, + is_owned_graph); +} + void* GpuExecutor::platform_specific_context() { return context_; } GpuContext* GpuExecutor::gpu_context() { return context_; } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc index 7aed936d4e9b5a..171deb030e0b83 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.cu.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_test_kernels.h" +#include + namespace stream_executor::cuda::internal { __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { @@ -22,6 +24,11 @@ __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { c[index] = a[index] + b[index]; } +__global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + c[index] = a[index] * b[index]; +} + __global__ void AddI32Ptrs3(Ptrs3 ptrs) { int index = threadIdx.x + blockIdx.x * blockDim.x; ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; @@ -29,6 +36,8 @@ __global__ void AddI32Ptrs3(Ptrs3 ptrs) { void* GetAddI32CudaKernel() { return reinterpret_cast(&AddI32); } +void* GetMulI32CudaKernel() { return reinterpret_cast(&MulI32); } + void* GetAddI32Ptrs3CudaKernel() { return reinterpret_cast(&AddI32Ptrs3); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h index 709216c804bc7f..c02682b22f9d13 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_test_kernels.h @@ -85,6 +85,9 @@ struct Ptrs3 { // Returns a pointer to device kernel compiled from the CUDA C++ code above. void* GetAddI32CudaKernel(); +// Returns a pointer to device kernel doing multiplication instead of addition. +void* GetMulI32CudaKernel(); + // Returns a pointer to device kernel compiled from the CUDA C++ but with all // three pointers passed to argument as an instance of `Ptr3` template to test // StreamExecutor arguments packing for custom C++ types. diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 9f8b083d150d93..aac97bf84a779c 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -74,6 +74,7 @@ cc_library( ":gpu_types_header", "//xla/stream_executor:device_options", "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ] + if_libtpu( @@ -115,7 +116,10 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ], + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + "//xla/stream_executor/cuda:cuda_conditional_kernels", + ]), ) cc_library( diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 0f98cb3be02da9..bd65fdb666a610 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -27,12 +27,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -89,8 +91,11 @@ static int64_t NotifyExecDestroyed() { //===----------------------------------------------------------------------===// GpuCommandBuffer::GpuCommandBuffer(Mode mode, GpuExecutor* parent, - GpuGraphHandle graph) - : mode_(mode), parent_(parent), graph_(graph) {} + GpuGraphHandle graph, bool is_owned_graph) + : mode_(mode), + parent_(parent), + graph_(graph), + is_owned_graph_(is_owned_graph) {} GpuCommandBuffer::~GpuCommandBuffer() { if (exec_ != nullptr) { @@ -100,7 +105,7 @@ GpuCommandBuffer::~GpuCommandBuffer() { auto st = GpuDriver::DestroyGraphExec(exec_); CHECK(st.ok()) << "Failed to destroy GPU graph exec: " << st.message(); } - if (graph_ != nullptr) { + if (graph_ != nullptr && is_owned_graph_) { auto st = GpuDriver::DestroyGraph(graph_); CHECK(st.ok()) << "Failed to destroy GPU graph: " << st.message(); } @@ -242,6 +247,67 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst, return UnsupportedStateError(state_); } +tsl::Status GpuCommandBuffer::If(StreamExecutor* executor, + DeviceMemory predicate, + CommandBuffer::Builder then_builder) { + DCHECK(executor->implementation() == parent_); // NOLINT + + SetConditionKernel set_condition(executor); + + { // Load kernels that update condition handle value. + MultiKernelLoaderSpec spec(/*arity=*/1); + spec.AddInProcessSymbol(gpu::GetSetConditionKernel(), "set_condition"); + TF_RETURN_IF_ERROR(executor->GetKernel(spec, &set_condition)); + } + + using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams; + using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result; + + // Conditional command buffers always created in nested mode. + CommandBuffer::Mode nested = CommandBuffer::Mode::kNested; + + if (state_ == State::kCreate) { + // Create a handle for a conditional node. + GpuGraphConditionalHandle handle; + TF_RETURN_IF_ERROR(GpuDriver::GraphConditionalHandleCreate( + &handle, graph_, parent_->gpu_context(), 0, 0)); + + // Add a kernel to update conditional handle value based on a predicate. + TF_RETURN_IF_ERROR( + Launch(set_condition, ThreadDim(), BlockDim(), handle, predicate)); + + // Add conditional node to the graph. + Dependencies deps = GetDependencies(); + GpuGraphNodeHandle* node = &nodes_.emplace_back(); + + ConditionalParams params; + params.type = ConditionalParams::Type::kIf; + params.handle = handle; + params.context = parent_->gpu_context(); + + TF_ASSIGN_OR_RETURN( + GpuDriver::GpuGraphNodeResult result, + GpuDriver::GraphAddNode(node, graph_, absl::MakeSpan(deps), params)); + + // Set up conditional command buffer. + GpuGraphHandle then_graph = std::get(result).graph; + + // Wrap conditional graph into command buffer and pass it to the builder. + auto then_command_buffer = CommandBuffer::Wrap( + executor, + parent_->GetCommandBufferImplementation(nested, then_graph, false)); + TF_RETURN_IF_ERROR(then_builder(&then_command_buffer)); + + return tsl::OkStatus(); + } + + // TODO(ezhulenev): For command buffer update we need to keep conditional + // handle for the command and command buffer itself as it has a mapping to + // node handles required for updates. + + return UnsupportedStateError(state_); +} + tsl::Status GpuCommandBuffer::Finalize() { TF_RETURN_IF_ERROR(CheckNotFinalized()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 70215b30d4f24a..517e124158cce2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -23,11 +23,13 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace stream_executor::gpu { @@ -37,7 +39,7 @@ namespace stream_executor::gpu { class GpuCommandBuffer : public internal::CommandBufferInterface { public: GpuCommandBuffer(CommandBuffer::Mode mode, GpuExecutor* parent, - GpuGraphHandle graph); + GpuGraphHandle graph, bool is_owned_graph = true); ~GpuCommandBuffer() override; tsl::Status Trace(Stream* stream, @@ -52,6 +54,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) override; + tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, + CommandBuffer::Builder then_builder) override; + tsl::Status Finalize() override; tsl::Status Update() override; @@ -61,6 +66,12 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { CommandBuffer::Mode mode() const override { return mode_; } CommandBuffer::State state() const override { return state_; } + // A helper template for launching typed kernels. + template + tsl::Status Launch(const TypedKernel& kernel, + const ThreadDim& threads, const BlockDim& blocks, + Args... args); + // We track the total number of allocated and alive executable graphs in the // process to track the command buffers resource usage. Executable graph // allocates resources on a GPU devices (rule of thumb is ~8kb per node), so @@ -80,6 +91,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { } private: + using SetConditionKernel = + TypedKernel>; + // TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a // dependency between all nodes added to a command buffer. We need a concept // of a barrier at a command buffer level. @@ -105,8 +119,9 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { CommandBuffer::State state_ = CommandBuffer::State::kCreate; GpuExecutor* parent_; // not owned, must outlive *this - GpuGraphHandle graph_ = nullptr; // owned handle - GpuGraphExecHandle exec_ = nullptr; // owned handle + GpuGraphHandle graph_ = nullptr; // owned if `is_owned_graph_ == true` + bool is_owned_graph_ = true; // ownership of `graph_` + GpuGraphExecHandle exec_ = nullptr; // owned // Handles to graph nodes corresponding to command buffer commands. Owned by // the `graph_` instance. @@ -120,6 +135,26 @@ class GpuCommandBuffer : public internal::CommandBufferInterface { int64_t num_updates_ = 0; }; +template +inline tsl::Status GpuCommandBuffer::Launch( + const TypedKernel& kernel, const ThreadDim& threads, + const BlockDim& blocks, Args... args) { + auto kernel_args = PackKernelArgs(kernel, args...); + TF_RETURN_IF_ERROR(Launch(threads, blocks, kernel, *kernel_args)); + return tsl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// Implementation details device kernels required by GpuCommandBuffer. +//===----------------------------------------------------------------------===// + +// See `cuda_conditional_kernels.cu.cc` for CUDA implementations. These are +// various kernels that update Gpu conditionals based on the device memory +// values, and allow implementing on-device control flow via conditional command +// buffers. + +void* GetSetConditionKernel(); + } // namespace stream_executor::gpu #endif // XLA_STREAM_EXECUTOR_GPU_GPU_COMMAND_BUFFER_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index d48faa11f7d9db..58592442908112 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -21,7 +21,9 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" #include "xla/stream_executor/device_options.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" @@ -424,6 +426,42 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g57c87f4ba6af41825627cdd4e5a8c52b static tsl::Status DeviceGraphMemTrim(GpuDeviceHandle device); + // Creates a conditional handle. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gece6f3b9e85d0edb8484d625fe567376 + static tsl::Status GraphConditionalHandleCreate( + GpuGraphConditionalHandle* handle, GpuGraphHandle graph, + GpuContext* context, unsigned int default_launch_value, + unsigned int flags); + + // Conditional node parameters. + // https://docs.nvidia.com/cuda/cuda-driver-api/structCUDA__CONDITIONAL__NODE__PARAMS.html#structCUDA__CONDITIONAL__NODE__PARAMS + struct GpuGraphConditionalNodeParams { + // Conditional node type. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g04ade961d0263336423eb216fbe514da + enum class Type { kIf }; + + // A struct for returning output arguments back to the caller. + struct Result { + GpuGraphHandle graph; + }; + + Type type; + GpuGraphConditionalHandle handle; + GpuContext* context; + }; + + // Graph node parameters + // https://docs.nvidia.com/cuda/cuda-driver-api/structCUgraphNodeParams.html#structCUgraphNodeParams + using GpuGraphNodeParams = std::variant; + using GpuGraphNodeResult = + std::variant; + + // Adds a node of arbitrary type to a graph. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g4210c258cbba352040a26d1b4e658f9d + static tsl::StatusOr GraphAddNode( + GpuGraphNodeHandle* node, GpuGraphHandle graph, + absl::Span deps, const GpuGraphNodeParams& params); + // Creates a kernel execution node and adds it to a graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 536356e1f7eec6..6bc5ee304498f7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -263,6 +263,13 @@ class GpuExecutor : public internal::StreamExecutorInterface { tsl::StatusOr> GetCommandBufferImplementation(CommandBuffer::Mode mode) override; + // Wraps existing Gpu graph handle into an instance of Gpu command buffer. + // This is required for wrapping nested graphs constructed for conditional + // nodes and owned by a parent graph executable. + std::unique_ptr + GetCommandBufferImplementation(CommandBuffer::Mode mode, GpuGraphHandle graph, + bool is_owned_graph); + void* platform_specific_context() override; GpuContext* gpu_context(); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_types.h b/third_party/xla/xla/stream_executor/gpu/gpu_types.h index dea81d66a1d59d..9702232c9a565d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_types.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_types.h @@ -36,6 +36,10 @@ limitations under the License. namespace stream_executor { namespace gpu { +// An empty struct to be used as a handle for all unsupported features in +// current CUDA/HIP version. +struct UnsupportedGpuFeature {}; + #if TENSORFLOW_USE_ROCM using GpuContextHandle = hipCtx_t; @@ -79,6 +83,12 @@ using GpuGraphHandle = CUgraph; using GpuGraphExecHandle = CUgraphExec; using GpuGraphNodeHandle = CUgraphNode; +#if CUDA_VERSION >= 12030 +using GpuGraphConditionalHandle = CUgraphConditionalHandle; +#else +using GpuGraphConditionalHandle = UnsupportedGpuFeature; +#endif // #if CUDA_VERSION >= 12030 + #endif } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 3c98029d3dbb7a..b2fb31023ecb7f 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -141,6 +141,15 @@ class CommandBufferInterface { const DeviceMemoryBase& src, uint64_t size) = 0; + // For all conditional command APIs defined below, nested command buffers + // constructed for conditional branches owned by *this and should never be + // finalized or updated inside builders. + + // Adds a conditional operation that will run a command buffer constructed by + // `then_builder` if `predicate` value is `true`. + virtual tsl::Status If(StreamExecutor* executor, DeviceMemory predicate, + CommandBuffer::Builder then_builder) = 0; + // Finalizes command buffer and makes it executable. Once command buffer is // finalized no commands can be added to it. virtual tsl::Status Finalize() = 0; From 494aeef21ea93c13c8aa70dc5044021f7eb39dfa Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 20 Nov 2023 14:44:52 -0800 Subject: [PATCH 313/391] [xla:ffi] Add support for passing dictionaries in attributes Also add support for interpreting all "top level" attributes as `Dictionary` to allow defining FFI handlers with variadic number of attributes decoded by the handler itself. PiperOrigin-RevId: 584133415 --- third_party/xla/xla/ffi/api/api.h | 217 +++++++++++++----- third_party/xla/xla/ffi/api/c_api.h | 1 + third_party/xla/xla/ffi/call_frame.cc | 96 +++++--- third_party/xla/xla/ffi/call_frame.h | 62 +++-- third_party/xla/xla/ffi/ffi_test.cc | 114 ++++++++- .../service/gpu/runtime3/custom_call_thunk.cc | 5 +- .../service/gpu/runtime3/custom_call_thunk.h | 5 +- 7 files changed, 387 insertions(+), 113 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index f5f638ce5bdc60..06754dc19b6e30 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -148,23 +148,59 @@ XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, // Type tags for distinguishing handler argument types //===----------------------------------------------------------------------===// -// Forward declare class defined below. -class RemainingArgs; - namespace internal { +// WARNING: A lot of template metaprogramming on top of C++ variadic templates +// parameter packs. We need this to be able to pattern match FFI handler +// signature at compile time. + +// A type tag to forward all remaining args as `RemainingArgs`. +struct RemainingArgsTag {}; + // A type tag to distinguish arguments tied to the attributes in the // `Binding` variadic template argument. template struct AttrTag {}; +// A type tag to forward all attributes as `Dictionary`. +struct AttrsTag {}; + // A type tag to distinguish arguments extracted from an execution context. template struct CtxTag {}; +//----------------------------------------------------------------------------// +// A template for counting tagged arguments in the Ts pack (i.e. attributes). +//----------------------------------------------------------------------------// + +template