From 5feadf761fca4b1f6db29edd0769115474c94291 Mon Sep 17 00:00:00 2001 From: orangekame3 Date: Sat, 30 Sep 2023 09:57:24 +0900 Subject: [PATCH 001/670] remove ioutil pkg --- tensorflow/go/example_inception_inference_test.go | 7 +++---- tensorflow/go/genop/internal/genop.go | 6 +++--- tensorflow/go/genop/main.go | 5 ++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go index 475619c55a5472..13a9316298a6d9 100644 --- a/tensorflow/go/example_inception_inference_test.go +++ b/tensorflow/go/example_inception_inference_test.go @@ -22,14 +22,13 @@ import ( "flag" "fmt" "io" - "io/ioutil" "log" "net/http" "os" "path/filepath" - "github.com/tensorflow/tensorflow/tensorflow/go/op" tf "github.com/tensorflow/tensorflow/tensorflow/go" + "github.com/tensorflow/tensorflow/tensorflow/go/op" ) func Example() { @@ -88,7 +87,7 @@ func Example() { log.Fatal(err) } - model, err := ioutil.ReadFile(modelfile) + model, err := os.ReadFile(modelfile) if err != nil { log.Fatal(err) } @@ -145,7 +144,7 @@ func printBestLabel(probabilities []float32, labels []string) { // Convert the image in filename to a Tensor suitable as input to the Inception model. func makeTensorFromImage(filename string) (*tf.Tensor, error) { - bytes, err := ioutil.ReadFile(filename) + bytes, err := os.ReadFile(filename) if err != nil { return nil, err } diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index 2b72b236a813a6..0c92d7e309aaca 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -39,7 +39,7 @@ import "C" import ( "fmt" "io" - "io/ioutil" + "os" "path" "reflect" "sort" @@ -96,7 +96,7 @@ func registeredOps() (*odpb.OpList, *apiDefMap, error) { } func updateAPIDefs(m *apiDefMap, dir string) error { - files, err := ioutil.ReadDir(dir) + files, err := os.ReadDir(dir) if err != nil { return err } @@ -104,7 +104,7 @@ func updateAPIDefs(m *apiDefMap, dir string) error { if file.IsDir() || !strings.HasSuffix(file.Name(), ".pbtxt") { continue } - data, err := ioutil.ReadFile(path.Join(dir, file.Name())) + data, err := os.ReadFile(path.Join(dir, file.Name())) if err != nil { return fmt.Errorf("failed to read %q: %v", file.Name(), err) } diff --git a/tensorflow/go/genop/main.go b/tensorflow/go/genop/main.go index 87c1d27c3b53d7..370a9aaec10a80 100644 --- a/tensorflow/go/genop/main.go +++ b/tensorflow/go/genop/main.go @@ -21,7 +21,6 @@ import ( "bytes" "flag" "go/format" - "io/ioutil" "log" "os" "path/filepath" @@ -42,7 +41,7 @@ func main() { log.Fatal("-outfile must be set") } if *header != "" { - hdr, err := ioutil.ReadFile(*header) + hdr, err := os.ReadFile(*header) if err != nil { log.Fatalf("Unable to read %s: %v", *header, err) } @@ -64,7 +63,7 @@ func main() { if err != nil { log.Fatalf("Failed to generate valid source? 'go fmt' failed: %v", err) } - if err := ioutil.WriteFile(*filename, formatted, 0644); err != nil { + if err := os.WriteFile(*filename, formatted, 0644); err != nil { log.Fatalf("Failed to write to %q: %v", *filename, err) } } From fc1ee3bcc2d07d30bcd9480280796eb014bb0e2b Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 19 Jan 2024 10:11:06 -0800 Subject: [PATCH 002/670] [onednn] Enable auto_mixed_precision for fp16 on cpu --- .../optimizers/auto_mixed_precision.cc | 51 ++++++++++----- .../optimizers/auto_mixed_precision_lists.h | 24 +++++-- .../optimizers/auto_mixed_precision_test.cc | 63 ++++++++++++++++--- 3 files changed, 109 insertions(+), 29 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 7cab9376515a87..e8331ea8318490 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { namespace grappler { @@ -1028,6 +1029,8 @@ std::unordered_map GetDevices(Cluster* cluster) { return devices; } +int GetNumGPUs(const Cluster& cluster); + class AutoMixedPrecisionImpl { public: // CastType indicates the type of inserted Cast op @@ -1038,7 +1041,8 @@ class AutoMixedPrecisionImpl { AutoMixedPrecisionImpl(Cluster* cluster, const std::unordered_set& nodes_to_preserve, GraphDef* graph, string id, - AutoMixedPrecisionMode mode) + AutoMixedPrecisionMode mode, + const bool run_fp16_on_cpu) : devices_(GetDevices(cluster)), virtual_placer_(devices_), nodes_to_preserve_(nodes_to_preserve), @@ -1053,7 +1057,9 @@ class AutoMixedPrecisionImpl { target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA || mode_ == AutoMixedPrecisionMode::CPU) ? DT_HALF - : DT_BFLOAT16) {} + : DT_BFLOAT16), + num_gpus_(GetNumGPUs(*cluster)), + run_fp16_on_cpu_(run_fp16_on_cpu) {} Status Optimize(); @@ -1063,8 +1069,8 @@ class AutoMixedPrecisionImpl { std::unique_ptr get_mixed_precision_lists() const { switch (mode_) { case AutoMixedPrecisionMode::CUDA: - return std::make_unique(cuda_version_, - cudnn_version_); + return std::make_unique( + cuda_version_, cudnn_version_, run_fp16_on_cpu_); case AutoMixedPrecisionMode::BF16: return std::make_unique(); case AutoMixedPrecisionMode::CPU: @@ -1147,6 +1153,8 @@ class AutoMixedPrecisionImpl { gtl::FlatSet f16_clearlist_; absl::flat_hash_set should_process_nodes_; DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16 + int num_gpus_ = 0; + bool run_fp16_on_cpu_ = false; }; NodeDef AutoMixedPrecisionImpl::BuildCastNode( @@ -1421,10 +1429,15 @@ Status AutoMixedPrecisionImpl::Optimize() { string device_type; switch (mode_) { case AutoMixedPrecisionMode::CUDA: - device_type = DEVICE_GPU; - should_process = - !MustPreserve(node) && IsOnDevice(node, device_type) && - (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node)); + if (!run_fp16_on_cpu_) { + device_type = DEVICE_GPU; + should_process = + !MustPreserve(node) && IsOnDevice(node, device_type) && + (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node)); + } else { + device_type = DEVICE_CPU; + should_process = !MustPreserve(node) && IsOnDevice(node, device_type); + } break; case AutoMixedPrecisionMode::BF16: case AutoMixedPrecisionMode::CPU: @@ -1857,7 +1870,7 @@ void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow( const absl::flat_hash_set& deny_set, absl::flat_hash_set* allow_set) const { // Currently only target for oneDNN - if (mode_ != AutoMixedPrecisionMode::BF16) { + if (mode_ != AutoMixedPrecisionMode::BF16 && !run_fp16_on_cpu_) { return; } for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) { @@ -2298,11 +2311,19 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, *output = item.graph; int num_gpus = GetNumGPUs(*cluster); + bool run_fp16_on_cpu = false; if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) { - // AutoMixedPrecision is currently only tuned for GPU. - LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() - << " graph optimizer"; - return OkStatus(); + // No GPUs to run AutoMixedPrecision in FP16. + // Check if CPU supports + if (!IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { + LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU/GPU"; + return OkStatus(); + } else { + run_fp16_on_cpu = true; + LOG(INFO) << "Running " << name() << " graph optimizer on CPU"; + } + } else { + LOG(INFO) << "Running " << name() << " graph optimizer on GPU"; } if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) { @@ -2312,11 +2333,11 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, // Optimize the output graph in-place. AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output, - item.id, mode_); + item.id, mode_, run_fp16_on_cpu); if (item.id == "tf_graph") { LOG(INFO) << "Running " << name() << " graph optimizer"; } else { - VLOG(1) << "Running " << name() << " graph optimizer on " << item.id; + VLOG(INFO) << "Running " << name() << " graph optimizer on " << item.id; } Status status = optimizer.Optimize(); if (!status.ok()) { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index f8f393a1cb960f..63a45f33b977ce 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { namespace grappler { @@ -106,8 +107,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } public: - AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) - : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} + AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version, + bool run_fp16_on_cpu = false) + : cuda_version_(cuda_version), + cudnn_version_(cudnn_version), + run_fp16_on_cpu_(run_fp16_on_cpu) {} gtl::FlatSet AllowList() override { auto list = gtl::FlatSet{ @@ -143,13 +147,13 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { #if TENSORFLOW_USE_ROCM if (true) { #else - if (cuda_version_ >= 9010) { + if (cuda_version_ >= 9010 || run_fp16_on_cpu_) { // Fp16 BatchMatMul is slow before CUDA 9.1. #endif list.insert("BatchMatMul"); list.insert("BatchMatMulV2"); } - if (cudnn_version_ >= 7602) { + if (cudnn_version_ >= 7602 || run_fp16_on_cpu_) { // Fp16 3D conv is slow before CUDNN 7.6.2. list.insert("Conv3D"); list.insert("Conv3DBackpropFilter"); @@ -157,7 +161,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { list.insert("Conv3DBackpropInput"); list.insert("Conv3DBackpropInputV2"); } - if (cudnn_version_ >= 8000) { + if (cudnn_version_ >= 8000 || run_fp16_on_cpu_) { list.insert("DepthwiseConv2dNative"); list.insert("DepthwiseConv2dNativeBackpropFilter"); list.insert("DepthwiseConv2dNativeBackpropInput"); @@ -220,6 +224,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { "Tanh", "TanhGrad", }; + if (run_fp16_on_cpu_) { + list.insert("Rsqrt"); + list.insert("Square"); + list.insert("SquaredDifference"); + } UpdateList("INFERLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -352,6 +361,10 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { "Where", "ZerosLike", }; + if (run_fp16_on_cpu_) { + list.insert("ResizeBilinear"); + list.insert("ScatterNd"); + } AddTensorListOps(&list); UpdateList("CLEARLIST", &list); return list; @@ -360,6 +373,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { private: int cuda_version_; int cudnn_version_; + bool run_fp16_on_cpu_; }; class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 689185fb08923d..91ce01425d5c1e 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -93,7 +93,7 @@ void VerifyGraphsEquivalent(const GraphDef& original_graph, // because otherwise the optimizer will not turn clearlist nodes to float16. // When looking at clearlist nodes, this optimizer checks if the nodes have a // float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all. -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL const std::pair kMinGPUArch = {7, 0}; @@ -112,19 +112,33 @@ class AutoMixedPrecisionTest : public GrapplerTest { if (gpu_available_) { virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1)); } else { - DeviceProperties device_properties; - device_properties.set_type("GPU"); + if( num_gpus > 0) { + DeviceProperties device_properties; + device_properties.set_type("GPU"); #if GOOGLE_CUDA - device_properties.mutable_environment()->insert({"architecture", "7"}); - device_properties.mutable_environment()->insert({"cuda", "9010"}); + device_properties.mutable_environment()->insert({"architecture", "7"}); + device_properties.mutable_environment()->insert({"cuda", "9010"}); #else - device_properties.mutable_environment()->insert( - {"architecture", "gfx906"}); + device_properties.mutable_environment()->insert( + {"architecture", "gfx906"}); #endif - virtual_cluster_.reset( - new VirtualCluster({{"/GPU:1", device_properties}})); + virtual_cluster_.reset( + new VirtualCluster({{"/GPU:1", device_properties}})); + } else { + // try running on CPU + DeviceProperties device_properties; + device_properties.set_type("CPU"); + virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); + } } TF_CHECK_OK(virtual_cluster_->Provision()); + + run_fp16_on_cpu_ = false; +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + run_fp16_on_cpu_ = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 + + skip_test_ = !gpu_available_ && (!IsMKLEnabled() || !run_fp16_on_cpu_); } void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } @@ -172,6 +186,8 @@ class AutoMixedPrecisionTest : public GrapplerTest { double input_min, double input_max, double atol, double rtol, const std::function& test_op_factory) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; + int size = 128; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output eye = ops::Const(s.WithOpName("eye"), @@ -210,11 +226,18 @@ class AutoMixedPrecisionTest : public GrapplerTest { } } + bool ShouldSkipTest() { + return skip_test_; + } + std::unique_ptr virtual_cluster_; bool gpu_available_; + bool skip_test_; + bool run_fp16_on_cpu_; }; TEST_F(AutoMixedPrecisionTest, NoOp) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.234f, {32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -252,6 +275,7 @@ TEST_F(AutoMixedPrecisionTest, NoOp) { } TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF); @@ -290,6 +314,7 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { } TEST_F(AutoMixedPrecisionTest, Simple) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -339,6 +364,7 @@ TEST_F(AutoMixedPrecisionTest, Simple) { } TEST_F(AutoMixedPrecisionTest, NoInferOp) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY", 1 /* replace */); @@ -391,6 +417,7 @@ TEST_F(AutoMixedPrecisionTest, NoInferOp) { } TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -430,6 +457,7 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { } TEST_F(AutoMixedPrecisionTest, PreserveFetches) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -475,6 +503,10 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) { } TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { + if (GetNumAvailableGPUs() == 0) { + GTEST_SKIP() << "This test is not required on CPU"; + } + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -516,6 +548,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { } TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT); @@ -560,6 +593,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { } TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); // Uses NHWC data format because non-GPU execution does not support NCHW. Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16}); @@ -619,6 +653,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { } TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -659,6 +694,7 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { } TEST_F(AutoMixedPrecisionTest, ExistingCast) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), true, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT); @@ -691,6 +727,7 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) { } TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -753,6 +790,7 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { } TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -824,6 +862,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { } TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -887,6 +926,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { } TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); @@ -937,6 +977,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { } TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -997,6 +1038,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { } TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; // This test passes a tensor list handle through a function with its own // Tensor List ops inside to test that the types are not changed to a // conflicting state. @@ -1105,6 +1147,7 @@ bool IsSupportedGPU(const Cluster& cluster) { } TEST_F(AutoMixedPrecisionTest, BatchMatMul) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32}); Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input); @@ -1437,6 +1480,7 @@ class AutoMixedPrecisionSimulateGpuTest : public GrapplerTest { } }; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_NoGpu) { TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ false); } @@ -1456,6 +1500,7 @@ TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu_CpuScope) { } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL #if INTEL_MKL From 9a3c8d59eb65fac7253b8f9a0db92dc93e22e374 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Tue, 23 Jan 2024 12:41:51 -0800 Subject: [PATCH 003/670] Address review comments - update comments as per guidelines. --- tensorflow/core/grappler/optimizers/auto_mixed_precision.cc | 2 +- .../core/grappler/optimizers/auto_mixed_precision_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index e8331ea8318490..466de8be2b4e83 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -2314,7 +2314,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, bool run_fp16_on_cpu = false; if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) { // No GPUs to run AutoMixedPrecision in FP16. - // Check if CPU supports + // Check if CPU supports FP16. if (!IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU/GPU"; return OkStatus(); diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 91ce01425d5c1e..3dc34150cb0806 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -125,7 +125,7 @@ class AutoMixedPrecisionTest : public GrapplerTest { virtual_cluster_.reset( new VirtualCluster({{"/GPU:1", device_properties}})); } else { - // try running on CPU + // When no GPUs are available, try running on CPU. DeviceProperties device_properties; device_properties.set_type("CPU"); virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); From 27b1a9623f72a5acd13a864d1d00b47a85ca43fe Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Tue, 30 Jan 2024 12:14:03 -0800 Subject: [PATCH 004/670] Address review comments --- .../optimizers/auto_mixed_precision_test.cc | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 3dc34150cb0806..c6234ac74a6cb7 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -111,34 +111,32 @@ class AutoMixedPrecisionTest : public GrapplerTest { #endif if (gpu_available_) { virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1)); - } else { - if( num_gpus > 0) { - DeviceProperties device_properties; - device_properties.set_type("GPU"); + } else if( num_gpus > 0) { + DeviceProperties device_properties; + device_properties.set_type("GPU"); #if GOOGLE_CUDA - device_properties.mutable_environment()->insert({"architecture", "7"}); - device_properties.mutable_environment()->insert({"cuda", "9010"}); + device_properties.mutable_environment()->insert({"architecture", "7"}); + device_properties.mutable_environment()->insert({"cuda", "9010"}); #else - device_properties.mutable_environment()->insert( - {"architecture", "gfx906"}); + device_properties.mutable_environment()->insert( + {"architecture", "gfx906"}); #endif - virtual_cluster_.reset( - new VirtualCluster({{"/GPU:1", device_properties}})); - } else { - // When no GPUs are available, try running on CPU. - DeviceProperties device_properties; - device_properties.set_type("CPU"); - virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); - } + virtual_cluster_.reset( + new VirtualCluster({{"/GPU:1", device_properties}})); + } else { + // When no GPUs are available, try running on CPU. + DeviceProperties device_properties; + device_properties.set_type("CPU"); + virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); } TF_CHECK_OK(virtual_cluster_->Provision()); - run_fp16_on_cpu_ = false; + bool run_fp16_on_cpu = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - run_fp16_on_cpu_ = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + run_fp16_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 - skip_test_ = !gpu_available_ && (!IsMKLEnabled() || !run_fp16_on_cpu_); + skip_test_ = !gpu_available_ && (!IsMKLEnabled() || !run_fp16_on_cpu); } void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } @@ -233,7 +231,6 @@ class AutoMixedPrecisionTest : public GrapplerTest { std::unique_ptr virtual_cluster_; bool gpu_available_; bool skip_test_; - bool run_fp16_on_cpu_; }; TEST_F(AutoMixedPrecisionTest, NoOp) { From 9a3ef902974965aacfaf2192b3a7253dcd0609f8 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 9 Feb 2024 12:30:25 -0800 Subject: [PATCH 005/670] Address review comments --- .../optimizers/auto_mixed_precision.cc | 75 +++--- .../optimizers/auto_mixed_precision.h | 12 +- .../optimizers/auto_mixed_precision_lists.h | 40 ++-- .../optimizers/auto_mixed_precision_test.cc | 221 ++++++++++-------- .../grappler/optimizers/meta_optimizer.cc | 10 +- 5 files changed, 198 insertions(+), 160 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 466de8be2b4e83..40201896086e4f 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -47,7 +47,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/util/env_var.h" -#include "tensorflow/core/util/util.h" namespace tensorflow { namespace grappler { @@ -1029,8 +1028,6 @@ std::unordered_map GetDevices(Cluster* cluster) { return devices; } -int GetNumGPUs(const Cluster& cluster); - class AutoMixedPrecisionImpl { public: // CastType indicates the type of inserted Cast op @@ -1041,8 +1038,7 @@ class AutoMixedPrecisionImpl { AutoMixedPrecisionImpl(Cluster* cluster, const std::unordered_set& nodes_to_preserve, GraphDef* graph, string id, - AutoMixedPrecisionMode mode, - const bool run_fp16_on_cpu) + AutoMixedPrecisionMode mode) : devices_(GetDevices(cluster)), virtual_placer_(devices_), nodes_to_preserve_(nodes_to_preserve), @@ -1055,11 +1051,10 @@ class AutoMixedPrecisionImpl { num_nonvar_casts_to_f16_(0), mode_(mode), target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA || - mode_ == AutoMixedPrecisionMode::CPU) + mode_ == AutoMixedPrecisionMode::CPU || + mode_ == AutoMixedPrecisionMode::FP16_CPU) ? DT_HALF - : DT_BFLOAT16), - num_gpus_(GetNumGPUs(*cluster)), - run_fp16_on_cpu_(run_fp16_on_cpu) {} + : DT_BFLOAT16) {} Status Optimize(); @@ -1069,16 +1064,20 @@ class AutoMixedPrecisionImpl { std::unique_ptr get_mixed_precision_lists() const { switch (mode_) { case AutoMixedPrecisionMode::CUDA: - return std::make_unique( - cuda_version_, cudnn_version_, run_fp16_on_cpu_); + return std::make_unique(cuda_version_, + cudnn_version_); case AutoMixedPrecisionMode::BF16: - return std::make_unique(); + return std::make_unique( + AutoMixedPrecisionMode::BF16); case AutoMixedPrecisionMode::CPU: // Note: this is not a typo here. AutoMixedPrecisionListsCuda is used // intentionally to make CPU and GPU have the same fp16 ops. return std::make_unique( /*cuda_version=*/10000, // Hardcode cuda and cudnn version so /*cudnn_version=*/8000); // CPU emulates the same ops on GPU. + case AutoMixedPrecisionMode::FP16_CPU: + return std::make_unique( + AutoMixedPrecisionMode::FP16_CPU); } } Status PrintDebugLogs(bool preop, size_t timestamp); @@ -1153,8 +1152,6 @@ class AutoMixedPrecisionImpl { gtl::FlatSet f16_clearlist_; absl::flat_hash_set should_process_nodes_; DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16 - int num_gpus_ = 0; - bool run_fp16_on_cpu_ = false; }; NodeDef AutoMixedPrecisionImpl::BuildCastNode( @@ -1392,9 +1389,11 @@ Status AutoMixedPrecisionImpl::Optimize() { "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level)); optimization_level = absl::AsciiStrToUpper(optimization_level); force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL"; - if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::BF16) { - // Many ops do not support bfloat16 on the CPU so we disallowing forcing to - // bfloat16. + if (force_all_fp16_ && + (mode_ == AutoMixedPrecisionMode::BF16 || + mode_ == AutoMixedPrecisionMode::FP16_CPU)) { + // Many ops do not support bfloat16/fp16 on the CPU. So, disallowing + // forcing to bfloat16/fp16. return errors::InvalidArgument( "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to " "UNSAFE_FORCE_ALL when oneDNN is used"); @@ -1429,18 +1428,14 @@ Status AutoMixedPrecisionImpl::Optimize() { string device_type; switch (mode_) { case AutoMixedPrecisionMode::CUDA: - if (!run_fp16_on_cpu_) { - device_type = DEVICE_GPU; - should_process = - !MustPreserve(node) && IsOnDevice(node, device_type) && - (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node)); - } else { - device_type = DEVICE_CPU; - should_process = !MustPreserve(node) && IsOnDevice(node, device_type); - } + device_type = DEVICE_GPU; + should_process = + !MustPreserve(node) && IsOnDevice(node, device_type) && + (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node)); break; case AutoMixedPrecisionMode::BF16: case AutoMixedPrecisionMode::CPU: + case AutoMixedPrecisionMode::FP16_CPU: device_type = DEVICE_CPU; should_process = !MustPreserve(node) && IsOnDevice(node, device_type); break; @@ -1870,7 +1865,8 @@ void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow( const absl::flat_hash_set& deny_set, absl::flat_hash_set* allow_set) const { // Currently only target for oneDNN - if (mode_ != AutoMixedPrecisionMode::BF16 && !run_fp16_on_cpu_) { + if (mode_ != AutoMixedPrecisionMode::BF16 && + mode_ != AutoMixedPrecisionMode::FP16_CPU) { return; } for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) { @@ -2311,20 +2307,19 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, *output = item.graph; int num_gpus = GetNumGPUs(*cluster); - bool run_fp16_on_cpu = false; if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) { // No GPUs to run AutoMixedPrecision in FP16. - // Check if CPU supports FP16. - if (!IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { - LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU/GPU"; - return OkStatus(); - } else { - run_fp16_on_cpu = true; - LOG(INFO) << "Running " << name() << " graph optimizer on CPU"; - } - } else { - LOG(INFO) << "Running " << name() << " graph optimizer on GPU"; + LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() + << " graph optimizer"; + return OkStatus(); + } + // Check if CPU supports FP16 + if (mode_ == AutoMixedPrecisionMode::FP16_CPU && + !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { + LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU"; + return OkStatus(); } + LOG(INFO) << "Running " << name() << " graph optimizer "; if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) { LOG(WARNING) << "Note: GPUs detected. Using " << name() @@ -2333,11 +2328,11 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, // Optimize the output graph in-place. AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output, - item.id, mode_, run_fp16_on_cpu); + item.id, mode_); if (item.id == "tf_graph") { LOG(INFO) << "Running " << name() << " graph optimizer"; } else { - VLOG(INFO) << "Running " << name() << " graph optimizer on " << item.id; + LOG(INFO) << "Running " << name() << " graph optimizer on " << item.id; } Status status = optimizer.Optimize(); if (!status.ok()) { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision.h index 0807d740f1448c..3f478ec3038534 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.h @@ -26,15 +26,16 @@ namespace grappler { // CUDA: convert to float16 on GPU // BF16: convert to bfloat16 on CPU // CPU: emulate float16 on CPU without changing operator kernel -enum class AutoMixedPrecisionMode { CUDA, BF16, CPU }; +// FP16_CPU : convert to float16 on CPU +enum class AutoMixedPrecisionMode { CUDA, BF16, CPU, FP16_CPU }; // Convert data types to float16 or bfloat16 where appropriate to improve // performance on GPUs or CPUs. class AutoMixedPrecision : public GraphOptimizer { public: - // If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If BF16, - // converts nodes to bfloat16 on CPUs in order to take advantage of oneDNN - // performance improvements with bfloat16. + // If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If BF16 or + // FP16_CPU, converts nodes to bfloat16/fp16 on CPUs in order to take + // advantage of oneDNN performance improvements with bfloat16/fp16. explicit AutoMixedPrecision( AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA) : mode_(mode) {} @@ -49,6 +50,9 @@ class AutoMixedPrecision : public GraphOptimizer { return "auto_mixed_precision_onednn_bfloat16"; case AutoMixedPrecisionMode::CPU: return "auto_mixed_precision_cpu"; + case AutoMixedPrecisionMode::FP16_CPU: + // Note: use same config for FP16 on CPU & GPU. + return "auto_mixed_precision"; default: LOG(FATAL) << "Invalid value for AutoMixedPrecisionMode: " // Crash Ok << static_cast(mode_); diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index 63a45f33b977ce..4a520fa3377e8a 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -107,11 +107,8 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } public: - AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version, - bool run_fp16_on_cpu = false) - : cuda_version_(cuda_version), - cudnn_version_(cudnn_version), - run_fp16_on_cpu_(run_fp16_on_cpu) {} + AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) + : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} gtl::FlatSet AllowList() override { auto list = gtl::FlatSet{ @@ -147,13 +144,13 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { #if TENSORFLOW_USE_ROCM if (true) { #else - if (cuda_version_ >= 9010 || run_fp16_on_cpu_) { + if (cuda_version_ >= 9010) { // Fp16 BatchMatMul is slow before CUDA 9.1. #endif list.insert("BatchMatMul"); list.insert("BatchMatMulV2"); } - if (cudnn_version_ >= 7602 || run_fp16_on_cpu_) { + if (cudnn_version_ >= 7602) { // Fp16 3D conv is slow before CUDNN 7.6.2. list.insert("Conv3D"); list.insert("Conv3DBackpropFilter"); @@ -161,7 +158,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { list.insert("Conv3DBackpropInput"); list.insert("Conv3DBackpropInputV2"); } - if (cudnn_version_ >= 8000 || run_fp16_on_cpu_) { + if (cudnn_version_ >= 8000) { list.insert("DepthwiseConv2dNative"); list.insert("DepthwiseConv2dNativeBackpropFilter"); list.insert("DepthwiseConv2dNativeBackpropInput"); @@ -224,11 +221,6 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { "Tanh", "TanhGrad", }; - if (run_fp16_on_cpu_) { - list.insert("Rsqrt"); - list.insert("Square"); - list.insert("SquaredDifference"); - } UpdateList("INFERLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -361,10 +353,6 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { "Where", "ZerosLike", }; - if (run_fp16_on_cpu_) { - list.insert("ResizeBilinear"); - list.insert("ScatterNd"); - } AddTensorListOps(&list); UpdateList("CLEARLIST", &list); return list; @@ -373,12 +361,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { private: int cuda_version_; int cudnn_version_; - bool run_fp16_on_cpu_; }; class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { public: - AutoMixedPrecisionListsMkl() {} + AutoMixedPrecisionListsMkl(AutoMixedPrecisionMode mode) : mode_(mode) {} // Only ops which are supported by MKL in bfloat16 should be added to the // allow list, infer list, or clear list. @@ -417,13 +404,14 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "BiasAddGrad", "BiasAddV1", "Erf", + "Erfc", "FusedBatchNormV2", "FusedBatchNormGradV2", "FusedBatchNormV3", "FusedBatchNormGradV3", + "Inv", "LeakyRelu", "LeakyReluGrad", - "Mean", "Mul", "Sub", "Elu", @@ -449,9 +437,12 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "Sqrt", "Square", "SquaredDifference", - "Sum", "Tanh", "TanhGrad"}; + if (mode_ != AutoMixedPrecisionMode::FP16_CPU) { + list.insert("Mean"); + list.insert("Sum"); + } UpdateList("INFERLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -469,6 +460,10 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "SoftmaxCrossEntropyWithLogits", "SparseSoftmaxCrossEntropyWithLogits", }; + if (mode_ == AutoMixedPrecisionMode::FP16_CPU) { + list.insert("Mean"); + list.insert("Sum"); + } UpdateList("DENYLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -505,6 +500,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "Greater", "GreaterEqual", "Identity", + "IdentityN", "IsFinite", "IsInf", "IsNan", @@ -576,6 +572,8 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { UpdateList("CLEARLIST", &list); return list; } + private: + AutoMixedPrecisionMode mode_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index c6234ac74a6cb7..c8b2e18318a143 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -128,15 +128,14 @@ class AutoMixedPrecisionTest : public GrapplerTest { DeviceProperties device_properties; device_properties.set_type("CPU"); virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); - } - TF_CHECK_OK(virtual_cluster_->Provision()); - bool run_fp16_on_cpu = false; + bool run_fp16_on_cpu = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - run_fp16_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + run_fp16_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 + } + TF_CHECK_OK(virtual_cluster_->Provision()); - skip_test_ = !gpu_available_ && (!IsMKLEnabled() || !run_fp16_on_cpu); } void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } @@ -183,8 +182,8 @@ class AutoMixedPrecisionTest : public GrapplerTest { void TestSimpleUnaryInferOp( double input_min, double input_max, double atol, double rtol, const std::function& - test_op_factory) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; + test_op_factory, AutoMixedPrecisionMode mode) { + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; int size = 128; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -203,7 +202,7 @@ class AutoMixedPrecisionTest : public GrapplerTest { std::vector> feed = {{"input", input_tensor}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -224,17 +223,27 @@ class AutoMixedPrecisionTest : public GrapplerTest { } } - bool ShouldSkipTest() { - return skip_test_; + bool ShouldSkipTest(AutoMixedPrecisionMode mode) { + if (mode == AutoMixedPrecisionMode::CUDA && GetNumAvailableGPUs() > 0 || + mode == AutoMixedPrecisionMode::FP16_CPU && is_fp16_enabled_on_cpu_) { + return false; + } else { + return true; + } } std::unique_ptr virtual_cluster_; bool gpu_available_; - bool skip_test_; + bool is_fp16_enabled_on_cpu_; }; -TEST_F(AutoMixedPrecisionTest, NoOp) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +class AutoMixedPrecisionParamTest : public AutoMixedPrecisionTest, + public ::testing::WithParamInterface< + AutoMixedPrecisionMode> {}; + +TEST_P(AutoMixedPrecisionParamTest, NoOp) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.234f, {32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -248,7 +257,7 @@ TEST_F(AutoMixedPrecisionTest, NoOp) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -271,8 +280,9 @@ TEST_F(AutoMixedPrecisionTest, NoOp) { } } -TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF); @@ -287,7 +297,7 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); VLOG(1) << output.DebugString(); @@ -310,8 +320,9 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { } } -TEST_F(AutoMixedPrecisionTest, Simple) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, Simple) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -331,7 +342,7 @@ TEST_F(AutoMixedPrecisionTest, Simple) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -360,11 +371,11 @@ TEST_F(AutoMixedPrecisionTest, Simple) { } } -TEST_F(AutoMixedPrecisionTest, NoInferOp) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY", 1 /* replace */); - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -384,7 +395,7 @@ TEST_F(AutoMixedPrecisionTest, NoInferOp) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -413,8 +424,9 @@ TEST_F(AutoMixedPrecisionTest, NoInferOp) { unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL"); } -TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -430,7 +442,7 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -453,8 +465,9 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { } } -TEST_F(AutoMixedPrecisionTest, PreserveFetches) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, PreserveFetches) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -472,7 +485,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -499,11 +512,11 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) { } } -TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { - if (GetNumAvailableGPUs() == 0) { +TEST_P(AutoMixedPrecisionParamTest, PreserveCPUNodes) { + AutoMixedPrecisionMode mode = GetParam(); + if (mode == AutoMixedPrecisionMode::FP16_CPU) { GTEST_SKIP() << "This test is not required on CPU"; } - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -521,7 +534,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -544,8 +557,9 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { } } -TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT); @@ -565,7 +579,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { std::vector> feed = {{"var1", var1_tensor}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -589,8 +603,9 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { } } -TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); // Uses NHWC data format because non-GPU execution does not support NCHW. Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16}); @@ -623,7 +638,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -649,8 +664,9 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { } } -TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -666,7 +682,7 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -690,8 +706,9 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { } } -TEST_F(AutoMixedPrecisionTest, ExistingCast) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), true, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT); @@ -703,7 +720,7 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -723,8 +740,9 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) { } } -TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -757,7 +775,7 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { const_node->add_input("^mrg1"); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -786,8 +804,9 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { } } -TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -829,7 +848,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -858,8 +877,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { } } -TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -893,7 +913,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -922,8 +942,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { } } -TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); @@ -948,7 +969,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -973,8 +994,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { } } -TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -1009,7 +1031,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1034,8 +1056,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { } } -TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, TensorListThroughFunction) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; // This test passes a tensor list handle through a function with its own // Tensor List ops inside to test that the types are not changed to a // conflicting state. @@ -1096,7 +1119,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1143,8 +1166,9 @@ bool IsSupportedGPU(const Cluster& cluster) { #endif } -TEST_F(AutoMixedPrecisionTest, BatchMatMul) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; +TEST_P(AutoMixedPrecisionParamTest, BatchMatMul) { + AutoMixedPrecisionMode mode = GetParam(); + if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32}); Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input); @@ -1155,7 +1179,7 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer; + AutoMixedPrecision optimizer(mode); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1179,110 +1203,120 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) { } } -TEST_F(AutoMixedPrecisionTest, EluOp) { +TEST_P(AutoMixedPrecisionParamTest, EluOp) { TestSimpleUnaryInferOp( -5, 5, 1.0e-3, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Elu(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, ErfOp) { +TEST_P(AutoMixedPrecisionParamTest, ErfOp) { TestSimpleUnaryInferOp( -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Erf(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, ErfcOp) { +TEST_P(AutoMixedPrecisionParamTest, ErfcOp) { TestSimpleUnaryInferOp( -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Erfc(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, InvOp) { +TEST_P(AutoMixedPrecisionParamTest, InvOp) { TestSimpleUnaryInferOp( 0.01, 10, -1, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Inv(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, LogOp) { +TEST_P(AutoMixedPrecisionParamTest, LogOp) { TestSimpleUnaryInferOp( 0.01, 10, 1.0e-3, 2.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Log(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, Log1pOp) { +TEST_P(AutoMixedPrecisionParamTest, Log1pOp) { TestSimpleUnaryInferOp( -0.99, 9, 1.0e-3, 5.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Log1p(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) { +TEST_P(AutoMixedPrecisionParamTest, LogSoftmaxOp) { TestSimpleUnaryInferOp( -8, 8, -1, 1.0e-2, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::LogSoftmax(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, ReciprocalOp) { +TEST_P(AutoMixedPrecisionParamTest, ReciprocalOp) { TestSimpleUnaryInferOp( 0.01, 10, -1, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Reciprocal(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, SigmoidOp) { +TEST_P(AutoMixedPrecisionParamTest, SigmoidOp) { TestSimpleUnaryInferOp( -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Sigmoid(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, SoftmaxOp) { +TEST_P(AutoMixedPrecisionParamTest, SoftmaxOp) { TestSimpleUnaryInferOp( -8, 8, 2.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Softmax(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, SoftplusOp) { +TEST_P(AutoMixedPrecisionParamTest, SoftplusOp) { TestSimpleUnaryInferOp( -5, 5, 2.0e-3, 2.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Softplus(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, SqrtOp) { +TEST_P(AutoMixedPrecisionParamTest, SqrtOp) { TestSimpleUnaryInferOp( 0, 10, 1.0e-3, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Sqrt(scope, input); - }); + }, GetParam()); } -TEST_F(AutoMixedPrecisionTest, TanhOp) { +TEST_P(AutoMixedPrecisionParamTest, TanhOp) { TestSimpleUnaryInferOp( -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Tanh(scope, input); - }); + }, GetParam()); } +INSTANTIATE_TEST_SUITE_P(AutoMixedPrecisionTest, AutoMixedPrecisionParamTest, + ::testing::ValuesIn({ +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + AutoMixedPrecisionMode::CUDA, +#endif +#if INTEL_MKL + AutoMixedPrecisionMode::FP16_CPU +#endif + })); + class AutoMixedPrecisionCpuTest : public GrapplerTest { protected: void SetUp() override { @@ -1761,6 +1795,7 @@ TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamDeny) { test::ExpectClose(tensors_expected[i], tensors[i]); } } + #endif // INTEL_MKL } // namespace diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 999e7c0dc6d092..6bfa08a78866dc 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -374,8 +374,14 @@ Status MetaOptimizer::InitializeOptimizers( if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) && AutoMixedPrecisionEnabled( plugin_configs.toggle_config["auto_mixed_precision"])) { - optimizers->push_back( - std::make_unique(AutoMixedPrecisionMode::CUDA)); + if (device_types.size() == 1 && + device_types.find("CPU") != device_types.end()) { + optimizers->push_back( + std::make_unique(AutoMixedPrecisionMode::FP16_CPU)); + } else { + optimizers->push_back( + std::make_unique(AutoMixedPrecisionMode::CUDA)); + } } #ifdef INTEL_MKL if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16()) && From 0ab5b80f9bf952e3eb6f5fb1df0881be7cdd0959 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Mon, 12 Feb 2024 10:27:02 -0800 Subject: [PATCH 006/670] minor change --- .../core/grappler/optimizers/auto_mixed_precision_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index c8b2e18318a143..780c4835b1c45e 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -1311,8 +1311,7 @@ INSTANTIATE_TEST_SUITE_P(AutoMixedPrecisionTest, AutoMixedPrecisionParamTest, ::testing::ValuesIn({ #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM AutoMixedPrecisionMode::CUDA, -#endif -#if INTEL_MKL +#elif INTEL_MKL AutoMixedPrecisionMode::FP16_CPU #endif })); From 182a17386f94f26fa7f9753eb03ea0b6263ef8b6 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 16 Feb 2024 15:24:11 -0800 Subject: [PATCH 007/670] Address review comments --- .../optimizers/auto_mixed_precision.cc | 22 +- .../optimizers/auto_mixed_precision_lists.h | 93 +++++---- .../optimizers/auto_mixed_precision_test.cc | 191 +++++++++--------- .../grappler/optimizers/meta_optimizer.cc | 16 +- 4 files changed, 160 insertions(+), 162 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 40201896086e4f..8d3ca6f758aa70 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -1064,19 +1064,19 @@ class AutoMixedPrecisionImpl { std::unique_ptr get_mixed_precision_lists() const { switch (mode_) { case AutoMixedPrecisionMode::CUDA: - return std::make_unique(cuda_version_, - cudnn_version_); + return std::make_unique( + cuda_version_, cudnn_version_, AutoMixedPrecisionMode::CUDA); case AutoMixedPrecisionMode::BF16: - return std::make_unique( - AutoMixedPrecisionMode::BF16); + return std::make_unique(); case AutoMixedPrecisionMode::CPU: - // Note: this is not a typo here. AutoMixedPrecisionListsCuda is used + // Note: this is not a typo here. AutoMixedPrecisionListsFp16 is used // intentionally to make CPU and GPU have the same fp16 ops. - return std::make_unique( + return std::make_unique( /*cuda_version=*/10000, // Hardcode cuda and cudnn version so - /*cudnn_version=*/8000); // CPU emulates the same ops on GPU. + /*cudnn_version=*/8000, // CPU emulates the same ops on GPU. + AutoMixedPrecisionMode::CPU); case AutoMixedPrecisionMode::FP16_CPU: - return std::make_unique( + return std::make_unique(0, 0, AutoMixedPrecisionMode::FP16_CPU); } } @@ -1865,8 +1865,7 @@ void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow( const absl::flat_hash_set& deny_set, absl::flat_hash_set* allow_set) const { // Currently only target for oneDNN - if (mode_ != AutoMixedPrecisionMode::BF16 && - mode_ != AutoMixedPrecisionMode::FP16_CPU) { + if (mode_ != AutoMixedPrecisionMode::BF16) { return; } for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) { @@ -2319,7 +2318,6 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU"; return OkStatus(); } - LOG(INFO) << "Running " << name() << " graph optimizer "; if (num_gpus >= 1 && mode_ == AutoMixedPrecisionMode::BF16) { LOG(WARNING) << "Note: GPUs detected. Using " << name() @@ -2332,7 +2330,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, if (item.id == "tf_graph") { LOG(INFO) << "Running " << name() << " graph optimizer"; } else { - LOG(INFO) << "Running " << name() << " graph optimizer on " << item.id; + VLOG(1) << "Running " << name() << " graph optimizer on " << item.id; } Status status = optimizer.Optimize(); if (!status.ok()) { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index 4a520fa3377e8a..810a3ea8d6f6d8 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -95,7 +95,7 @@ class AutoMixedPrecisionLists { } }; -class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { +class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { private: static bool IsPseudoFastMath() { string optimization_level; @@ -107,50 +107,60 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } public: - AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) - : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} + AutoMixedPrecisionListsFp16(int cuda_version, int cudnn_version, + AutoMixedPrecisionMode mode) + : cuda_version_(cuda_version), cudnn_version_(cudnn_version) { + if (mode == AutoMixedPrecisionMode::CUDA || + mode == AutoMixedPrecisionMode::CPU) { + use_cuda_ = true; + } else if (mode == AutoMixedPrecisionMode::FP16_CPU) { + use_onednn_ = true; + } + } gtl::FlatSet AllowList() override { auto list = gtl::FlatSet{ - "BlockLSTM", - "BlockLSTMV2", - "BlockLSTMGrad", - "BlockLSTMGradV2", "Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", - "CudnnRNN", - "CudnnRNNBackprop", - "CudnnRNNBackpropV2", - "CudnnRNNBackpropV3", - "CudnnRNNV2", - "CudnnRNNV3", "Einsum", - "FusedConv2DBiasActivation", - "FusedSparseConvGpuV2", - "GRUBlockCell", - "GRUBlockCellGrad", - "LSTMBlockCell", - "LSTMBlockCellGrad", "MatMul", - "Mha", - "MhaV2", - "Tmlp", - "TmlpV2", - "TmlpV3", - "Pmlp", - "FastUnsortedSegmentMax", }; + if (use_cuda_) { + list.insert("BlockLSTM"); + list.insert("BlockLSTMV2"); + list.insert("BlockLSTMGrad"); + list.insert("BlockLSTMGradV2"); + list.insert("CudnnRNN"); + list.insert("CudnnRNNBackprop"); + list.insert("CudnnRNNBackpropV2"); + list.insert("CudnnRNNBackpropV3"); + list.insert("CudnnRNNV2"); + list.insert("CudnnRNNV3"); + list.insert("FusedConv2DBiasActivation"); + list.insert("FusedSparseConvGpuV2"); + list.insert("GRUBlockCell"); + list.insert("GRUBlockCellGrad"); + list.insert("LSTMBlockCell"); + list.insert("LSTMBlockCellGrad"); + list.insert("Mha"); + list.insert("MhaV2"); + list.insert("Tmlp"); + list.insert("TmlpV2"); + list.insert("TmlpV3"); + list.insert("Pmlp"); + list.insert("FastUnsortedSegmentMax"); + } #if TENSORFLOW_USE_ROCM if (true) { #else - if (cuda_version_ >= 9010) { + if ((use_cuda_ && cuda_version_ >= 9010) || use_onednn_ ) { // Fp16 BatchMatMul is slow before CUDA 9.1. #endif list.insert("BatchMatMul"); list.insert("BatchMatMulV2"); } - if (cudnn_version_ >= 7602) { + if ((use_cuda_ && cudnn_version_ >= 7602) || use_onednn_) { // Fp16 3D conv is slow before CUDNN 7.6.2. list.insert("Conv3D"); list.insert("Conv3DBackpropFilter"); @@ -158,7 +168,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { list.insert("Conv3DBackpropInput"); list.insert("Conv3DBackpropInputV2"); } - if (cudnn_version_ >= 8000) { + if ((use_cuda_ && cudnn_version_ >= 8000) || use_onednn_) { list.insert("DepthwiseConv2dNative"); list.insert("DepthwiseConv2dNativeBackpropFilter"); list.insert("DepthwiseConv2dNativeBackpropInput"); @@ -172,7 +182,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } gtl::FlatSet InferList() override { - if (IsPseudoFastMath()) { + if (IsPseudoFastMath() && use_cuda_) { return gtl::FlatSet{}; } @@ -221,6 +231,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { "Tanh", "TanhGrad", }; + if (use_onednn_) { + list.insert("Rsqrt"); + list.insert("Square"); + list.insert("SquaredDifference"); + } UpdateList("INFERLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -229,7 +244,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } gtl::FlatSet DenyList() override { - if (IsPseudoFastMath()) { + if (IsPseudoFastMath() && use_cuda_) { return gtl::FlatSet{}; } @@ -252,7 +267,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { } gtl::FlatSet ClearList() override { - if (IsPseudoFastMath()) { + if (IsPseudoFastMath() && use_cuda_) { return gtl::FlatSet{}; } @@ -361,11 +376,13 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { private: int cuda_version_; int cudnn_version_; + bool use_cuda_; + bool use_onednn_; }; class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { public: - AutoMixedPrecisionListsMkl(AutoMixedPrecisionMode mode) : mode_(mode) {} + AutoMixedPrecisionListsMkl() {} // Only ops which are supported by MKL in bfloat16 should be added to the // allow list, infer list, or clear list. @@ -439,10 +456,6 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "SquaredDifference", "Tanh", "TanhGrad"}; - if (mode_ != AutoMixedPrecisionMode::FP16_CPU) { - list.insert("Mean"); - list.insert("Sum"); - } UpdateList("INFERLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -460,10 +473,6 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "SoftmaxCrossEntropyWithLogits", "SparseSoftmaxCrossEntropyWithLogits", }; - if (mode_ == AutoMixedPrecisionMode::FP16_CPU) { - list.insert("Mean"); - list.insert("Sum"); - } UpdateList("DENYLIST", &list); // For backwards compatibility, keeping the original env variable here. // TODO(reedwm): This should be removed if we don't have active users. @@ -572,8 +581,6 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { UpdateList("CLEARLIST", &list); return list; } - private: - AutoMixedPrecisionMode mode_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 780c4835b1c45e..60dd0b88eed075 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -99,43 +99,47 @@ const std::pair kMinGPUArch = {7, 0}; class AutoMixedPrecisionTest : public GrapplerTest { protected: + void SetMode(AutoMixedPrecisionMode mode) { + mode_ = mode; + } void SetUp() override { - int num_gpus = GetNumAvailableGPUs(); - // If GPUs are available, require that they all satisfy the min arch. - gpu_available_ = (num_gpus > 0); + if (mode_ == AutoMixedPrecisionMode::CUDA) { + int num_gpus = GetNumAvailableGPUs(); + // If GPUs are available, require that they all satisfy the min arch. + gpu_available_ = (num_gpus > 0); #if GOOGLE_CUDA - gpu_available_ = - gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch)); + gpu_available_ = + gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch)); #else // Here we force Tensorflow to use the virtual GFX906 - gpu_available_ = false; + gpu_available_ = false; #endif - if (gpu_available_) { - virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1)); - } else if( num_gpus > 0) { - DeviceProperties device_properties; - device_properties.set_type("GPU"); + if (gpu_available_) { + virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1)); + } else { + DeviceProperties device_properties; + device_properties.set_type("GPU"); #if GOOGLE_CUDA - device_properties.mutable_environment()->insert({"architecture", "7"}); - device_properties.mutable_environment()->insert({"cuda", "9010"}); + device_properties.mutable_environment()->insert({"architecture", "7"}); + device_properties.mutable_environment()->insert({"cuda", "9010"}); #else - device_properties.mutable_environment()->insert( - {"architecture", "gfx906"}); + device_properties.mutable_environment()->insert( + {"architecture", "gfx906"}); #endif - virtual_cluster_.reset( - new VirtualCluster({{"/GPU:1", device_properties}})); - } else { + virtual_cluster_.reset( + new VirtualCluster({{"/GPU:1", device_properties}})); + } + } else if (mode_ == AutoMixedPrecisionMode::FP16_CPU) { // When no GPUs are available, try running on CPU. DeviceProperties device_properties; device_properties.set_type("CPU"); virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); - bool run_fp16_on_cpu = false; + is_fp16_enabled_on_cpu_ = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - run_fp16_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + is_fp16_enabled_on_cpu_ = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 } TF_CHECK_OK(virtual_cluster_->Provision()); - } void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } @@ -182,8 +186,8 @@ class AutoMixedPrecisionTest : public GrapplerTest { void TestSimpleUnaryInferOp( double input_min, double input_max, double atol, double rtol, const std::function& - test_op_factory, AutoMixedPrecisionMode mode) { - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + test_op_factory) { + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; int size = 128; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -202,7 +206,7 @@ class AutoMixedPrecisionTest : public GrapplerTest { std::vector> feed = {{"input", input_tensor}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -223,9 +227,10 @@ class AutoMixedPrecisionTest : public GrapplerTest { } } - bool ShouldSkipTest(AutoMixedPrecisionMode mode) { - if (mode == AutoMixedPrecisionMode::CUDA && GetNumAvailableGPUs() > 0 || - mode == AutoMixedPrecisionMode::FP16_CPU && is_fp16_enabled_on_cpu_) { + bool ShouldSkipTest() { + if (mode_ == AutoMixedPrecisionMode::CUDA && GetNumAvailableGPUs() > 0 || + (mode_ == AutoMixedPrecisionMode::FP16_CPU && IsMKLEnabled() && + is_fp16_enabled_on_cpu_)) { return false; } else { return true; @@ -235,15 +240,24 @@ class AutoMixedPrecisionTest : public GrapplerTest { std::unique_ptr virtual_cluster_; bool gpu_available_; bool is_fp16_enabled_on_cpu_; + AutoMixedPrecisionMode mode_; }; class AutoMixedPrecisionParamTest : public AutoMixedPrecisionTest, public ::testing::WithParamInterface< - AutoMixedPrecisionMode> {}; + AutoMixedPrecisionMode> { + + protected: + void SetUp() override { + mode_ = GetParam(); + AutoMixedPrecisionTest::SetMode(mode_); + AutoMixedPrecisionTest::SetUp(); + } + AutoMixedPrecisionMode mode_; +}; TEST_P(AutoMixedPrecisionParamTest, NoOp) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.234f, {32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -257,7 +271,7 @@ TEST_P(AutoMixedPrecisionParamTest, NoOp) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -281,8 +295,7 @@ TEST_P(AutoMixedPrecisionParamTest, NoOp) { } TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF); @@ -297,7 +310,7 @@ TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); VLOG(1) << output.DebugString(); @@ -321,8 +334,7 @@ TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { } TEST_P(AutoMixedPrecisionParamTest, Simple) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -342,7 +354,7 @@ TEST_P(AutoMixedPrecisionParamTest, Simple) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -372,8 +384,7 @@ TEST_P(AutoMixedPrecisionParamTest, Simple) { } TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY", 1 /* replace */); tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -395,7 +406,7 @@ TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -425,8 +436,7 @@ TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { } TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -442,7 +452,7 @@ TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -466,8 +476,7 @@ TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { } TEST_P(AutoMixedPrecisionParamTest, PreserveFetches) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -485,7 +494,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveFetches) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -513,8 +522,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveFetches) { } TEST_P(AutoMixedPrecisionParamTest, PreserveCPUNodes) { - AutoMixedPrecisionMode mode = GetParam(); - if (mode == AutoMixedPrecisionMode::FP16_CPU) { + if (mode_ == AutoMixedPrecisionMode::FP16_CPU) { GTEST_SKIP() << "This test is not required on CPU"; } tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -534,7 +542,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveCPUNodes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -558,8 +566,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveCPUNodes) { } TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT); @@ -579,7 +586,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { std::vector> feed = {{"var1", var1_tensor}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -604,8 +611,7 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { } TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); // Uses NHWC data format because non-GPU execution does not support NCHW. Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16}); @@ -638,7 +644,7 @@ TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -665,8 +671,7 @@ TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { } TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -682,7 +687,7 @@ TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -707,8 +712,7 @@ TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { } TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), true, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT); @@ -720,7 +724,7 @@ TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -741,8 +745,7 @@ TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { } TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -775,7 +778,7 @@ TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { const_node->add_input("^mrg1"); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -805,8 +808,7 @@ TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { } TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -848,7 +850,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -878,8 +880,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { } TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -913,7 +914,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -943,8 +944,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { } TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); @@ -969,7 +969,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -995,8 +995,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { } TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -1031,7 +1030,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1057,8 +1056,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { } TEST_P(AutoMixedPrecisionParamTest, TensorListThroughFunction) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; // This test passes a tensor list handle through a function with its own // Tensor List ops inside to test that the types are not changed to a // conflicting state. @@ -1119,7 +1117,7 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListThroughFunction) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1167,8 +1165,7 @@ bool IsSupportedGPU(const Cluster& cluster) { } TEST_P(AutoMixedPrecisionParamTest, BatchMatMul) { - AutoMixedPrecisionMode mode = GetParam(); - if (ShouldSkipTest(mode)) GTEST_SKIP() << "This device doesn't support FP16"; + if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32}); Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input); @@ -1179,7 +1176,7 @@ TEST_P(AutoMixedPrecisionParamTest, BatchMatMul) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - AutoMixedPrecision optimizer(mode); + AutoMixedPrecision optimizer(mode_); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); @@ -1208,7 +1205,7 @@ TEST_P(AutoMixedPrecisionParamTest, EluOp) { -5, 5, 1.0e-3, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Elu(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, ErfOp) { @@ -1216,7 +1213,7 @@ TEST_P(AutoMixedPrecisionParamTest, ErfOp) { -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Erf(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, ErfcOp) { @@ -1224,7 +1221,7 @@ TEST_P(AutoMixedPrecisionParamTest, ErfcOp) { -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Erfc(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, InvOp) { @@ -1232,7 +1229,7 @@ TEST_P(AutoMixedPrecisionParamTest, InvOp) { 0.01, 10, -1, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Inv(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, LogOp) { @@ -1240,7 +1237,7 @@ TEST_P(AutoMixedPrecisionParamTest, LogOp) { 0.01, 10, 1.0e-3, 2.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Log(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, Log1pOp) { @@ -1248,7 +1245,7 @@ TEST_P(AutoMixedPrecisionParamTest, Log1pOp) { -0.99, 9, 1.0e-3, 5.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Log1p(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, LogSoftmaxOp) { @@ -1256,7 +1253,7 @@ TEST_P(AutoMixedPrecisionParamTest, LogSoftmaxOp) { -8, 8, -1, 1.0e-2, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::LogSoftmax(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, ReciprocalOp) { @@ -1264,7 +1261,7 @@ TEST_P(AutoMixedPrecisionParamTest, ReciprocalOp) { 0.01, 10, -1, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Reciprocal(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, SigmoidOp) { @@ -1272,7 +1269,7 @@ TEST_P(AutoMixedPrecisionParamTest, SigmoidOp) { -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Sigmoid(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, SoftmaxOp) { @@ -1280,7 +1277,7 @@ TEST_P(AutoMixedPrecisionParamTest, SoftmaxOp) { -8, 8, 2.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Softmax(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, SoftplusOp) { @@ -1288,7 +1285,7 @@ TEST_P(AutoMixedPrecisionParamTest, SoftplusOp) { -5, 5, 2.0e-3, 2.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Softplus(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, SqrtOp) { @@ -1296,7 +1293,7 @@ TEST_P(AutoMixedPrecisionParamTest, SqrtOp) { 0, 10, 1.0e-3, 1.0e-3, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Sqrt(scope, input); - }, GetParam()); + }); } TEST_P(AutoMixedPrecisionParamTest, TanhOp) { @@ -1304,14 +1301,15 @@ TEST_P(AutoMixedPrecisionParamTest, TanhOp) { -5, 5, 1.0e-3, -1, [](const tensorflow::Scope& scope, Output input) -> Output { return ops::Tanh(scope, input); - }, GetParam()); + }); } INSTANTIATE_TEST_SUITE_P(AutoMixedPrecisionTest, AutoMixedPrecisionParamTest, ::testing::ValuesIn({ #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM AutoMixedPrecisionMode::CUDA, -#elif INTEL_MKL +#endif +#if INTEL_MKL AutoMixedPrecisionMode::FP16_CPU #endif })); @@ -1794,7 +1792,6 @@ TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamDeny) { test::ExpectClose(tensors_expected[i], tensors[i]); } } - #endif // INTEL_MKL } // namespace diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6bfa08a78866dc..3c0f37d2e9ea4a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -372,16 +372,12 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(std::make_unique()); } if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) && - AutoMixedPrecisionEnabled( - plugin_configs.toggle_config["auto_mixed_precision"])) { - if (device_types.size() == 1 && - device_types.find("CPU") != device_types.end()) { - optimizers->push_back( - std::make_unique(AutoMixedPrecisionMode::FP16_CPU)); - } else { - optimizers->push_back( - std::make_unique(AutoMixedPrecisionMode::CUDA)); - } + AutoMixedPrecisionEnabled( + plugin_configs.toggle_config["auto_mixed_precision"])) { + optimizers->push_back( + std::make_unique(AutoMixedPrecisionMode::FP16_CPU)); + optimizers->push_back( + std::make_unique(AutoMixedPrecisionMode::CUDA)); } #ifdef INTEL_MKL if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16()) && From 9780d75ad68275bf30b81c7cf02688fb87e5f1e2 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 16 Feb 2024 19:59:22 -0800 Subject: [PATCH 008/670] Address re-review comments --- .../optimizers/auto_mixed_precision.cc | 6 +-- .../optimizers/auto_mixed_precision.h | 4 +- .../optimizers/auto_mixed_precision_lists.h | 7 +-- .../optimizers/auto_mixed_precision_test.cc | 50 ++++--------------- .../grappler/optimizers/meta_optimizer.cc | 2 + .../core/protobuf/rewriter_config.proto | 4 +- tensorflow/python/framework/config.py | 4 +- 7 files changed, 24 insertions(+), 53 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 8d3ca6f758aa70..d70ef54aee2533 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -1069,8 +1069,6 @@ class AutoMixedPrecisionImpl { case AutoMixedPrecisionMode::BF16: return std::make_unique(); case AutoMixedPrecisionMode::CPU: - // Note: this is not a typo here. AutoMixedPrecisionListsFp16 is used - // intentionally to make CPU and GPU have the same fp16 ops. return std::make_unique( /*cuda_version=*/10000, // Hardcode cuda and cudnn version so /*cudnn_version=*/8000, // CPU emulates the same ops on GPU. @@ -2308,14 +2306,14 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, int num_gpus = GetNumGPUs(*cluster); if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) { // No GPUs to run AutoMixedPrecision in FP16. - LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() + VLOG(1) << "No (suitable) GPUs detected, skipping " << name() << " graph optimizer"; return OkStatus(); } // Check if CPU supports FP16 if (mode_ == AutoMixedPrecisionMode::FP16_CPU && !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { - LOG(WARNING) << "No support for " << name() << " graph optimizer on CPU"; + VLOG(1) << "No support for " << name() << " graph optimizer on CPU"; return OkStatus(); } diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision.h index 3f478ec3038534..c26b640765f3d4 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.h @@ -51,8 +51,8 @@ class AutoMixedPrecision : public GraphOptimizer { case AutoMixedPrecisionMode::CPU: return "auto_mixed_precision_cpu"; case AutoMixedPrecisionMode::FP16_CPU: - // Note: use same config for FP16 on CPU & GPU. - return "auto_mixed_precision"; + // Note: using different name than GPU for ease of debugging. + return "auto_mixed_precision_onednn_float16"; default: LOG(FATAL) << "Invalid value for AutoMixedPrecisionMode: " // Crash Ok << static_cast(mode_); diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index 810a3ea8d6f6d8..5c4cf2940f1720 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -112,9 +112,13 @@ class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { : cuda_version_(cuda_version), cudnn_version_(cudnn_version) { if (mode == AutoMixedPrecisionMode::CUDA || mode == AutoMixedPrecisionMode::CPU) { + // Note: this is not a typo here. use_cuda_ is set to true for the CPU + // intentionally to make CPU and GPU have the same fp16 ops. use_cuda_ = true; + use_onednn_ = false; } else if (mode == AutoMixedPrecisionMode::FP16_CPU) { use_onednn_ = true; + use_cuda_ = false; } } @@ -421,12 +425,10 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "BiasAddGrad", "BiasAddV1", "Erf", - "Erfc", "FusedBatchNormV2", "FusedBatchNormGradV2", "FusedBatchNormV3", "FusedBatchNormGradV3", - "Inv", "LeakyRelu", "LeakyReluGrad", "Mul", @@ -509,7 +511,6 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "Greater", "GreaterEqual", "Identity", - "IdentityN", "IsFinite", "IsInf", "IsNan", diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 60dd0b88eed075..9ac263c89068ab 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -89,12 +89,11 @@ void VerifyGraphsEquivalent(const GraphDef& original_graph, } } -// Currently, this test suite only passes when TensorFlow passes with CUDA/HIP, -// because otherwise the optimizer will not turn clearlist nodes to float16. -// When looking at clearlist nodes, this optimizer checks if the nodes have a -// float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all. -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL - +// Currently on GPU, this test suite only passes when TensorFlow passes with +// CUDA/HIP, because otherwise the optimizer will not turn clearlist nodes to +// float16. When looking at clearlist nodes, this optimizer checks if the nodes +// have a float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels +// at all. And on CPU, this test suite passes when AMX FP16 is supported. const std::pair kMinGPUArch = {7, 0}; class AutoMixedPrecisionTest : public GrapplerTest { @@ -129,15 +128,17 @@ class AutoMixedPrecisionTest : public GrapplerTest { new VirtualCluster({{"/GPU:1", device_properties}})); } } else if (mode_ == AutoMixedPrecisionMode::FP16_CPU) { - // When no GPUs are available, try running on CPU. DeviceProperties device_properties; device_properties.set_type("CPU"); virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0)); - is_fp16_enabled_on_cpu_ = false; + bool is_fp16_enabled_on_cpu = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - is_fp16_enabled_on_cpu_ = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + is_fp16_enabled_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 + if(!IsMKLEnabled() || !is_fp16_enabled_on_cpu) { + GTEST_SKIP() << "This device doesn't support FP16"; + } } TF_CHECK_OK(virtual_cluster_->Provision()); } @@ -187,8 +188,6 @@ class AutoMixedPrecisionTest : public GrapplerTest { double input_min, double input_max, double atol, double rtol, const std::function& test_op_factory) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; - int size = 128; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output eye = ops::Const(s.WithOpName("eye"), @@ -227,19 +226,8 @@ class AutoMixedPrecisionTest : public GrapplerTest { } } - bool ShouldSkipTest() { - if (mode_ == AutoMixedPrecisionMode::CUDA && GetNumAvailableGPUs() > 0 || - (mode_ == AutoMixedPrecisionMode::FP16_CPU && IsMKLEnabled() && - is_fp16_enabled_on_cpu_)) { - return false; - } else { - return true; - } - } - std::unique_ptr virtual_cluster_; bool gpu_available_; - bool is_fp16_enabled_on_cpu_; AutoMixedPrecisionMode mode_; }; @@ -257,7 +245,6 @@ class AutoMixedPrecisionParamTest : public AutoMixedPrecisionTest, }; TEST_P(AutoMixedPrecisionParamTest, NoOp) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.234f, {32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -295,7 +282,6 @@ TEST_P(AutoMixedPrecisionParamTest, NoOp) { } TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF); @@ -334,7 +320,6 @@ TEST_P(AutoMixedPrecisionParamTest, AlreadyFp16) { } TEST_P(AutoMixedPrecisionParamTest, Simple) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -384,7 +369,6 @@ TEST_P(AutoMixedPrecisionParamTest, Simple) { } TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY", 1 /* replace */); tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -436,7 +420,6 @@ TEST_P(AutoMixedPrecisionParamTest, NoInferOp) { } TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); @@ -476,7 +459,6 @@ TEST_P(AutoMixedPrecisionParamTest, BidirectionalClearChain) { } TEST_P(AutoMixedPrecisionParamTest, PreserveFetches) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -566,7 +548,6 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveCPUNodes) { } TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT); @@ -611,7 +592,6 @@ TEST_P(AutoMixedPrecisionParamTest, PreserveIdentityAfterVariable) { } TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); // Uses NHWC data format because non-GPU execution does not support NCHW. Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16}); @@ -671,7 +651,6 @@ TEST_P(AutoMixedPrecisionParamTest, FusedBatchNorm) { } TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); @@ -712,7 +691,6 @@ TEST_P(AutoMixedPrecisionParamTest, RepeatedAndListTypeAttrs) { } TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), true, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT); @@ -745,7 +723,6 @@ TEST_P(AutoMixedPrecisionParamTest, ExistingCast) { } TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output deny1 = ops::Exp(s.WithOpName("deny1"), input); @@ -808,7 +785,6 @@ TEST_P(AutoMixedPrecisionParamTest, RecurrentEdgeColorMismatch) { } TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -880,7 +856,6 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListSetGet) { } TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -944,7 +919,6 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushPop) { } TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); @@ -995,7 +969,6 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListFromTensor) { } TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32, 32}; auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); @@ -1056,7 +1029,6 @@ TEST_P(AutoMixedPrecisionParamTest, TensorListPushBackBatchAndConcatLists) { } TEST_P(AutoMixedPrecisionParamTest, TensorListThroughFunction) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; // This test passes a tensor list handle through a function with its own // Tensor List ops inside to test that the types are not changed to a // conflicting state. @@ -1165,7 +1137,6 @@ bool IsSupportedGPU(const Cluster& cluster) { } TEST_P(AutoMixedPrecisionParamTest, BatchMatMul) { - if (ShouldSkipTest()) GTEST_SKIP() << "This device doesn't support FP16"; tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32}); Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input); @@ -1528,7 +1499,6 @@ TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu_CpuScope) { } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL #if INTEL_MKL diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 3c0f37d2e9ea4a..3687e6c307e4da 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -231,6 +231,8 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA)); #ifdef INTEL_MKL if (IsMKLEnabled()) { + MK_OPT("auto_mixed_precision", "auto_mixed_precision", + new AutoMixedPrecision(AutoMixedPrecisionMode::FP16_CPU)); MK_OPT("auto_mixed_precision_mkl", "auto_mixed_precision_mkl", new AutoMixedPrecision(AutoMixedPrecisionMode::BF16)); MK_OPT("auto_mixed_precision_onednn_bfloat16", diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 9f4042e6f8be9b..f98d1928d9e156 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -102,8 +102,8 @@ message RewriterConfig { // Enable the swap of kernel implementations based on the device placement // (default is ON). Toggle implementation_selector = 22; - // Optimize data types for CUDA (default is OFF). - // This will try to use float16 on GPU which is faster. + // Optimize data types for CUDA/oneDNN (default is OFF). + // This will try to use float16 on GPU/CPU which is faster. // Note that this can change the numerical stability of the graph and may // require the use of loss scaling to maintain model convergence. Toggle auto_mixed_precision = 23; diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 228bacb7f6443d..6ee1860950cc40 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -242,8 +242,8 @@ def set_optimizer_experimental_options(options): - implementation_selector: Enable the swap of kernel implementations based on the device placement. - auto_mixed_precision: Change certain float32 ops to float16 on Volta - GPUs and above. Without the use of loss scaling, this can cause - numerical underflow (see + GPUs and above; and on CPUs with AMX FP16 support. Without the use of + loss scaling, this can cause numerical underflow (see `keras.mixed_precision.experimental.LossScaleOptimizer`). - disable_meta_optimizer: Disable the entire meta optimizer. - min_graph_nodes: The minimum number of nodes in a graph to optimizer. From 1668cc3c7371287fdd17988eb0e218c8231fc10f Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Thu, 29 Feb 2024 20:37:10 +0530 Subject: [PATCH 009/670] Fix checkfail in tf.raw_ops.Substr The API tf.raw_ops.Substr currently validates whether the input args pos and len are of same shape or not.Its not checking whether these tensors are empty or not and trying to access the Tensor values directly without validating.If a user passes empty tensors it will lead to assertion failure causing core dumped error. May fixes #63036 --- tensorflow/core/kernels/substr_op.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index a7880ccc681eff..5f4b2a3a3b0d54 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -56,7 +56,12 @@ class SubstrOp : public OpKernel { errors::InvalidArgument( "pos and len should have the same shape, got: ", pos_shape.DebugString(), " vs. ", len_shape.DebugString())); - + OP_REQUIRES(context, pos_tensor.NumElements() > 0, + errors::InvalidArgument("received empty tensor pos_tensor: ", + pos_tensor.DebugString())); + OP_REQUIRES(context, len_tensor.NumElements() > 0, + errors::InvalidArgument("received empty tensor len_tensor: ", + len_tensor.DebugString())); bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); if (is_scalar || input_shape == pos_shape) { From 2b164c8cd6b5ae3dd6c664127ebdf1104836eeda Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Mon, 4 Mar 2024 20:34:41 +0530 Subject: [PATCH 010/670] Fix checkfail in DenseBincount The API raw_ops.DenseBincount lacks validation of input to be vector. It does have checking for rank<=2 but not for rank>0. Passing a scalar value causes checkfail with debug build. Reported at #63068 --- tensorflow/core/ops/math_ops.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index baa487e728d533..2e4a158add1eb8 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1896,6 +1896,9 @@ REGISTER_OP("DenseBincount") c->set_output(0, c->MakeShape({size_val})); } else if (c->Rank(c->input(0)) == 2) { c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val})); + } else { + return errors::InvalidArgument("input must not be a scalar. " + "Recieved input of rank ", c->Rank(c->input(0))); } return absl::OkStatus(); }); From f4532fd0a6905deab9983a258b61928a8a380f3d Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Wed, 6 Mar 2024 10:22:04 +0530 Subject: [PATCH 011/670] Update math_ops.cc Change the logic to validate rank != 0 explicitly --- tensorflow/core/ops/math_ops.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 2e4a158add1eb8..dc93372c1d9df3 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1896,9 +1896,8 @@ REGISTER_OP("DenseBincount") c->set_output(0, c->MakeShape({size_val})); } else if (c->Rank(c->input(0)) == 2) { c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val})); - } else { - return errors::InvalidArgument("input must not be a scalar. " - "Recieved input of rank ", c->Rank(c->input(0))); + } else if (c->Rank(c->input(0)) == 0) { + return absl::InvalidArgumentError("The input must not be a scalar. "); } return absl::OkStatus(); }); From 994978a764b23db8be45e5b7747f327fa9e6d47e Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Thu, 7 Mar 2024 23:00:47 +0530 Subject: [PATCH 012/670] Set output shape to rank 0 input in DensebinCount Set output shape to rank 0 input in DensebinCount Op. --- tensorflow/core/ops/math_ops.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index dc93372c1d9df3..192899b6726364 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1892,12 +1892,10 @@ REGISTER_OP("DenseBincount") return errors::InvalidArgument("size (", size_val, ") must be non-negative"); } - if (c->Rank(c->input(0)) == 1) { + if (c->Rank(c->input(0)) == 1 || c->Rank(c->input(0)) == 0) { c->set_output(0, c->MakeShape({size_val})); } else if (c->Rank(c->input(0)) == 2) { c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val})); - } else if (c->Rank(c->input(0)) == 0) { - return absl::InvalidArgumentError("The input must not be a scalar. "); } return absl::OkStatus(); }); From a4efadfd30e535f50a000f7dd2853cb9bde88301 Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Fri, 8 Mar 2024 10:05:30 +0530 Subject: [PATCH 013/670] Support for Rank 0 Input for DenseBinCount Op Support for Rank 0 Input for DenseBinCount Op. Assuming that tensor.flat works with scalar tensor also. --- tensorflow/core/kernels/bincount_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index a680a0e4e7a2e3..1a1e55ed067fd3 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -308,7 +308,7 @@ class DenseBincountOp : public OpKernel { Tensor* out_t; functor::SetZeroFunctor fill; - if (data.dims() == 1) { + if (data.dims() <= 1) { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({size}), &out_t)); auto out = out_t->flat(); fill(ctx->eigen_device(), out); From 176048ad39a1d928208b92e1503e2d8d2cd35d28 Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:50:05 +0530 Subject: [PATCH 014/670] Fix checkfail in UnicodeEncode Op The Op UnicodeEncode segfaults when passed 2D tensor to `input_splits`. It has the below check in SetShapeFn which supposed to raise exception if rank !=1 AFAIk. This seems not working for reason unknown to me. https://github.com/tensorflow/tensorflow/blob/6f64ad5d767a034df45a5eaab8b36fd688cd1217/tensorflow/core/ops/string_ops.cc#L316-L317 Same with input_values argument also. Added an explicit check in Op. --- tensorflow/core/kernels/unicode_ops.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index 3d59cc034480b3..a454e6f69ec646 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -532,6 +532,10 @@ class UnicodeEncodeOp : public OpKernel { const Tensor& input_splits = context->input(1); const auto input_splits_flat = input_splits.flat(); + OP_REQUIRES( + context, input_tensor.dims() == 1 && input_splits.dims() == 1, + absl::InvalidArgumentError( + "Both the input_tensor and input_splits should be of rank 1. ")); OP_REQUIRES( context, input_splits.NumElements() > 0, errors::InvalidArgument("Input_splits should contain elements, but " From 3c7b63ecd0afc101c0c889b194e4869906054043 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 16 Mar 2024 10:21:21 -0700 Subject: [PATCH 015/670] Integrate LLVM at llvm/llvm-project@a4ca07f13b56 Updates LLVM usage to match [a4ca07f13b56](https://github.com/llvm/llvm-project/commit/a4ca07f13b56) PiperOrigin-RevId: 616433576 --- .../tensorflow/utils/dump_mlir_util_test.cc | 2 +- third_party/llvm/generated.patch | 811 ++++++++++++++---- third_party/llvm/workspace.bzl | 4 +- .../service/cpu/hlo_xla_runtime_pipeline.cc | 8 +- 4 files changed, 656 insertions(+), 169 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index bb474b1413f7ac..2efd63b29b04ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -125,7 +125,7 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { EXPECT_TRUE(mlir::MlirOptMain(output_stream->os(), std::move(input_file), registry, mlir::MlirOptMainConfig{} - .splitInputFile(false) + .splitInputFile("") .verifyDiagnostics(false) .verifyPasses(false) .allowUnregisteredDialects(false) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index b75801c374943b..575d74a4816f67 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,183 +1,670 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/AutoUpgrade.h b/llvm/include/llvm/IR/AutoUpgrade.h ---- a/llvm/include/llvm/IR/AutoUpgrade.h -+++ b/llvm/include/llvm/IR/AutoUpgrade.h -@@ -88,9 +88,6 @@ - /// info. Return true if module is modified. - bool UpgradeDebugInfo(Module &M); - -- /// Copies module attributes to the functions in the module. -- void CopyModuleAttrToFunctions(Module &M); -- - /// Check whether a string looks like an old loop attachment tag. - inline bool mayBeOldLoopAttachmentTag(StringRef Name) { - return Name.starts_with("llvm.vectorizer."); -diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp ---- a/llvm/lib/IR/AutoUpgrade.cpp -+++ b/llvm/lib/IR/AutoUpgrade.cpp -@@ -5178,72 +5178,6 @@ - Arg.removeAttrs(AttributeFuncs::typeIncompatible(Arg.getType())); +diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst +--- a/clang/docs/ReleaseNotes.rst ++++ b/clang/docs/ReleaseNotes.rst +@@ -201,21 +201,6 @@ + and each must be a positive integer when provided. The parameter ``x`` is required, while ``y`` and + ``z`` are optional with default value of 1. + +-- The ``_Nullable`` and ``_Nonnull`` family of type attributes can now apply +- to certain C++ class types, such as smart pointers: +- ``void useObject(std::unique_ptr _Nonnull obj);``. +- +- This works for standard library types including ``unique_ptr``, ``shared_ptr``, +- and ``function``. See +- `the attribute reference documentation `_ +- for the full list. +- +-- The ``_Nullable`` attribute can be applied to C++ class declarations: +- ``template class _Nullable MySmartPointer {};``. +- +- This allows the ``_Nullable`` and ``_Nonnull`` family of type attributes to +- apply to this class. +- + Improvements to Clang's diagnostics + ----------------------------------- + - Clang now applies syntax highlighting to the code snippets it +diff -ruN --strip-trailing-cr a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td +--- a/clang/include/clang/Basic/AttrDocs.td ++++ b/clang/include/clang/Basic/AttrDocs.td +@@ -4151,20 +4151,6 @@ + @property (assign, nullable) NSView *superview; + @property (readonly, nonnull) NSArray *subviews; + @end +- +-As well as built-in pointer types, the nullability attributes can be attached +-to C++ classes marked with the ``_Nullable`` attribute. +- +-The following C++ standard library types are considered nullable: +-``unique_ptr``, ``shared_ptr``, ``auto_ptr``, ``exception_ptr``, ``function``, +-``move_only_function`` and ``coroutine_handle``. +- +-Types should be marked nullable only where the type itself leaves nullability +-ambiguous. For example, ``std::optional`` is not marked ``_Nullable``, because +-``optional _Nullable`` is redundant and ``optional _Nonnull`` is +-not a useful type. ``std::weak_ptr`` is not nullable, because its nullability +-can change with no visible modification, so static annotation is unlikely to be +-unhelpful. + }]; } --// Check if the module attribute is present and not zero. --static bool isModuleAttributeSet(Module &M, const StringRef &ModAttr) { -- const auto *Attr = -- mdconst::extract_or_null(M.getModuleFlag(ModAttr)); -- return Attr && Attr->getZExtValue(); --} +@@ -4199,17 +4185,6 @@ + int fetch_or_zero(int * _Nullable ptr); + + a caller of ``fetch_or_zero`` can provide null. - --// Copy an attribute from module to the function if exists. --// First value of the pair is used when the module attribute is not zero --// the second otherwise. --static void --CopyModuleAttributeToFunction(Function &F, StringRef FnAttrName, -- StringRef ModAttrName, -- std::pair Values) { -- if (F.hasFnAttribute(FnAttrName)) -- return; -- F.addFnAttr(FnAttrName, isModuleAttributeSet(*F.getParent(), ModAttrName) -- ? Values.first -- : Values.second); --} +-The ``_Nullable`` attribute on classes indicates that the given class can +-represent null values, and so the ``_Nullable``, ``_Nonnull`` etc qualifiers +-make sense for this type. For example: - --// Copy a boolean attribute from module to the function if exists. --// Module attribute treated false if zero otherwise true. --static void CopyModuleAttributeToFunction(Function &F, StringRef AttrName) { -- CopyModuleAttributeToFunction( -- F, AttrName, AttrName, -- std::make_pair("true", "false")); +- .. code-block:: c +- +- class _Nullable ArenaPointer { ... }; +- +- ArenaPointer _Nonnull x = ...; +- ArenaPointer _Nullable y = nullptr; + }]; + } + +diff -ruN --strip-trailing-cr a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td +--- a/clang/include/clang/Basic/Attr.td ++++ b/clang/include/clang/Basic/Attr.td +@@ -2178,10 +2178,9 @@ + let Documentation = [TypeNonNullDocs]; + } + +-def TypeNullable : DeclOrTypeAttr { ++def TypeNullable : TypeAttr { + let Spellings = [CustomKeyword<"_Nullable">]; + let Documentation = [TypeNullableDocs]; +-// let Subjects = SubjectList<[CXXRecord], ErrorDiag>; + } + + def TypeNullableResult : TypeAttr { +diff -ruN --strip-trailing-cr a/clang/include/clang/Basic/Features.def b/clang/include/clang/Basic/Features.def +--- a/clang/include/clang/Basic/Features.def ++++ b/clang/include/clang/Basic/Features.def +@@ -94,7 +94,6 @@ + FEATURE(enumerator_attributes, true) + FEATURE(nullability, true) + FEATURE(nullability_on_arrays, true) +-FEATURE(nullability_on_classes, true) + FEATURE(nullability_nullable_result, true) + FEATURE(memory_sanitizer, + LangOpts.Sanitize.hasOneOf(SanitizerKind::Memory | +diff -ruN --strip-trailing-cr a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h +--- a/clang/include/clang/Parse/Parser.h ++++ b/clang/include/clang/Parse/Parser.h +@@ -3014,7 +3014,6 @@ + void DiagnoseAndSkipExtendedMicrosoftTypeAttributes(); + SourceLocation SkipExtendedMicrosoftTypeAttributes(); + void ParseMicrosoftInheritanceClassAttributes(ParsedAttributes &attrs); +- void ParseNullabilityClassAttributes(ParsedAttributes &attrs); + void ParseBorlandTypeAttributes(ParsedAttributes &attrs); + void ParseOpenCLKernelAttributes(ParsedAttributes &attrs); + void ParseOpenCLQualifiers(ParsedAttributes &Attrs); +diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h +--- a/clang/include/clang/Sema/Sema.h ++++ b/clang/include/clang/Sema/Sema.h +@@ -1655,9 +1655,6 @@ + /// Add [[gsl::Pointer]] attributes for std:: types. + void inferGslPointerAttribute(TypedefNameDecl *TD); + +- /// Add _Nullable attributes for std:: types. +- void inferNullableClassAttribute(CXXRecordDecl *CRD); +- + enum PragmaOptionsAlignKind { + POAK_Native, // #pragma options align=native + POAK_Natural, // #pragma options align=natural +diff -ruN --strip-trailing-cr a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp +--- a/clang/lib/AST/Type.cpp ++++ b/clang/lib/AST/Type.cpp +@@ -4558,15 +4558,16 @@ + case Type::Auto: + return ResultIfUnknown; + +- // Dependent template specializations could instantiate to pointer types. ++ // Dependent template specializations can instantiate to pointer ++ // types unless they're known to be specializations of a class ++ // template. + case Type::TemplateSpecialization: +- // If it's a known class template, we can already check if it's nullable. +- if (TemplateDecl *templateDecl = +- cast(type.getTypePtr()) +- ->getTemplateName() +- .getAsTemplateDecl()) +- if (auto *CTD = dyn_cast(templateDecl)) +- return CTD->getTemplatedDecl()->hasAttr(); ++ if (TemplateDecl *templateDecl ++ = cast(type.getTypePtr()) ++ ->getTemplateName().getAsTemplateDecl()) { ++ if (isa(templateDecl)) ++ return false; ++ } + return ResultIfUnknown; + + case Type::Builtin: +@@ -4623,17 +4624,6 @@ + } + llvm_unreachable("unknown builtin type"); + +- case Type::Record: { +- const RecordDecl *RD = cast(type)->getDecl(); +- // For template specializations, look only at primary template attributes. +- // This is a consistent regardless of whether the instantiation is known. +- if (const auto *CTSD = dyn_cast(RD)) +- return CTSD->getSpecializedTemplate() +- ->getTemplatedDecl() +- ->hasAttr(); +- return RD->hasAttr(); +- } +- + // Non-pointer types. + case Type::Complex: + case Type::LValueReference: +@@ -4651,6 +4641,7 @@ + case Type::DependentAddressSpace: + case Type::FunctionProto: + case Type::FunctionNoProto: ++ case Type::Record: + case Type::DeducedTemplateSpecialization: + case Type::Enum: + case Type::InjectedClassName: +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp +--- a/clang/lib/CodeGen/CGCall.cpp ++++ b/clang/lib/CodeGen/CGCall.cpp +@@ -4372,8 +4372,7 @@ + NNAttr = getNonNullAttr(AC.getDecl(), PVD, ArgType, ArgNo); + + bool CanCheckNullability = false; +- if (SanOpts.has(SanitizerKind::NullabilityArg) && !NNAttr && PVD && +- !PVD->getType()->isRecordType()) { ++ if (SanOpts.has(SanitizerKind::NullabilityArg) && !NNAttr && PVD) { + auto Nullability = PVD->getType()->getNullability(); + CanCheckNullability = Nullability && + *Nullability == NullabilityKind::NonNull && +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp +--- a/clang/lib/CodeGen/CodeGenFunction.cpp ++++ b/clang/lib/CodeGen/CodeGenFunction.cpp +@@ -979,8 +979,7 @@ + // return value. Initialize the flag to 'true' and refine it in EmitParmDecl. + if (SanOpts.has(SanitizerKind::NullabilityReturn)) { + auto Nullability = FnRetTy->getNullability(); +- if (Nullability && *Nullability == NullabilityKind::NonNull && +- !FnRetTy->isRecordType()) { ++ if (Nullability && *Nullability == NullabilityKind::NonNull) { + if (!(SanOpts.has(SanitizerKind::ReturnsNonnullAttribute) && + CurCodeDecl && CurCodeDecl->getAttr())) + RetValNullabilityPrecondition = +diff -ruN --strip-trailing-cr a/clang/lib/Parse/ParseDeclCXX.cpp b/clang/lib/Parse/ParseDeclCXX.cpp +--- a/clang/lib/Parse/ParseDeclCXX.cpp ++++ b/clang/lib/Parse/ParseDeclCXX.cpp +@@ -1494,15 +1494,6 @@ + } + } + +-void Parser::ParseNullabilityClassAttributes(ParsedAttributes &attrs) { +- while (Tok.is(tok::kw__Nullable)) { +- IdentifierInfo *AttrName = Tok.getIdentifierInfo(); +- auto Kind = Tok.getKind(); +- SourceLocation AttrNameLoc = ConsumeToken(); +- attrs.addNew(AttrName, AttrNameLoc, nullptr, AttrNameLoc, nullptr, 0, Kind); +- } -} - --// Copy an attribute from module to the function if exists. --// First value of the pair is used when the module attribute is not zero --// the second otherwise. --static void --CopyModuleAttributeToFunction(Function &F, StringRef AttrName, -- std::pair Values) { -- CopyModuleAttributeToFunction(F, AttrName, AttrName, Values); + /// Determine whether the following tokens are valid after a type-specifier + /// which could be a standalone declaration. This will conservatively return + /// true if there's any doubt, and is appropriate for insert-';' fixits. +@@ -1684,21 +1675,15 @@ + + ParsedAttributes attrs(AttrFactory); + // If attributes exist after tag, parse them. +- for (;;) { +- MaybeParseAttributes(PAKM_CXX11 | PAKM_Declspec | PAKM_GNU, attrs); +- // Parse inheritance specifiers. +- if (Tok.isOneOf(tok::kw___single_inheritance, +- tok::kw___multiple_inheritance, +- tok::kw___virtual_inheritance)) { +- ParseMicrosoftInheritanceClassAttributes(attrs); +- continue; +- } +- if (Tok.is(tok::kw__Nullable)) { +- ParseNullabilityClassAttributes(attrs); +- continue; +- } +- break; +- } ++ MaybeParseAttributes(PAKM_CXX11 | PAKM_Declspec | PAKM_GNU, attrs); ++ ++ // Parse inheritance specifiers. ++ if (Tok.isOneOf(tok::kw___single_inheritance, tok::kw___multiple_inheritance, ++ tok::kw___virtual_inheritance)) ++ ParseMicrosoftInheritanceClassAttributes(attrs); ++ ++ // Allow attributes to precede or succeed the inheritance specifiers. ++ MaybeParseAttributes(PAKM_CXX11 | PAKM_Declspec | PAKM_GNU, attrs); + + // Source location used by FIXIT to insert misplaced + // C++11 attributes +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaAttr.cpp b/clang/lib/Sema/SemaAttr.cpp +--- a/clang/lib/Sema/SemaAttr.cpp ++++ b/clang/lib/Sema/SemaAttr.cpp +@@ -215,18 +215,6 @@ + inferGslPointerAttribute(Record, Record); + } + +-void Sema::inferNullableClassAttribute(CXXRecordDecl *CRD) { +- static llvm::StringSet<> Nullable{ +- "auto_ptr", "shared_ptr", "unique_ptr", "exception_ptr", +- "coroutine_handle", "function", "move_only_function", +- }; +- +- if (CRD->isInStdNamespace() && Nullable.count(CRD->getName()) && +- !CRD->hasAttr()) +- for (Decl *Redecl : CRD->redecls()) +- Redecl->addAttr(TypeNullableAttr::CreateImplicit(Context)); -} - --void llvm::CopyModuleAttrToFunctions(Module &M) { -- Triple T(M.getTargetTriple()); -- if (!T.isThumb() && !T.isARM() && !T.isAArch64()) + void Sema::ActOnPragmaOptionsAlign(PragmaOptionsAlignKind Kind, + SourceLocation PragmaLoc) { + PragmaMsStackAction Action = Sema::PSK_Reset; +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp +--- a/clang/lib/Sema/SemaChecking.cpp ++++ b/clang/lib/Sema/SemaChecking.cpp +@@ -27,7 +27,6 @@ + #include "clang/AST/ExprObjC.h" + #include "clang/AST/ExprOpenMP.h" + #include "clang/AST/FormatString.h" +-#include "clang/AST/IgnoreExpr.h" + #include "clang/AST/NSAPI.h" + #include "clang/AST/NonTrivialTypeVisitor.h" + #include "clang/AST/OperationKinds.h" +@@ -7358,14 +7357,6 @@ + /// + /// Returns true if the value evaluates to null. + static bool CheckNonNullExpr(Sema &S, const Expr *Expr) { +- // Treat (smart) pointers constructed from nullptr as null, whether we can +- // const-evaluate them or not. +- // This must happen first: the smart pointer expr might have _Nonnull type! +- if (isa( +- IgnoreExprNodes(Expr, IgnoreImplicitAsWrittenSingleStep, +- IgnoreElidableImplicitConstructorSingleStep))) +- return true; +- + // If the expression has non-null type, it doesn't evaluate to null. + if (auto nullability = Expr->IgnoreImplicit()->getType()->getNullability()) { + if (*nullability == NullabilityKind::NonNull) +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp +--- a/clang/lib/Sema/SemaDeclAttr.cpp ++++ b/clang/lib/Sema/SemaDeclAttr.cpp +@@ -5976,20 +5976,6 @@ + D->addAttr(::new (S.Context) BuiltinAliasAttr(S.Context, AL, Ident)); + } + +-static void handleNullableTypeAttr(Sema &S, Decl *D, const ParsedAttr &AL) { +- if (AL.isUsedAsTypeAttr()) - return; - -- for (Function &F : M.getFunctionList()) { -- if (F.isDeclaration()) -- continue; +- if (auto *CRD = dyn_cast(D); +- !CRD || !(CRD->isClass() || CRD->isStruct())) { +- S.Diag(AL.getRange().getBegin(), diag::err_attribute_wrong_decl_type_str) +- << AL << AL.isRegularKeywordAttribute() << "classes"; +- return; +- } - -- if (!F.hasFnAttribute("sign-return-address")) { -- StringRef SignType = "none"; -- if (isModuleAttributeSet(M, "sign-return-address")) -- SignType = "non-leaf"; +- handleSimpleAttribute(S, D, AL); +-} - -- if (isModuleAttributeSet(M, "sign-return-address-all")) -- SignType = "all"; + static void handlePreferredTypeAttr(Sema &S, Decl *D, const ParsedAttr &AL) { + if (!AL.hasParsedType()) { + S.Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1; +@@ -9959,10 +9945,6 @@ + case ParsedAttr::AT_UsingIfExists: + handleSimpleAttribute(S, D, AL); + break; - -- F.addFnAttr("sign-return-address", SignType); -- } -- CopyModuleAttributeToFunction(F, "branch-target-enforcement"); -- CopyModuleAttributeToFunction(F, "branch-protection-pauth-lr"); -- CopyModuleAttributeToFunction(F, "guarded-control-stack"); -- CopyModuleAttributeToFunction( -- F, "sign-return-address-key", -- std::make_pair("b_key", "a_key")); +- case ParsedAttr::AT_TypeNullable: +- handleNullableTypeAttr(S, D, AL); +- break; + } + } + +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp +--- a/clang/lib/Sema/SemaDecl.cpp ++++ b/clang/lib/Sema/SemaDecl.cpp +@@ -18254,10 +18254,8 @@ + if (PrevDecl) + mergeDeclAttributes(New, PrevDecl); + +- if (auto *CXXRD = dyn_cast(New)) { ++ if (auto *CXXRD = dyn_cast(New)) + inferGslOwnerPointerAttribute(CXXRD); +- inferNullableClassAttribute(CXXRD); - } + + // If there's a #pragma GCC visibility in scope, set the visibility of this + // record. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp +--- a/clang/lib/Sema/SemaInit.cpp ++++ b/clang/lib/Sema/SemaInit.cpp +@@ -7075,11 +7075,6 @@ + hasCopyOrMoveCtorParam(S.Context, + getConstructorInfo(Step.Function.FoundDecl)); + +- // A smart pointer constructed from a nullable pointer is nullable. +- if (NumArgs == 1 && !Kind.isExplicitCast()) +- S.diagnoseNullableToNonnullConversion( +- Entity.getType(), Args.front()->getType(), Kind.getLocation()); +- + // Determine the arguments required to actually perform the constructor + // call. + if (S.CompleteConstructorCall(Constructor, Step.Type, Args, Loc, +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp +--- a/clang/lib/Sema/SemaOverload.cpp ++++ b/clang/lib/Sema/SemaOverload.cpp +@@ -14797,13 +14797,6 @@ + } + } + +- // Check for nonnull = nullable. +- // This won't be caught in the arg's initialization: the parameter to +- // the assignment operator is not marked nonnull. +- if (Op == OO_Equal) +- diagnoseNullableToNonnullConversion(Args[0]->getType(), +- Args[1]->getType(), OpLoc); +- + // Convert the arguments. + if (CXXMethodDecl *Method = dyn_cast(FnDecl)) { + // Best->Access is only meaningful for class members. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp +--- a/clang/lib/Sema/SemaTemplate.cpp ++++ b/clang/lib/Sema/SemaTemplate.cpp +@@ -2171,7 +2171,6 @@ + + AddPushedVisibilityAttribute(NewClass); + inferGslOwnerPointerAttribute(NewClass); +- inferNullableClassAttribute(NewClass); + + if (TUK != TUK_Friend) { + // Per C++ [basic.scope.temp]p2, skip the template parameter scopes. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp +--- a/clang/lib/Sema/SemaType.cpp ++++ b/clang/lib/Sema/SemaType.cpp +@@ -4711,18 +4711,6 @@ + return false; + } + +-// Whether this is a type broadly expected to have nullability attached. +-// These types are affected by `#pragma assume_nonnull`, and missing nullability +-// will be diagnosed with -Wnullability-completeness. +-static bool shouldHaveNullability(QualType T) { +- return T->canHaveNullability(/*ResultIfUnknown=*/false) && +- // For now, do not infer/require nullability on C++ smart pointers. +- // It's unclear whether the pragma's behavior is useful for C++. +- // e.g. treating type-aliases and template-type-parameters differently +- // from types of declarations can be surprising. +- !isa(T); -} - - static bool isOldLoopArgument(Metadata *MD) { - auto *T = dyn_cast_or_null(MD); - if (!T) -diff -ruN --strip-trailing-cr a/llvm/lib/Linker/IRMover.cpp b/llvm/lib/Linker/IRMover.cpp ---- a/llvm/lib/Linker/IRMover.cpp -+++ b/llvm/lib/Linker/IRMover.cpp -@@ -1606,11 +1606,6 @@ - // Loop over all of the linked values to compute type mappings. - computeTypeMapping(); - -- // Convert module level attributes to function level attributes because -- // after merging modules the attributes might change and would have different -- // effect on the functions as the original module would have. -- CopyModuleAttrToFunctions(*SrcM); -- - std::reverse(Worklist.begin(), Worklist.end()); - while (!Worklist.empty()) { - GlobalValue *GV = Worklist.back(); -diff -ruN --strip-trailing-cr a/llvm/test/Linker/link-arm-and-thumb.ll b/llvm/test/Linker/link-arm-and-thumb.ll ---- a/llvm/test/Linker/link-arm-and-thumb.ll -+++ b/llvm/test/Linker/link-arm-and-thumb.ll -@@ -13,12 +13,11 @@ - ret i32 %add + static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state, + QualType declSpecType, + TypeSourceInfo *TInfo) { +@@ -4841,7 +4829,8 @@ + // inner pointers. + complainAboutMissingNullability = CAMN_InnerPointers; + +- if (shouldHaveNullability(T) && !T->getNullability()) { ++ if (T->canHaveNullability(/*ResultIfUnknown*/ false) && ++ !T->getNullability()) { + // Note that we allow but don't require nullability on dependent types. + ++NumPointersRemaining; + } +@@ -5064,7 +5053,8 @@ + // If the type itself could have nullability but does not, infer pointer + // nullability and perform consistency checking. + if (S.CodeSynthesisContexts.empty()) { +- if (shouldHaveNullability(T) && !T->getNullability()) { ++ if (T->canHaveNullability(/*ResultIfUnknown*/ false) && ++ !T->getNullability()) { + if (isVaList(T)) { + // Record that we've seen a pointer, but do nothing else. + if (NumPointersRemaining > 0) +diff -ruN --strip-trailing-cr a/clang/test/Sema/nullability.c b/clang/test/Sema/nullability.c +--- a/clang/test/Sema/nullability.c ++++ b/clang/test/Sema/nullability.c +@@ -248,5 +248,3 @@ + void (^withTypedefBad)(INTS _Nonnull [2]) = // expected-error {{nullability specifier '_Nonnull' cannot be applied to non-pointer type 'INTS' (aka 'int[4]')}} + ^(INTS _Nonnull x[2]) {}; // expected-error {{nullability specifier '_Nonnull' cannot be applied to non-pointer type 'INTS' (aka 'int[4]')}} } +- +-struct _Nullable NotCplusplusClass {}; // expected-error {{'_Nullable' attribute only applies to classes}} +diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/nullability.cpp b/clang/test/SemaCXX/nullability.cpp +--- a/clang/test/SemaCXX/nullability.cpp ++++ b/clang/test/SemaCXX/nullability.cpp +@@ -4,10 +4,6 @@ + #else + # error nullability feature should be defined + #endif +-#if __has_feature(nullability_on_classes) +-#else +-# error smart-pointer feature should be defined +-#endif + + #include "nullability-completeness.h" + +@@ -31,7 +27,6 @@ + struct AddNonNull { + typedef _Nonnull T type; // expected-error{{nullability specifier '_Nonnull' cannot be applied to non-pointer type 'int'}} + // expected-error@-1{{nullability specifier '_Nonnull' cannot be applied to non-pointer type 'std::nullptr_t'}} +- // expected-error@-2{{nullability specifier '_Nonnull' cannot be applied to non-pointer type 'NotPtr'}} + }; --; CHECK: define i32 @main() [[MAIN_ATTRS:#[0-9]+]] -+; CHECK: define i32 @main() { - ; CHECK: define i32 @foo(i32 %a, i32 %b) [[ARM_ATTRS:#[0-9]+]] - ; CHECK: define i32 @bar(i32 %a, i32 %b) [[THUMB_ATTRS:#[0-9]+]] - --; CHECK: attributes [[MAIN_ATTRS]] = { {{.*}} } --; CHECK: attributes [[ARM_ATTRS]] = { {{.*}} "target-features"="-thumb-mode" } --; CHECK: attributes [[THUMB_ATTRS]] = { {{.*}} "target-features"="+thumb-mode" } -+; CHECK: attributes [[ARM_ATTRS]] = { "target-features"="-thumb-mode" } -+; CHECK: attributes [[THUMB_ATTRS]] = { "target-features"="+thumb-mode" } - - ; STDERR-NOT: warning: Linking two modules of different target triples: -diff -ruN --strip-trailing-cr a/llvm/test/LTO/AArch64/link-branch-target-enforcement.ll b/llvm/test/LTO/AArch64/link-branch-target-enforcement.ll ---- a/llvm/test/LTO/AArch64/link-branch-target-enforcement.ll -+++ b/llvm/test/LTO/AArch64/link-branch-target-enforcement.ll -@@ -32,7 +32,6 @@ - ; CHECK-DUMP:
: - ; CHECK-DUMP: bl 0x8 - ; CHECK-DUMP: : --; CHECK-DUMP: paciasp - - ; `main` doesn't support BTI while `foo` does, so in the binary - ; we should see only PAC which is supported by both. -diff -ruN --strip-trailing-cr a/llvm/test/LTO/AArch64/link-sign-return-address.ll b/llvm/test/LTO/AArch64/link-sign-return-address.ll ---- a/llvm/test/LTO/AArch64/link-sign-return-address.ll -+++ b/llvm/test/LTO/AArch64/link-sign-return-address.ll -@@ -1,43 +0,0 @@ --; Testcase to check that module with different branch-target-enforcement can --; be mixed. --; --; RUN: llvm-as %s -o %t1.bc --; RUN: llvm-as %p/Inputs/foo.ll -o %t2.bc --; RUN: llvm-lto -exported-symbol main \ --; RUN: -exported-symbol foo \ --; RUN: -filetype=obj \ --; RUN: %t2.bc %t1.bc \ --; RUN: -o %t1.exe 2>&1 --; RUN: llvm-objdump -d %t1.exe | FileCheck --check-prefix=CHECK-DUMP %s --; RUN: llvm-readelf -n %t1.exe | FileCheck --allow-empty --check-prefix=CHECK-PROP %s -- --target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" --target triple = "aarch64-unknown-linux-gnu" -- --declare i32 @foo(); -- --define i32 @main() { --entry: -- %add = call i32 @foo() -- ret i32 %add + typedef AddNonNull::type nonnull_int_ptr_1; +@@ -40,33 +35,6 @@ + + typedef AddNonNull::type nonnull_non_pointer_1; // expected-note{{in instantiation of template class 'AddNonNull' requested here}} + +-// Nullability on C++ class types (smart pointers). +-struct NotPtr{}; +-typedef AddNonNull::type nonnull_non_pointer_2; // expected-note{{in instantiation}} +-struct _Nullable SmartPtr{ +- SmartPtr(); +- SmartPtr(nullptr_t); +- SmartPtr(const SmartPtr&); +- SmartPtr(SmartPtr&&); +- SmartPtr &operator=(const SmartPtr&); +- SmartPtr &operator=(SmartPtr&&); +-}; +-typedef AddNonNull::type nonnull_smart_pointer_1; +-template struct _Nullable SmartPtrTemplate{}; +-typedef AddNonNull>::type nonnull_smart_pointer_2; +-namespace std { inline namespace __1 { +- template class unique_ptr {}; +- template class function; +- template class function {}; +-} } +-typedef AddNonNull>::type nonnull_smart_pointer_3; +-typedef AddNonNull>::type nonnull_smart_pointer_4; +- +-class Derived : public SmartPtr {}; +-Derived _Nullable x; // expected-error {{'_Nullable' cannot be applied}} +-class DerivedPrivate : private SmartPtr {}; +-DerivedPrivate _Nullable y; // expected-error {{'_Nullable' cannot be applied}} +- + // Non-null checking within a template. + template + struct AddNonNull2 { +@@ -86,7 +54,6 @@ + void (X::* accepts_nonnull_3)(_Nonnull int *ptr); + void accepts_nonnull_4(_Nonnull int *ptr); + void (&accepts_nonnull_5)(_Nonnull int *ptr) = accepts_nonnull_4; +-void accepts_nonnull_6(SmartPtr _Nonnull); + + void test_accepts_nonnull_null_pointer_literal(X *x) { + accepts_nonnull_1(0); // expected-warning{{null passed to a callee that requires a non-null argument}} +@@ -94,8 +61,6 @@ + (x->*accepts_nonnull_3)(0); // expected-warning{{null passed to a callee that requires a non-null argument}} + accepts_nonnull_4(0); // expected-warning{{null passed to a callee that requires a non-null argument}} + accepts_nonnull_5(0); // expected-warning{{null passed to a callee that requires a non-null argument}} +- +- accepts_nonnull_6(nullptr); // expected-warning{{null passed to a callee that requires a non-null argument}} + } + + template +@@ -106,7 +71,6 @@ + template void test_accepts_nonnull_null_pointer_literal_template<&accepts_nonnull_4>(); // expected-note{{instantiation of function template specialization}} + + void TakeNonnull(void *_Nonnull); +-void TakeSmartNonnull(SmartPtr _Nonnull); + // Check different forms of assignment to a nonull type from a nullable one. + void AssignAndInitNonNull() { + void *_Nullable nullable; +@@ -117,26 +81,12 @@ + void *_Nonnull nonnull; + nonnull = nullable; // expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull'}} + nonnull = {nullable}; // expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull'}} ++ + TakeNonnull(nullable); //expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull}} + TakeNonnull(nonnull); // OK +- nonnull = (void *_Nonnull)nullable; // explicit cast OK +- +- SmartPtr _Nullable s_nullable; +- SmartPtr _Nonnull s(s_nullable); // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s2{s_nullable}; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s3 = {s_nullable}; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s4 = s_nullable; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s_nonnull; +- s_nonnull = s_nullable; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- s_nonnull = {s_nullable}; // no warning here - might be nice? +- TakeSmartNonnull(s_nullable); //expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull}} +- TakeSmartNonnull(s_nonnull); // OK +- s_nonnull = (SmartPtr _Nonnull)s_nullable; // explicit cast OK +- s_nonnull = static_cast(s_nullable); // explicit cast OK + } + + void *_Nullable ReturnNullable(); +-SmartPtr _Nullable ReturnSmartNullable(); + + void AssignAndInitNonNullFromFn() { + void *_Nonnull p(ReturnNullable()); // expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull'}} +@@ -146,16 +96,8 @@ + void *_Nonnull nonnull; + nonnull = ReturnNullable(); // expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull'}} + nonnull = {ReturnNullable()}; // expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull'}} +- TakeNonnull(ReturnNullable()); //expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull}} + +- SmartPtr _Nonnull s(ReturnSmartNullable()); // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s2{ReturnSmartNullable()}; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s3 = {ReturnSmartNullable()}; // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s4 = ReturnSmartNullable(); // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- SmartPtr _Nonnull s_nonnull; +- s_nonnull = ReturnSmartNullable(); // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} +- s_nonnull = {ReturnSmartNullable()}; +- TakeSmartNonnull(ReturnSmartNullable()); // expected-warning{{implicit conversion from nullable pointer 'SmartPtr _Nullable' to non-nullable pointer type 'SmartPtr _Nonnull'}} ++ TakeNonnull(ReturnNullable()); //expected-warning{{implicit conversion from nullable pointer 'void * _Nullable' to non-nullable pointer type 'void * _Nonnull}} + } + + void ConditionalExpr(bool c) { +diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp ++++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +@@ -1019,6 +1019,7 @@ + const DataLayout &DL = getDataLayout(); + + // GlobalVariables are always constant pointers themselves. ++ PointerType *PTy = GVar->getType(); + Type *ETy = GVar->getValueType(); + + if (GVar->hasExternalLinkage()) { +@@ -1026,9 +1027,6 @@ + O << ".visible "; + else + O << ".extern "; +- } else if (GVar->hasCommonLinkage() && +- GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) { +- O << ".common "; + } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() || + GVar->hasAvailableExternallyLinkage() || + GVar->hasCommonLinkage()) { +@@ -1140,7 +1138,7 @@ + } + + O << "."; +- emitPTXAddressSpace(GVar->getAddressSpace(), O); ++ emitPTXAddressSpace(PTy->getAddressSpace(), O); + + if (isManaged(*GVar)) { + if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { +@@ -1169,8 +1167,8 @@ + // Ptx allows variable initilization only for constant and global state + // spaces. + if (GVar->hasInitializer()) { +- if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || +- (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) { ++ if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || ++ (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) { + const Constant *Initializer = GVar->getInitializer(); + // 'undef' is treated as there is no value specified. + if (!Initializer->isNullValue() && !isa(Initializer)) { +@@ -1185,7 +1183,7 @@ + !isa(GVar->getInitializer())) { + report_fatal_error("initial value of '" + GVar->getName() + + "' is not allowed in addrspace(" + +- Twine(GVar->getAddressSpace()) + ")"); ++ Twine(PTy->getAddressSpace()) + ")"); + } + } + } +@@ -1204,8 +1202,8 @@ + ElementSize = DL.getTypeStoreSize(ETy); + // Ptx allows variable initilization only for constant and + // global state spaces. +- if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || +- (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) && ++ if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || ++ (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) && + GVar->hasInitializer()) { + const Constant *Initializer = GVar->getInitializer(); + if (!isa(Initializer) && !Initializer->isNullValue()) { +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/common-linkage.ll b/llvm/test/CodeGen/NVPTX/common-linkage.ll +--- a/llvm/test/CodeGen/NVPTX/common-linkage.ll ++++ b/llvm/test/CodeGen/NVPTX/common-linkage.ll +@@ -1,26 +0,0 @@ +-; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s +-; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %} +- +-; CHECK: .common .global .align 4 .u32 g +-@g = common addrspace(1) global i32 0, align 4 +- +-; CHECK: .weak .const .align 4 .u32 c +-@c = common addrspace(4) global i32 0, align 4 +- +-; CHECK: .weak .shared .align 4 .u32 s +-@s = common addrspace(3) global i32 0, align 4 +- +-define i32 @f1() { +- %1 = load i32, ptr addrspace(1) @g +- ret i32 %1 -} - --!llvm.module.flags = !{!0, !1, !2, !3 } --!0 = !{i32 8, !"branch-target-enforcement", i32 0} --!1 = !{i32 8, !"sign-return-address", i32 0} --!2 = !{i32 8, !"sign-return-address-all", i32 0} --!3 = !{i32 8, !"sign-return-address-with-bkey", i32 0} -- --; CHECK-DUMP: : --; CHECK-DUMP: paciasp --; CHECK-DUMP: mov w0, #0x2a --; CHECK-DUMP: autiasp --; CHECK-DUMP: ret --; CHECK-DUMP:
: --; CHECK-DUMP-NOT: paciasp --; CHECK-DUMP: str x30, --; CHECK-DUMP: bl 0x14 -- --; `main` doesn't support PAC sign-return-address while `foo` does, so in the binary --; we should not see anything. --; CHECK-PROP-NOT: Properties: aarch64 feature: PAC -\ No newline at end of file +-define i32 @f4() { +- %1 = load i32, ptr addrspace(4) @c +- ret i32 %1 +-} +- +-define i32 @f3() { +- %1 = load i32, ptr addrspace(3) @s +- ret i32 %1 +-} +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/weak-global.ll b/llvm/test/CodeGen/NVPTX/weak-global.ll +--- a/llvm/test/CodeGen/NVPTX/weak-global.ll ++++ b/llvm/test/CodeGen/NVPTX/weak-global.ll +@@ -1,7 +1,7 @@ + ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s + ; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %} + +-; CHECK: .common .global .align 4 .u32 g ++; CHECK: .weak .global .align 4 .u32 g + @g = common addrspace(1) global i32 zeroinitializer + + define i32 @func0() { diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 772469ed4698c1..c190989fc46286 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "3b5e7c83a6e226d5bd7ed2e9b67449b64812074c" - LLVM_SHA256 = "7fa7a38aade8b5fa2f7719cd3b6e2f038fed1b00d7369cdb05b490085de79c91" + LLVM_COMMIT = "a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b" + LLVM_SHA256 = "fb936389d46b3ce7ee423c0d788e5359da8ce41cfe8996847719920c6f60b044" tf_http_archive( name = name, diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index 6cc266d9d8d3e8..1bcab94dae3df8 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -179,13 +179,13 @@ static Status CreateHloXlaPipeline( } pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - mlir::bufferization::BufferResultsToOutParamsOptions out_params_options; - out_params_options.filterFn = [](mlir::func::FuncOp* func) { + mlir::bufferization::BufferResultsToOutParamsOpts out_params_opts; + out_params_opts.filterFn = [](mlir::func::FuncOp* func) { // Only transform the entry point. return func->getSymName() == "main"; }; - pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass( - out_params_options)); + pm.addPass( + mlir::bufferization::createBufferResultsToOutParamsPass(out_params_opts)); pm.addNestedPass( mlir::bufferization::createPromoteBuffersToStackPass(nullptr)); From 33a86ec0896b35d1def2c53576ad9cc016796f34 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 16 Mar 2024 15:03:28 -0700 Subject: [PATCH 016/670] [xla:gpu][NFC] Add test for multiple sliced operands for AddressComputationThunk PiperOrigin-RevId: 616467493 --- .../runtime/address_computation_thunk_test.cc | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index 60c2f808677324..c7a8c6b88a7653 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -310,6 +310,143 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { ASSERT_FALSE(thunk.ExecuteOnStream(params).ok()); } +TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + se::Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + int64_t length = sizeof(float) * 2 * 4; + int64_t out_length = sizeof(float) * 1; + int64_t offset_length = sizeof(int64_t) * 2; + int64_t slice_length = sizeof(float) * 3; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + BufferAllocation alloc_lhs(/*index=*/0, length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, length); + + BufferAllocation alloc_rhs(/*index=*/1, length, /*color=*/0); + BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, length); + + BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); + BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + + BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + BufferAllocation alloc_lhs_offset(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_lhs_offset(&alloc_lhs_offset, 0, offset_length); + + BufferAllocation alloc_rhs_offset(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_rhs_offset(&alloc_rhs_offset, 0, offset_length); + + BufferAllocation alloc_lhs_fake(/*index=*/0, slice_length, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, slice_length); + + BufferAllocation alloc_rhs_fake(/*index=*/1, slice_length, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake(&alloc_rhs_fake, 0, slice_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out, slice_workspace, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, + {slice_out, slice_workspace}, {slice_lhs_offset, slice_rhs_offset}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1})}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1})}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `lhs` tensor of shape f32[2,4]{1,0} + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + + // Preparing memory for thunk arguments. + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + std::vector arr{1, 2, 3, 4, 5, 6, 7, 8}; + se::DeviceMemory lhs = executor->AllocateArray(2 * 4); + TF_ASSERT_OK(stream.Memcpy(&lhs, arr.data(), length)); + + // Given a `rhs` tensor of shape f32[8,1]{1,0} + // The `rhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[3,1]{1,0} slice(rhs), slice={[2:5], [0:1]} + // rhs = [1.0, + // 2.0, + // 3.0, + // 4.0, + // 5.0, + // 6.0, + // 7.0, + // 8.0] + se::DeviceMemory rhs = executor->AllocateArray(8); + std::vector rhs_arr(8, 1); + TF_ASSERT_OK(stream.Memcpy(&rhs, arr.data(), length)); + + se::DeviceMemory out = executor->AllocateArray(1); + TF_ASSERT_OK(stream.MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset = executor->AllocateArray(2); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK( + stream.Memcpy(&lhs_offset, lhs_offset_arr.data(), offset_length)); + + se::DeviceMemory rhs_offset = executor->AllocateArray(2); + std::vector rhs_offset_arr{2, 0}; + TF_ASSERT_OK( + stream.Memcpy(&rhs_offset, rhs_offset_arr.data(), offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations( + {lhs, rhs, out, workspace, lhs_offset, rhs_offset}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, &stream, &stream, {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, &stream, &stream})); + + // Execute address computation thunk and verify that it executed a GEMM on the + // right slices. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copy `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream.Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({2 * 3 + 3 * 4 + 4 * 5})); +} + static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, ffi::BufferBase dst) { return stream->MemcpyD2D( From 1283253eb46b9e1caa2c1caa4316dc545272691c Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 16 Mar 2024 16:05:15 -0700 Subject: [PATCH 017/670] [xla:gpu] Add support for sliced results in AddressComputationThunk PiperOrigin-RevId: 616475170 --- .../gpu/runtime/address_computation_thunk.cc | 167 ++++++++++++++---- .../gpu/runtime/address_computation_thunk.h | 26 ++- .../runtime/address_computation_thunk_test.cc | 151 +++++++++++++++- 3 files changed, 293 insertions(+), 51 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 07ae9ac30f67c8..8affba065d2d78 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -47,32 +47,55 @@ AddressComputationThunk::AddressComputationThunk( std::vector> operands, std::vector> results, std::vector> - offset_buffer_indices, - std::vector> orig_shapes, - std::vector> sliced_shapes) + operand_offset_buffer_indices, + std::vector> operand_orig_shapes, + std::vector> operand_sliced_shapes, + std::vector> + result_offset_buffer_indices, + std::vector> result_orig_shapes, + std::vector> result_sliced_shapes) : Thunk(Kind::kAddressComputation, thunk_info), embedded_thunk_(std::make_unique( ThunkInfo(thunk_info.op), std::move(*embedded_thunk))), embedded_thunk_operands_(std::move(operands)), embedded_thunk_results_(std::move(results)), - offset_buffer_indices_(std::move(offset_buffer_indices)), - orig_shapes_(std::move(orig_shapes)), - sliced_shapes_(std::move(sliced_shapes)) {} + operand_offset_buffer_indices_(std::move(operand_offset_buffer_indices)), + operand_orig_shapes_(std::move(operand_orig_shapes)), + operand_sliced_shapes_(std::move(operand_sliced_shapes)), + result_offset_buffer_indices_(std::move(result_offset_buffer_indices)), + result_orig_shapes_(std::move(result_orig_shapes)), + result_sliced_shapes_(std::move(result_sliced_shapes)) {} absl::Status AddressComputationThunk::Prepare( const PrepareParams& params, ResourceRequests& resource_requests) { auto num_operands = embedded_thunk_operands_.size(); - TF_RET_CHECK(num_operands == offset_buffer_indices_.size()); - TF_RET_CHECK(num_operands == orig_shapes_.size()); - TF_RET_CHECK(num_operands == sliced_shapes_.size()); + TF_RET_CHECK(num_operands == operand_offset_buffer_indices_.size()); + TF_RET_CHECK(num_operands == operand_orig_shapes_.size()); + TF_RET_CHECK(num_operands == operand_sliced_shapes_.size()); for (unsigned i = 0; i < num_operands; ++i) { - if (sliced_shapes_[i].has_value()) { + if (operand_sliced_shapes_[i].has_value()) { TF_RET_CHECK(embedded_thunk_operands_[i].has_value()); - TF_RET_CHECK(offset_buffer_indices_[i].has_value()); - TF_RET_CHECK(sliced_shapes_[i]->IsArray()); - TF_RET_CHECK(orig_shapes_[i].has_value() && orig_shapes_[i]->IsArray()); + TF_RET_CHECK(operand_offset_buffer_indices_[i].has_value()); + TF_RET_CHECK(operand_sliced_shapes_[i]->IsArray()); + TF_RET_CHECK(operand_orig_shapes_[i].has_value() && + operand_orig_shapes_[i]->IsArray()); + } + } + + auto num_results = embedded_thunk_results_.size(); + TF_RET_CHECK(num_results == result_offset_buffer_indices_.size()); + TF_RET_CHECK(num_results == result_orig_shapes_.size()); + TF_RET_CHECK(num_results == result_sliced_shapes_.size()); + for (unsigned i = 0; i < num_results; ++i) { + if (result_sliced_shapes_[i].has_value()) { + TF_RET_CHECK(embedded_thunk_results_[i].has_value()); + TF_RET_CHECK(result_offset_buffer_indices_[i].has_value()); + TF_RET_CHECK(result_sliced_shapes_[i]->IsArray()); + TF_RET_CHECK(result_orig_shapes_[i].has_value() && + result_orig_shapes_[i]->IsArray()); } } + TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests)); return absl::OkStatus(); } @@ -81,16 +104,38 @@ absl::Status AddressComputationThunk::Initialize( const InitializeParams& params) { TF_RETURN_IF_ERROR(embedded_thunk_->Initialize(params)); - unsigned num_offsets = 0; - for (auto maybe_shape : sliced_shapes_) { - num_offsets += (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); + unsigned operand_offset_count = 0; + for (auto maybe_shape : operand_sliced_shapes_) { + operand_offset_count += + (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); } - absl::MutexLock lock(&mutex_); - if (auto it = offsets_.find(params.executor); it == offsets_.end()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr allocation, - params.executor->HostMemoryAllocate(num_offsets * sizeof(int64_t))); - offsets_.emplace(params.executor, std::move(allocation)); + + { + absl::MutexLock lock(&mutex_); + if (auto it = operand_offsets_.find(params.executor); + it == operand_offsets_.end()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, + params.executor->HostMemoryAllocate( + operand_offset_count * sizeof(int64_t))); + operand_offsets_.emplace(params.executor, std::move(allocation)); + } + } + + unsigned result_offset_count = 0; + for (auto maybe_shape : result_sliced_shapes_) { + result_offset_count += + (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); + } + + { + absl::MutexLock lock(&mutex_); + if (auto it = result_offsets_.find(params.executor); + it == result_offsets_.end()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, + params.executor->HostMemoryAllocate( + result_offset_count * sizeof(int64_t))); + result_offsets_.emplace(params.executor, std::move(allocation)); + } } return absl::OkStatus(); @@ -99,16 +144,17 @@ absl::Status AddressComputationThunk::Initialize( absl::Status AddressComputationThunk::ExecuteOnStream( const ExecuteParams& params) { auto& stream = *params.stream; + std::vector new_buffers; + const BufferAllocations& orig_allocations = *params.buffer_allocations; - // Get memory allocation for copying offsets from device. - int64_t* offsets_base = [&] { + // Get memory allocation for copying operand offsets from device. + int64_t* operand_offsets_base = [&] { absl::MutexLock lock(&mutex_); - return reinterpret_cast(offsets_.at(stream.parent())->opaque()); + return reinterpret_cast( + operand_offsets_.at(stream.parent())->opaque()); }(); - std::vector new_buffers; - const BufferAllocations& orig_allocations = *params.buffer_allocations; - for (unsigned i = 0; i < offset_buffer_indices_.size(); ++i) { + for (unsigned i = 0; i < operand_offset_buffer_indices_.size(); ++i) { if (embedded_thunk_operands_[i] == std::nullopt) { new_buffers.push_back(se::DeviceMemoryBase()); continue; @@ -116,18 +162,18 @@ absl::Status AddressComputationThunk::ExecuteOnStream( se::DeviceMemoryBase orig_operand = orig_allocations.GetDeviceAddress(*embedded_thunk_operands_[i]); - if (offset_buffer_indices_[i] == std::nullopt) { + if (operand_offset_buffer_indices_[i] == std::nullopt) { new_buffers.push_back(orig_operand); continue; } se::DeviceMemoryBase offset_src = - orig_allocations.GetDeviceAddress(*offset_buffer_indices_[i]); + orig_allocations.GetDeviceAddress(*operand_offset_buffer_indices_[i]); // Copy the ith offset from device to host. - const Shape& src_shape = *orig_shapes_[i]; - const Shape& dst_shape = *sliced_shapes_[i]; - int64_t* offset_dst = &offsets_base[i]; + const Shape& src_shape = *operand_orig_shapes_[i]; + const Shape& dst_shape = *operand_sliced_shapes_[i]; + int64_t* offset_dst = &operand_offsets_base[i]; TF_RETURN_IF_ERROR(stream.Memcpy(offset_dst, offset_src, dst_shape.rank() * sizeof(int64_t))); @@ -155,15 +201,58 @@ absl::Status AddressComputationThunk::ExecuteOnStream( new_buffers.push_back(orig_operand.GetByteSlice(new_offset, new_size)); } - // TODO(vuson): handle DUS too. For now just copy the results over. - for (auto result : embedded_thunk_results_) { - if (result == std::nullopt) { + // Get memory allocation for copying result offsets from device. + int64_t* result_offsets_base = [&] { + absl::MutexLock lock(&mutex_); + return reinterpret_cast( + result_offsets_.at(stream.parent())->opaque()); + }(); + + for (unsigned i = 0; i < result_offset_buffer_indices_.size(); ++i) { + if (embedded_thunk_results_[i] == std::nullopt) { new_buffers.push_back(se::DeviceMemoryBase()); - } else { - se::DeviceMemoryBase orig_result = - orig_allocations.GetDeviceAddress(*result); + continue; + } + + se::DeviceMemoryBase orig_result = + orig_allocations.GetDeviceAddress(*embedded_thunk_results_[i]); + if (result_offset_buffer_indices_[i] == std::nullopt) { new_buffers.push_back(orig_result); + continue; + } + + se::DeviceMemoryBase offset_src = + orig_allocations.GetDeviceAddress(*result_offset_buffer_indices_[i]); + + // Copy the ith offset from device to host. + const Shape& src_shape = *result_orig_shapes_[i]; + const Shape& dst_shape = *result_sliced_shapes_[i]; + int64_t* offset_dst = &result_offsets_base[i]; + TF_RETURN_IF_ERROR(stream.Memcpy(offset_dst, offset_src, + dst_shape.rank() * sizeof(int64_t))); + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to retrieve all slice offset values on stream %p: %s", + &stream, blocked.message())); + } + + // Compute new slice. No need to copy the content to new buffers as we can + // reuse the original buffers since slices are contiguous. + TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); + + int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); + BufferAllocation::Slice orig_slice = *embedded_thunk_results_[i]; + + int64_t new_offset = orig_slice.offset(); + std::vector slice_starts(offset_dst, + offset_dst + dst_shape.rank()); + for (auto [start, stride] : + llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { + new_offset += start * stride; } + + new_buffers.push_back(orig_result.GetByteSlice(new_offset, new_size)); } // Safe to create a local BufferAllocations here since buffers are only slices diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index abb6d89ed1f59c..d4bdbfe287d9b1 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -47,9 +47,13 @@ class AddressComputationThunk : public Thunk { std::vector> operands, std::vector> results, std::vector> - offset_buffer_indices, - std::vector> orig_shapes, - std::vector> sliced_shapes); + operand_offset_buffer_indices, + std::vector> operand_orig_shapes, + std::vector> operand_sliced_shapes, + std::vector> + result_offset_buffer_indices, + std::vector> result_orig_shapes, + std::vector> result_sliced_shapes); AddressComputationThunk(const AddressComputationThunk&) = delete; AddressComputationThunk& operator=(const AddressComputationThunk&) = delete; @@ -66,16 +70,22 @@ class AddressComputationThunk : public Thunk { std::vector> embedded_thunk_results_; std::vector> - offset_buffer_indices_; - - std::vector> orig_shapes_; - std::vector> sliced_shapes_; + operand_offset_buffer_indices_; + std::vector> operand_orig_shapes_; + std::vector> operand_sliced_shapes_; + std::vector> + result_offset_buffer_indices_; + std::vector> result_orig_shapes_; + std::vector> result_sliced_shapes_; // Pinned host memory for transferring offset values from device to host. absl::Mutex mutex_; absl::flat_hash_map> - offsets_ ABSL_GUARDED_BY(mutex_); + operand_offsets_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> + result_offsets_ ABSL_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index c7a8c6b88a7653..e783cdea0ba6a3 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -124,7 +124,9 @@ TEST(AddressComputationThunkTest, SlicedGemm) { std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, {slice_out, slice_workspace}, {slice_lhs_offset, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt}, - {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt}); + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt}, + {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt}); // Step 2: // Execute address computation thunk. @@ -246,7 +248,9 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3})}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2})}); + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2})}, + {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt}); // Step 2: // Execute address computation thunk. @@ -372,7 +376,9 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1})}, {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), - ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1})}); + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1})}, + {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt}); // Step 2: // Execute address computation thunk. @@ -520,7 +526,8 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { {slice_offset}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) - {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8})}); + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8})}, {std::nullopt}, + {std::nullopt}, {std::nullopt}); // Step 2: // Execute address computation thunk. @@ -573,4 +580,140 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { ASSERT_EQ(out, ref); } +TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { + se::StreamExecutor* executor = GpuExecutor(); + + se::Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + int64_t src_count = 8 * 8 * 10 * 2; + int64_t dst_count = 2 * 2 * 2 * 2; + int64_t slice_count = 2 * 2; + int64_t src_length = sizeof(int32_t) * src_count; + int64_t dst_length = sizeof(int32_t) * dst_count; + int64_t offset_length = sizeof(int64_t) * 4; + int64_t slice_length = sizeof(int32_t) * slice_count; + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); + BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); + + BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); + BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); + + BufferAllocation alloc_src_offset(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset(&alloc_src_offset, 0, offset_length); + + BufferAllocation alloc_dst_offset(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset(&alloc_dst_offset, 0, offset_length); + + // Fake slices for embedded thunk creation. + BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); + BufferAllocation::Slice slice_src_fake(&alloc_src_fake, 0, slice_length); + + BufferAllocation alloc_dst_fake(/*index=*/1, slice_length, /*color=*/0); + BufferAllocation::Slice slice_dst_fake(&alloc_dst_fake, 0, slice_length); + + // Preparing custom call thunk: setting up call target and operands + results + // buffers. + auto handler = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(handler.ok()); + + std::vector> operands{ + CustomCallThunk::Slice{slice_src_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + std::vector> results{ + CustomCallThunk::Slice{slice_dst_fake, + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2})}}; + + // Creating embedded custom call thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), *handler, operands, results, + /*attributes=*/CustomCallThunk::AttributesMap(), + /*called_computation=*/nullptr)); + + // Wrapping address computation thunk around the custom call thunk. + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), {slice_src}, {slice_dst}, + {slice_src_offset}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2})}, + // Make sure to pass a dst shape with the same rank as src shape (i.e. + // original slice result and not bitcasted one) + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}, + {slice_dst_offset}, + {{ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}); + + // Step 2: + // Execute address computation thunk. + // + // Given a `src` tensor of shape s32[8,8,10,2]{3,2,1,0} + // The `src` slice that we want to copy from will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(src), slice={[3:4], [5:6], [2:4], [0:2]} + // + // Given a `dst` tensor of shape s32[2,2,2,2]{3,2,1,0} + // The `dst` slice that we want to copy into will be equivalent to this static + // slice op: + // s32[1,1,2,2]{3,2,1,0} slice(dst), slice={[1:2], [1:2], [0:2], [0:2]} + + // Preparing memory for thunk arguments. + se::DeviceMemory src = executor->AllocateArray(src_count); + std::vector src_arr(src_count, 0); + for (unsigned i = 0; i < src_count; ++i) src_arr[i] = i; + TF_ASSERT_OK(stream.Memcpy(&src, src_arr.data(), src_length)); + + se::DeviceMemory dst = executor->AllocateArray(dst_count); + TF_ASSERT_OK(stream.MemZero(&dst, dst_length)); + + se::DeviceMemory src_offset = executor->AllocateArray(4); + std::vector src_offset_arr{3, 5, 2, 0}; + TF_ASSERT_OK( + stream.Memcpy(&src_offset, src_offset_arr.data(), offset_length)); + + se::DeviceMemory dst_offset = executor->AllocateArray(4); + std::vector dst_offset_arr{1, 1, 0, 0}; + TF_ASSERT_OK( + stream.Memcpy(&dst_offset, dst_offset_arr.data(), offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({src, dst, src_offset, dst_offset}, 0, + executor->GetAllocator()); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, &stream, &stream, {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, &stream, &stream})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copying `dst` data back to host for verification. + std::vector out(dst_count, 0); + TF_ASSERT_OK(stream.Memcpy(out.data(), dst, dst_length)); + + // Verifying that the right slice of `src` was copied to `dst`. + std::vector ref(dst_count, 0); + int64_t src_offset_val = + src_offset_arr[3] + + 2 * (src_offset_arr[2] + + 10 * (src_offset_arr[1] + 8 * src_offset_arr[0])); + int64_t dst_offset_val = + dst_offset_arr[3] + + 2 * (dst_offset_arr[2] + 2 * (dst_offset_arr[1] + 2 * dst_offset_arr[0])); + std::copy(src_arr.begin() + src_offset_val, + src_arr.begin() + src_offset_val + slice_count, + ref.begin() + dst_offset_val); + ASSERT_EQ(out, ref); +} + } // namespace xla::gpu From de598b814a737e17292c99fe8036c33d2f170141 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 17 Mar 2024 02:02:06 -0700 Subject: [PATCH 018/670] compat: Update forward compatibility horizon to 2024-03-17 PiperOrigin-RevId: 616552576 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5a949daa30884e..382dfdf2eb7712 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 16) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 17) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 042a7c1bbcc7636ee7ea6d9469a061296c1ddf97 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 17 Mar 2024 02:02:13 -0700 Subject: [PATCH 019/670] Update GraphDef version to 1804. PiperOrigin-RevId: 616552596 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8fa4e8122aab1e..c96dd2e0380234 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1803 // Updated: 2024/3/16 +#define TF_GRAPH_DEF_VERSION 1804 // Updated: 2024/3/17 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 229b345cbe5afd07c55c51a5a5325ce9016200d5 Mon Sep 17 00:00:00 2001 From: "Jae H. Yoo" Date: Sun, 17 Mar 2024 17:13:00 -0700 Subject: [PATCH 020/670] Add BFLOAT16 to TFLite flatbuffer schema PiperOrigin-RevId: 616668681 --- .../compiler/mlir/lite/utils/convert_type.cc | 8 ++++++++ tensorflow/lite/core/api/BUILD | 3 ++- .../lite/core/api/flatbuffer_conversions.cc | 3 +++ .../lite/core/api/flatbuffer_conversions_test.cc | 8 ++++++++ tensorflow/lite/core/c/c_api_types.h | 1 + tensorflow/lite/core/c/common.cc | 2 ++ tensorflow/lite/core/c/common.h | 7 +++++++ tensorflow/lite/core/c/common_test.cc | 1 + tensorflow/lite/core/tools/verifier.cc | 3 +++ tensorflow/lite/delegates/flex/BUILD | 3 ++- tensorflow/lite/delegates/flex/util.cc | 8 ++++++++ tensorflow/lite/delegates/flex/util_test.cc | 3 +++ .../delegates/gpu/common/model_builder_helper.h | 2 ++ tensorflow/lite/objc/apis/TFLTensor.h | 3 +++ tensorflow/lite/objc/sources/TFLCommonUtil.mm | 2 ++ tensorflow/lite/optional_debug_tools.cc | 2 ++ tensorflow/lite/python/interpreter_wrapper/BUILD | 1 + .../lite/python/interpreter_wrapper/numpy.cc | 8 ++++++-- .../lite/python/optimize/calibration_wrapper.cc | 2 ++ tensorflow/lite/schema/schema.fbs | 1 + tensorflow/lite/schema/schema_generated.h | 15 +++++++++------ .../lite/tools/serialization/enum_mapping.h | 2 ++ tensorflow/lite/util.cc | 3 +++ 23 files changed, 81 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 9b215e77b89529..e09030ceb7515f 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -34,6 +34,8 @@ namespace errors = tensorflow::errors; tflite::TensorType ConvertTypeToTensorType(mlir::Type type) { if (type.isF16()) { return tflite::TensorType_FLOAT16; + } else if (type.isBF16()) { + return tflite::TensorType_BFLOAT16; } else if (type.isF32()) { return tflite::TensorType_FLOAT32; } else if (type.isF64()) { @@ -81,6 +83,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { switch (type) { case tflite::TensorType_FLOAT16: return builder.getF16Type(); + case tflite::TensorType_BFLOAT16: + return builder.getBF16Type(); case tflite::TensorType_FLOAT32: return builder.getF32Type(); case tflite::TensorType_FLOAT64: @@ -128,6 +132,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_COMPLEX128; case tflite::TensorType_FLOAT16: return tensorflow::DT_HALF; + case tflite::TensorType_BFLOAT16: + return tensorflow::DT_BFLOAT16; case tflite::TensorType_FLOAT32: return tensorflow::DT_FLOAT; case tflite::TensorType_FLOAT64: @@ -170,6 +176,8 @@ absl::StatusOr TfTypeToTflType(tensorflow::DataType type) { return tflite::TensorType_COMPLEX128; case tensorflow::DT_HALF: return tflite::TensorType_FLOAT16; + case tensorflow::DT_BFLOAT16: + return tflite::TensorType_BFLOAT16; case tensorflow::DT_FLOAT: return tflite::TensorType_FLOAT32; case tensorflow::DT_DOUBLE: diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index a0e28f1ccaaf8b..1d6e1ca1eed47a 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -1,6 +1,6 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "op_resolver_internal_visibility_allowlist") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -154,6 +154,7 @@ cc_test( ":api", "//tensorflow/lite:string", "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 10feeb3fc2c7dd..d36c2b69f4058a 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -1017,6 +1017,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_FLOAT16: *type = kTfLiteFloat16; return kTfLiteOk; + case TensorType_BFLOAT16: + *type = kTfLiteBFloat16; + return kTfLiteOk; case TensorType_FLOAT32: *type = kTfLiteFloat32; return kTfLiteOk; diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 6e08e6880e5522..87c897dfc0928e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" +#include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_type.h" @@ -189,6 +190,13 @@ TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) { EXPECT_EQ(kTfLiteFloat16, type); } +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeBFloat16) { + TfLiteType type; + EXPECT_EQ(kTfLiteOk, + ConvertTensorType(TensorType_BFLOAT16, &type, &mock_reporter_)); + EXPECT_EQ(kTfLiteBFloat16, type); +} + TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeInt4) { TfLiteType type; EXPECT_EQ(kTfLiteOk, diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index 1170025cbab9a2..32cefa839f4452 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -133,6 +133,7 @@ typedef enum { kTfLiteUInt32 = 16, kTfLiteUInt16 = 17, kTfLiteInt4 = 18, + kTfLiteBFloat16 = 19, } TfLiteType; /// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index fd7c415f96e634..7afecdbe885199 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -370,6 +370,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "STRING"; case kTfLiteFloat16: return "FLOAT16"; + case kTfLiteBFloat16: + return "BFLOAT16"; case kTfLiteFloat64: return "FLOAT64"; case kTfLiteResource: diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 4e4890164d3aa6..9801bde9ddc6ea 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -296,6 +296,13 @@ typedef struct TfLiteFloat16 { uint16_t data; } TfLiteFloat16; +/// bfloat16 data type compatible with the Google Brain definition. +/// https://cloud.google.com/tpu/docs/bfloat16. +/// This provides 1 bit of sign, 8 bits of exponent, and 7 bits of mantissa. +typedef struct TfLiteBFloat16 { + uint16_t data; +} TfLiteBFloat16; + /// Return the name of a given type, for error reporting purposes. const char* TfLiteTypeGetName(TfLiteType type); diff --git a/tensorflow/lite/core/c/common_test.cc b/tensorflow/lite/core/c/common_test.cc index d2bc137378656e..58fd8654d8b171 100644 --- a/tensorflow/lite/core/c/common_test.cc +++ b/tensorflow/lite/core/c/common_test.cc @@ -107,6 +107,7 @@ TEST(Types, TestTypeNames) { EXPECT_EQ(type_name(kTfLiteFloat64), "FLOAT64"); EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16"); + EXPECT_EQ(type_name(kTfLiteBFloat16), "BFLOAT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); EXPECT_EQ(type_name(kTfLiteUInt16), "UINT16"); EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); diff --git a/tensorflow/lite/core/tools/verifier.cc b/tensorflow/lite/core/tools/verifier.cc index cdf8959d55483f..c878f7c392d14a 100644 --- a/tensorflow/lite/core/tools/verifier.cc +++ b/tensorflow/lite/core/tools/verifier.cc @@ -409,6 +409,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, case TensorType_FLOAT16: bytes_required *= sizeof(uint16_t); break; + case TensorType_BFLOAT16: + bytes_required *= sizeof(uint16_t); + break; case TensorType_FLOAT64: bytes_required *= sizeof(double); break; diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 77a16ffa032865..5f126f68124cf8 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -7,10 +7,10 @@ load( "tf_opts_nortti_if_lite_protos", "tf_opts_nortti_if_mobile", ) +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library", "tflite_flex_shared_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") default_visibility = [ "//tensorflow/compiler/mlir/lite:__subpackages__", @@ -322,6 +322,7 @@ cc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite:string_util", "//tensorflow/lite:util", + "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels/internal:tensor", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index 8a115a4f33cf64..9940fadb8d7625 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "tensorflow/c/tf_datatype.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/string_util.h" @@ -74,6 +76,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_FLOAT; case kTfLiteFloat16: return TF_HALF; + case kTfLiteBFloat16: + return TF_BFLOAT16; case kTfLiteFloat64: return TF_DOUBLE; case kTfLiteInt16: @@ -116,6 +120,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { return kTfLiteFloat32; case TF_HALF: return kTfLiteFloat16; + case TF_BFLOAT16: + return kTfLiteBFloat16; case TF_DOUBLE: return kTfLiteFloat64; case TF_INT16: @@ -186,6 +192,8 @@ const char* TfLiteTypeToTfTypeName(TfLiteType type) { return "string"; case kTfLiteFloat16: return "float16"; + case kTfLiteBFloat16: + return "bfloat16"; case kTfLiteFloat64: return "float64"; case kTfLiteResource: diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc index c7361314aa38f5..7dfea9e6437c9d 100644 --- a/tensorflow/lite/delegates/flex/util_test.cc +++ b/tensorflow/lite/delegates/flex/util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/c/tf_datatype.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/lite/core/c/c_api_types.h" @@ -118,6 +119,7 @@ TEST(UtilTest, TypeConversionsFromTFLite) { EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); EXPECT_EQ(TF_HALF, GetTensorFlowDataType(kTfLiteFloat16)); + EXPECT_EQ(TF_BFLOAT16, GetTensorFlowDataType(kTfLiteBFloat16)); EXPECT_EQ(TF_DOUBLE, GetTensorFlowDataType(kTfLiteFloat64)); EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32)); @@ -136,6 +138,7 @@ TEST(UtilTest, TypeConversionsFromTFLite) { TEST(UtilTest, TypeConversionsFromTensorFlow) { EXPECT_EQ(kTfLiteFloat16, GetTensorFlowLiteType(TF_HALF)); + EXPECT_EQ(kTfLiteBFloat16, GetTensorFlowLiteType(TF_BFLOAT16)); EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT)); EXPECT_EQ(kTfLiteFloat64, GetTensorFlowLiteType(TF_DOUBLE)); EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16)); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h index 14384ce5be9a1c..27bb621c40dea9 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -154,6 +154,8 @@ absl::Status CreateVectorCopyData(const TfLiteTensor& src, T* dst) { return absl::OkStatus(); case kTfLiteFloat16: return absl::UnimplementedError("src can't be float16."); + case kTfLiteBFloat16: + return absl::UnimplementedError("src can't be bfloat16."); case kTfLiteFloat64: for (int i = 0; i < n; ++i) { dst[i] = tflite::GetTensorData(&src)[i]; diff --git a/tensorflow/lite/objc/apis/TFLTensor.h b/tensorflow/lite/objc/apis/TFLTensor.h index cd60b2144a0e6c..deaf52f9e5843f 100644 --- a/tensorflow/lite/objc/apis/TFLTensor.h +++ b/tensorflow/lite/objc/apis/TFLTensor.h @@ -52,6 +52,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) { /** 64-bit double precision floating point. */ TFLTensorDataTypeFloat64, + + /** 16-bit bfloat16 floating point. */ + TFLTensorDataTypeBFloat16, }; /** diff --git a/tensorflow/lite/objc/sources/TFLCommonUtil.mm b/tensorflow/lite/objc/sources/TFLCommonUtil.mm index 57362ceabb6597..8f9e37ebb421b6 100644 --- a/tensorflow/lite/objc/sources/TFLCommonUtil.mm +++ b/tensorflow/lite/objc/sources/TFLCommonUtil.mm @@ -32,6 +32,8 @@ TFLTensorDataType TFLTensorDataTypeFromCTensor(const TfLiteTensor *cTensor) { return TFLTensorDataTypeFloat32; case kTfLiteFloat16: return TFLTensorDataTypeFloat16; + case kTfLiteBFloat16: + return TFLTensorDataTypeBFloat16; case kTfLiteFloat64: return TFLTensorDataTypeFloat64; case kTfLiteInt32: diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index ce6e9e4973f702..9b716cdffb17c9 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -336,6 +336,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteComplex128"; case kTfLiteFloat16: return "kTfLiteFloat16"; + case kTfLiteBFloat16: + return "kTfLiteBFloat16"; case kTfLiteFloat64: return "kTfLiteFloat64"; case kTfLiteResource: diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index fa0af673063325..ed111e41efee0a 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -13,6 +13,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:string_util", + "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", # buildcleaner: keep diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index 0e07563702fcb0..45146cf88b0616 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #define TFLITE_IMPORT_NUMPY // See numpy.h for explanation. +#include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/python/interpreter_wrapper/numpy.h" -#include - namespace tflite { namespace python { @@ -38,6 +39,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteFloat16: return NPY_FLOAT16; + case kTfLiteBFloat16: + // TODO(b/329491949): NPY_BFLOAT16 currently doesn't exist + return NPY_FLOAT16; case kTfLiteFloat64: return NPY_FLOAT64; case kTfLiteInt32: diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index ffccf71a40635e..65f5dfe49d51ca 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -114,6 +114,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; case kTfLiteFloat16: return TensorType_FLOAT16; + case kTfLiteBFloat16: + return TensorType_BFLOAT16; case kTfLiteFloat64: return TensorType_FLOAT64; case kTfLiteInt32: diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 382462f938d93b..fe9ee4c11cc5c9 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -58,6 +58,7 @@ enum TensorType : byte { UINT32 = 15, UINT16 = 16, INT4 = 17, + BFLOAT16 = 18, } // Custom quantization parameters for experimenting with new quantization diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index b416555e837c3f..79d78c1fc84341 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -700,11 +700,12 @@ enum TensorType : int8_t { TensorType_UINT32 = 15, TensorType_UINT16 = 16, TensorType_INT4 = 17, + TensorType_BFLOAT16 = 18, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_INT4 + TensorType_MAX = TensorType_BFLOAT16 }; -inline const TensorType (&EnumValuesTensorType())[18] { +inline const TensorType (&EnumValuesTensorType())[19] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -723,13 +724,14 @@ inline const TensorType (&EnumValuesTensorType())[18] { TensorType_VARIANT, TensorType_UINT32, TensorType_UINT16, - TensorType_INT4 + TensorType_INT4, + TensorType_BFLOAT16 }; return values; } inline const char * const *EnumNamesTensorType() { - static const char * const names[19] = { + static const char * const names[20] = { "FLOAT32", "FLOAT16", "INT32", @@ -748,13 +750,14 @@ inline const char * const *EnumNamesTensorType() { "UINT32", "UINT16", "INT4", + "BFLOAT16", nullptr }; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT4)) return ""; + if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_BFLOAT16)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } diff --git a/tensorflow/lite/tools/serialization/enum_mapping.h b/tensorflow/lite/tools/serialization/enum_mapping.h index 574b1ee3e21cf7..d218b66258581f 100644 --- a/tensorflow/lite/tools/serialization/enum_mapping.h +++ b/tensorflow/lite/tools/serialization/enum_mapping.h @@ -64,6 +64,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; case kTfLiteFloat16: return TensorType_FLOAT16; + case kTfLiteBFloat16: + return TensorType_BFLOAT16; case kTfLiteFloat64: return TensorType_FLOAT64; case kTfLiteInt32: diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index d0d385a310d732..cecda0e5eb44a1 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -118,6 +118,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, case kTfLiteFloat16: *bytes = sizeof(TfLiteFloat16); break; + case kTfLiteBFloat16: + *bytes = sizeof(TfLiteBFloat16); + break; case kTfLiteFloat64: *bytes = sizeof(double); break; From eac0721765d3f5ee31a9a36baff8ce70352012f4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 02:01:57 -0700 Subject: [PATCH 021/670] compat: Update forward compatibility horizon to 2024-03-18 PiperOrigin-RevId: 616753201 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 382dfdf2eb7712..813819ae0aec8d 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 17) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 18) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 1449a4f07665788459a3cb37fbe4354835e592ef Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 02:02:10 -0700 Subject: [PATCH 022/670] Update GraphDef version to 1805. PiperOrigin-RevId: 616753252 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index c96dd2e0380234..b199c37ee80142 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1804 // Updated: 2024/3/17 +#define TF_GRAPH_DEF_VERSION 1805 // Updated: 2024/3/18 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 4c0445cde011df70ea111c09d2fa6a9f412b8f93 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 18 Mar 2024 07:05:12 -0700 Subject: [PATCH 023/670] Fix MOF transpose fusions. The current code attempts to evaluate the epilogue for each transpose, but it needs to be evaluated once for all transposes together. PiperOrigin-RevId: 616815857 --- .../xla/service/gpu/fusions/transpose_mlir.cc | 202 +++++------------- .../xla/service/gpu/fusions/transpose_mlir.h | 15 +- .../service/gpu/model/indexing_analysis.cc | 4 + 3 files changed, 67 insertions(+), 154 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 9311097e0093f8..8f3f4ef37480b4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -111,17 +111,16 @@ Tiling ComputeTransposeTiling(const TransposeDescription& tiled_transpose) { } // Returns transpose heroes that should be codegened via shmem. -absl::flat_hash_set GetShMemTranposes( +std::vector GetShMemTransposes( const HloFusionAnalysis& analysis) { - absl::flat_hash_set tranposes_to_tile; + ConstHloInstructionSet transposes_to_tile; for (const auto [hero, root] : llvm::zip(analysis.fusion_heroes(), analysis.fusion_roots())) { - if (!GetDescriptionForTiledTransposeEmitter(*root, *hero)) { - continue; + if (GetDescriptionForTiledTransposeEmitter(*root, *hero)) { + transposes_to_tile.insert(hero); } - tranposes_to_tile.insert(hero); } - return tranposes_to_tile; + return {transposes_to_tile.begin(), transposes_to_tile.end()}; } } // namespace @@ -129,7 +128,7 @@ absl::flat_hash_set GetShMemTranposes( MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), tiling_(ComputeTransposeTiling(analysis.tiled_transpose())), - shmem_transposes_(GetShMemTranposes(analysis)) { + shmem_transposes_(GetShMemTransposes(analysis)) { for (auto [root, hero] : llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) { if (auto transpose = GetDescriptionForTiledTransposeEmitter(*root, *hero)) { @@ -143,14 +142,7 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) const HloFusionAnalysis& analysis) { // If there is a hero, which does not have a transpose, the codegen might // fail because of the incorrect thread ID mapping for that particular case. - for (const auto [hero, root] : - llvm::zip(analysis.fusion_heroes(), analysis.fusion_roots())) { - if (!GetDescriptionForTiledTransposeEmitter(*root, *hero)) { - return false; - } - } - return mlir_converter::IsHloConversionSupported( - analysis.fusion(), analysis.device_info().gpu_compute_capability()); + return GetShMemTransposes(analysis).size() == analysis.fusion_heroes().size(); } std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( @@ -161,7 +153,11 @@ std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( // Non-transpose roots are elementwise by definition. return ComputeThreadIdToInputIndexing(root_index, 0, ctx); } + return ComputeThreadIdToOutputIndexing(hero, ctx); +} +IndexingMap MlirTransposeFusion::ComputeThreadIdToOutputIndexing( + const HloInstruction& hero, MLIRContext* ctx) const { // The block offsets are permuted, but the thread offsets remain the same. auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) .getSubMap(std::vector{permutation_.begin(), @@ -180,10 +176,8 @@ std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( return map; } -std::optional MlirTransposeFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const { - const auto& hero = *analysis_.fusion_heroes()[root_index]; - +IndexingMap MlirTransposeFusion::ComputeThreadIdToInputIndexing( + const HloInstruction& hero, MLIRContext* ctx) const { auto map = ComposeIndexingMaps( GetIndexingMapForTiling(tiling_, ctx), GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); @@ -242,32 +236,12 @@ absl::StatusOr> MlirTransposeFusion::EmitWriteToShMemMlir( int num_inputs = fusion.fused_instructions_computation()->num_parameters(); int num_outputs = entry_function.getArguments().size() - num_inputs; - SmallPtrSet emitted_heros; - SmallVector shmem_intermediate_result; - for (const auto& [root_index, hero_and_root] : llvm::enumerate( - llvm::zip(analysis_.fusion_heroes(), analysis_.fusion_roots()))) { - const HloInstruction* transpose = std::get<0>(hero_and_root); - const HloInstruction* root = std::get<1>(hero_and_root); - - // The same hero can occure only multiple (hero, root) pair. We should emit - // the write to shmem only once. - if (!emitted_heros.insert(transpose).second) { - continue; - } - - // Skip non-transpose heroes and handle them in EmitReadFromShMemMlir. - auto description = - GetDescriptionForTiledTransposeEmitter(*root, *transpose); - if (!description.has_value()) { - continue; - } - - auto input_indexing = ComputeThreadIdToInputIndexing( - root_index, /*hero_operand_index=*/0, builder.getContext()); - TF_RET_CHECK(input_indexing) << "Indexing is never nullopt"; + for (auto* transpose : shmem_transposes_) { + auto input_indexing = + ComputeThreadIdToInputIndexing(*transpose, builder.getContext()); IndexingMap shmem_input_indexing = - GetSharedMemoryWriteIndexingMap(*input_indexing, permutation_[2]); + GetSharedMemoryWriteIndexingMap(input_indexing, permutation_[2]); // Allocate shared memory. const HloInstruction* transpose_operand = transpose->operand(0); @@ -278,11 +252,11 @@ absl::StatusOr> MlirTransposeFusion::EmitWriteToShMemMlir( // Emit loop that writes subgraphs of transpose operands to shmem. auto shmem_result = EmitThreadLoopNest( - builder, {shmem}, *input_indexing, + builder, {shmem}, input_indexing, [&](ValueRange output_tensors, ValueRange dim_values, ValueRange symbol_values) -> SmallVector { auto input_indices = - ApplyAffineMap(input_indexing->GetAffineMap(), dim_values, + ApplyAffineMap(input_indexing.GetAffineMap(), dim_values, symbol_values, builder); auto shmem_indices = ApplyAffineMap(shmem_input_indexing.GetAffineMap(), dim_values, @@ -313,115 +287,43 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( const HloFusionInstruction& fusion, const mlir_converter::PartitionedComputations& computations, const CallTargetProvider& call_targets, ValueRange shmem_tensors) const { - SmallVector result_tensors; - int num_inputs = fusion.fused_instructions_computation()->num_parameters(); - SmallPtrSet hero_roots{ - analysis_.fusion_roots().begin(), analysis_.fusion_roots().end()}; - - // Cache for root indexing per hero. If multiple roots use the same hero, they - // will have identical indexing. - absl::flat_hash_map root_to_hero_indexing; - - int transpose_hero_count = 0; - - // Map from hero instruction to shmem tensor value. - absl::flat_hash_map hero_to_shmem_tensor; - ValueRange output_tensor_args = entry_function.getArguments().drop_front(num_inputs); + auto output_indexing = ComputeThreadIdToOutputIndexing( + *shmem_transposes_.front(), builder.getContext()); + auto shmem_output_indexing = + GetSharedMemoryReadIndexingMap(output_indexing, permutation_[2]); + auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( + shmem_transposes_.front(), builder.getContext()); + auto root_indexing = ComposeIndexingMaps(output_indexing, epilogue_indexing); + auto result_tensors = EmitThreadLoopNest( + builder, output_tensor_args, output_indexing, + [&](ValueRange output_tensors, ValueRange dim_values, + ValueRange symbol_values) -> SmallVector { + auto shmem_indices = + ApplyAffineMap(shmem_output_indexing.GetAffineMap(), dim_values, + symbol_values, builder); + llvm::SmallVector transpose_values; + for (auto shmem : shmem_tensors) { + transpose_values.push_back( + builder.create(shmem, shmem_indices)); + } + auto root_indices = ApplyAffineMap(root_indexing.GetAffineMap(), + dim_values, symbol_values, builder); + auto result_scalars = + EmitEpilogue(computations, entry_function, transpose_values, + root_indices, builder); + SmallVector results; + results.reserve(output_tensor_args.size()); + for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { + results.push_back( + builder.create(value, tensor, root_indices)); + } + return results; + }); - for (const auto& [root_index, hero_and_root] : llvm::enumerate( - llvm::zip(analysis_.fusion_heroes(), analysis_.fusion_roots()))) { - const HloInstruction* transpose = std::get<0>(hero_and_root); - const HloInstruction* root = std::get<1>(hero_and_root); - - auto* mlir_context = builder.getContext(); - auto output_indexing = - ComputeThreadIdToOutputIndexing(root_index, mlir_context); - TF_RET_CHECK(output_indexing) << "Indexing is never nullopt"; - - if (!root_to_hero_indexing.contains(transpose)) { - auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( - transpose, mlir_context, - /*is_root=*/[&](const HloInstruction* instr) { - return hero_roots.contains(instr); - }); - root_to_hero_indexing.emplace( - transpose, ComposeIndexingMaps(*output_indexing, epilogue_indexing)); - } - - const IndexingMap& root_indexing = root_to_hero_indexing.at(transpose); - - IndexingMap shmem_output_indexing = - GetSharedMemoryReadIndexingMap(*output_indexing, permutation_[2]); - auto description = - GetDescriptionForTiledTransposeEmitter(*root, *transpose); - - if (description.has_value()) { - auto subresult_tensors = EmitThreadLoopNest( - builder, output_tensor_args[root_index], *output_indexing, - [&](ValueRange output_tensors, ValueRange dim_values, - ValueRange symbol_values) -> SmallVector { - auto root_indices = - ApplyAffineMap(root_indexing.GetAffineMap(), dim_values, - symbol_values, builder); - auto shmem_indices = - ApplyAffineMap(shmem_output_indexing.GetAffineMap(), dim_values, - symbol_values, builder); - - if (!hero_to_shmem_tensor.contains(transpose)) { - hero_to_shmem_tensor[transpose] = - shmem_tensors[transpose_hero_count]; - ++transpose_hero_count; - } - - mlir::Value value = builder.create( - hero_to_shmem_tensor[transpose], shmem_indices); - auto result_scalars = EmitEpilogue(computations, entry_function, - value, root_indices, builder); - SmallVector results; - results.reserve(output_tensor_args.size()); - for (auto [tensor, value] : - llvm::zip(output_tensors, result_scalars)) { - results.push_back( - builder.create(value, tensor, root_indices)); - } - return results; - }); - result_tensors.append(subresult_tensors.begin(), subresult_tensors.end()); - } else { - auto indexing = ComputeThreadIdToOutputIndexing(0, builder.getContext()); - TF_RET_CHECK(indexing) << "Indexing is never nullopt"; - auto subresult_tensors = EmitThreadLoopNest( - builder, output_tensor_args, *indexing, - [&](ValueRange output_tensors, ValueRange dim_values, - ValueRange symbol_values) -> SmallVector { - auto output_indices = ApplyAffineMap( - indexing->GetAffineMap(), dim_values, symbol_values, builder); - - // Generate the operands for the root function: input tensors + - // output indices. - llvm::SmallVector operands( - entry_function.getArguments().take_front(num_inputs)); - absl::c_copy(output_indices, std::back_inserter(operands)); - - auto result_scalars = - builder.create(call_targets(root), operands); - - SmallVector results; - results.reserve(output_tensor_args.size()); - for (auto [tensor, value] : - llvm::zip(output_tensors, result_scalars.getResults())) { - results.push_back( - builder.create(value, tensor, output_indices)); - } - return results; - }); - result_tensors.append(subresult_tensors.begin(), subresult_tensors.end()); - } - } builder.create(result_tensors); return absl::OkStatus(); } @@ -429,7 +331,7 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( std::vector MlirTransposeFusion::GetInstructionsWithCustomCodegen( const HloFusionInstruction& fusion) const { - return {shmem_transposes_.begin(), shmem_transposes_.end()}; + return GetShMemTransposes(analysis_); } absl::Status MlirTransposeFusion::EmitEntryFunction( diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 3df6073f5d924e..58c8d6265ae838 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -17,15 +17,14 @@ limitations under the License. #include #include +#include -#include "absl/container/flat_hash_set.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" @@ -59,9 +58,17 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + mlir::MLIRContext* ctx) const override { + return ComputeThreadIdToInputIndexing( + *analysis_.fusion_heroes()[root_index], ctx); + } protected: + IndexingMap ComputeThreadIdToInputIndexing(const HloInstruction& hero, + mlir::MLIRContext* ctx) const; + IndexingMap ComputeThreadIdToOutputIndexing(const HloInstruction& hero, + mlir::MLIRContext* ctx) const; + absl::Status EmitEntryFunction( const mlir_converter::PartitionedComputations& computations, const mlir_converter::CallTargetProvider& call_targets, @@ -87,7 +94,7 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { const HloFusionAnalysis& analysis_; Tiling tiling_; Vector3 permutation_; - absl::flat_hash_set shmem_transposes_; + std::vector shmem_transposes_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index eca6ccd0dc067b..a6a14c28ca8161 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -1188,6 +1188,10 @@ HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, if (auto transpose = DynCast(instr)) { return ComputeInputToOutputTransposeOpIndexing(transpose, ctx); } + if (instr->opcode() == HloOpcode::kTuple) { + return HloInstructionIndexing::FromIndexingMaps( + {CreateIdentityMap(instr->shape().tuple_shapes(input_id), ctx)}); + } // If we cannot compute input-to-output indexing, we return std::nullopt for // every op result. int64_t num_results = From d1b0fb4020e8020e41c86f18fc6685df82a96188 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 18 Mar 2024 08:05:01 -0700 Subject: [PATCH 024/670] Support sparse dots in GemmFusion pass The codegen will only support this for NVidia GPUs, which have the following restrictions: - only 2:4 structured sparsity is allowed; - only the first dot operand may be sparse; PiperOrigin-RevId: 616829925 --- .../xla/xla/service/gpu/gemm_fusion.cc | 35 ++++++++++-- .../xla/xla/service/gpu/gemm_fusion_test.cc | 57 +++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index 999989208d73ce..e98904b364443f 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -162,14 +162,17 @@ struct HlosAndRequirements { HloInstruction& FuseDot(const HloDotInstruction& dot, const HloInstruction& fused_lhs, const HloInstruction& fused_rhs, + std::optional fused_meta, HloComputation::Builder& builder // append ) { - CHECK_EQ(dot.operand_count(), 2); VLOG(3) << "Fusing " << dot.ToString(); - std::array hlo_new_operands = { + std::vector hlo_new_operands = { const_cast(&fused_lhs), const_cast(&fused_rhs)}; + if (fused_meta.has_value()) { + hlo_new_operands.push_back(const_cast(fused_meta.value())); + } return *builder.AddInstruction( dot.CloneWithNewOperands(dot.shape(), hlo_new_operands)); } @@ -620,12 +623,33 @@ absl::StatusOr CreateDotFusion( return can_handle; } + // Verify sparse dot constraints. + if (dot.sparse_operands()) { + const SparsityDescriptor& descriptor = dot.sparsity().front(); + if (dot.sparse_operands() != 1 || descriptor.index() != 0) { + return InvalidArgument("Sparsity is only supported on left operand"); + } + if (descriptor.type() != SparsityType::SPARSITY_STRUCTURED_N_M || + descriptor.n() != 2 || descriptor.m() != 4) { + return InvalidArgument("Only 2:4 structured sparsity is supported"); + } + // DotDimensionSorter pass makes sure the sparse dimension is minor. + CHECK_EQ(descriptor.dimension(), dot.operand(0)->shape().rank() - 1); + } + HlosAndRequirements lhs_hlos_and_reqs = FuseDotOperand( dot, /*operand_index=*/0, gpu_version, builder, fusion_inputs); HlosAndRequirements rhs_hlos_and_reqs = FuseDotOperand( dot, /*operand_index=*/1, gpu_version, builder, fusion_inputs); - HloInstruction& fused_dot = FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, - *rhs_hlos_and_reqs.fused_hlo, builder); + std::optional meta_hlo; + if (dot.sparse_operands()) { + HlosAndRequirements meta_hlos_and_reqs = FuseDotOperand( + dot, /*operand_index=*/2, gpu_version, builder, fusion_inputs); + meta_hlo.emplace(meta_hlos_and_reqs.fused_hlo); + } + HloInstruction& fused_dot = + FuseDot(dot, *lhs_hlos_and_reqs.fused_hlo, *rhs_hlos_and_reqs.fused_hlo, + meta_hlo, builder); // For now the RHS doesn't support splits, so it also doesn't impose any // requirements. HlosAndRequirements fused_output_and_reqs = @@ -642,7 +666,8 @@ absl::StatusOr CreateDotFusion( dot.precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || - dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any()) { + dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || + dot.sparse_operands()) { return FusionDecision{}; } diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc index 43c1c155fd4189..bdb1be455024f0 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc @@ -1148,6 +1148,63 @@ ENTRY e { })"); } +class SparseDotTest : public GemmFusionTest {}; + +TEST_F(SparseDotTest, DotWithSparseLhsOperandIsRewritten) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,16] parameter(0) + rhs = f16[32,2] parameter(1) + meta = u16[2,2] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + MatchHloModule(*module, R"( +; CHECK-LABEL: ENTRY %main ({{.*}}: f16[2,16], {{.*}}: f16[32,2], {{.*}}: u16[2,2]) -> f32[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1) +; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2) +; CHECK: ROOT {{.*}} = f32[2,2]{1,0} +; CHECK-SAME: fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]), +; CHECK-SAME: kind=kCustom +; CHECK-SAME: __triton_gemm +})"); +} + +TEST_F(SparseDotTest, DotWithSparseRhsOperandIsNotSupported) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,32] parameter(0) + rhs = f16[16,2] parameter(1) + meta = u16[2,2] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=R.0@2:4 +})") + .value(); + auto result = GemmFusion(gpu_version_).Run(module.get()); + EXPECT_FALSE(result.ok()); +} + +TEST_F(SparseDotTest, UnsupportedSparsityType) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule test +ENTRY main { + lhs = f16[2,8] parameter(0) + rhs = f16[32,2] parameter(1) + meta = u16[2,1] parameter(2) + ROOT dot = f32[2,2] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@1:4 +})") + .value(); + auto result = GemmFusion(gpu_version_).Run(module.get()); + EXPECT_FALSE(result.ok()); +} + } // namespace } // namespace gpu } // namespace xla From 3a717efddf65842ff10f24f71eecbc47490a0422 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 18 Mar 2024 08:18:49 -0700 Subject: [PATCH 025/670] Support sparse dots in GemmFusionAutotuner pass 1) Add `allow_cublas` flag and set it to false for sparse dots (we cannot run cublas for reference, as it doesn't support sparsity). 2) Make sure the configs that are not supported by the codegen are excluded. Specifically, if there are more threads than metadata values, it'd fail. 3) For deviceless compilations, apply the `ReduceTileSizes` to the default config, as otherwise it produces an incorrect config for sparse dots (too many threads). PiperOrigin-RevId: 616833330 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gemm_fusion_autotuner.cc | 80 +++++++++++++------ .../service/gpu/gemm_fusion_autotuner_test.cc | 34 +++++++- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index bc0f46f49bef75..59d7cc553ec9cc 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -766,6 +766,7 @@ xla_test( ":backend_configs_cc", ":gemm_fusion", ":gemm_fusion_autotuner", + ":ir_emission_utils", ":matmul_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index a53affc736bb3c..e68364f71d3903 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -188,6 +188,9 @@ class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { // This contains all alternative Triton GEMM configs related to one fusion. struct GemmConfigSet { std::vector configs; + // Setting this to true disallows verification and fallback to cuBLAS, and + // the usage of cuDNN. + bool has_sparsity = false; }; using CuDnnPlanId = int64_t; @@ -259,10 +262,12 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { fusion->GetModule()->config().debug_options(); auto cuda_comp = std::get(config_.GetGpuComputeCapability()); - return {GetPossibleMatmulAutotuneConfigs( - *Cast(hlo_query::GetFirstInstructionWithOpcode( - *fusion->called_computations().at(0), HloOpcode::kDot)), - cuda_comp, debug_options, config_.ExhaustiveTilingSearch())}; + const HloDotInstruction* dot_instr = + Cast(hlo_query::GetFirstInstructionWithOpcode( + *fusion->called_computations().at(0), HloOpcode::kDot)); + auto configs = GetPossibleMatmulAutotuneConfigs( + *dot_instr, cuda_comp, debug_options, config_.ExhaustiveTilingSearch()); + return {configs, /*has_sparsity=*/dot_instr->sparse_operands() > 0}; } AutotuneConfig config_; @@ -294,8 +299,11 @@ TileSizeLimit GetUpperLimit(const HloDotInstruction& dot) { std::max(tsl::NextPowerOfTwoS64(m), kMinTileSize); const int64_t block_n_limit = std::max(tsl::NextPowerOfTwoS64(n), kMinTileSize); + // Increase minimum tile size for the contracting dimension proportionally + // to the sparsity multiplier (assume 2:4 structured sparsity). const int64_t block_k_limit = - std::max(tsl::NextPowerOfTwoS64(k), kMinTileSize); + std::max(tsl::NextPowerOfTwoS64(k), + kMinTileSize * (dot.sparse_operands() ? 2 : 1)); return {block_m_limit, block_n_limit, block_k_limit}; } @@ -345,6 +353,12 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( if (block_k > limit.block_k) { continue; } + // Sparse meta should have at least one element per thread. + // Note: only 2:4 structured sparsity is currently supported. + if (dot.sparse_operands() && + block_m * block_k / 16 < num_warps * WarpSize()) { + continue; + } for (int split_k : SPLIT_K) { if (split_k > std::min(max_split_k, @@ -429,6 +443,13 @@ std::vector ReduceTileSizes( config.block_k = std::min(config.block_k, limit.block_k); config.split_k = std::min( config.split_k, GetSplitKLimit(config.block_k, limit.block_k)); + // Sparse meta should have at least one element per thread. + // Note: only 2:4 structured sparsity is currently supported. + if (dot.sparse_operands()) { + int meta_elements = config.block_m * config.block_k / 16; + config.num_warps = + std::min(config.num_warps, meta_elements / WarpSize()); + } } // Remove duplicates. @@ -632,16 +653,16 @@ CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, if (IsFusionKind(hlo, kTritonGemmFusionKind)) { config_count += gemm_config_set.configs.size(); - if (IsCuDnnEnabled(config, debug_opts) && + if (!gemm_config_set.has_sparsity && IsCuDnnEnabled(config, debug_opts) && HasAlgorithmSupportedByCudnn(hlo)) { config_count += GetCuDnnPlanCount(hlo, config); } } else if (IsFusionKind(hlo, kCuDnnFusionKind)) { config_count += GetCuDnnPlanCount(hlo, config); } + // Reference config for verification (uses cuBLAS). + config_count += !gemm_config_set.has_sparsity; } - // cuBLAS configs: one per fusion. - config_count += gemm_config_sets.size(); std::atomic done_count = 0; std::atomic good_count = 0; @@ -756,16 +777,19 @@ CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, }); } - thread_pool->Schedule([&, fusion] { - absl::StatusOr has_executable = - compile_reference_executable(fusion); - TF_CHECK_OK(has_executable.status()); - log(has_executable.value()); - counter.DecrementCount(); - }); + if (!gemm_config_set.has_sparsity) { + thread_pool->Schedule([&, fusion] { + absl::StatusOr has_executable = + compile_reference_executable(fusion); + TF_CHECK_OK(has_executable.status()); + log(has_executable.value()); + counter.DecrementCount(); + }); + } if (IsFusionKind(*fusion, kCuDnnFusionKind) || (IsFusionKind(*fusion, kTritonGemmFusionKind) && + !gemm_config_set.has_sparsity && IsCuDnnEnabled(config, debug_opts) && HasAlgorithmSupportedByCudnn(*fusion))) { const int plan_count = GetCuDnnPlanCount(*fusion, config); @@ -803,12 +827,15 @@ CompileMany(const AutotuneConfig& config, AutotunerCompileUtil& util, log(has_executable); } - TF_ASSIGN_OR_RETURN(bool has_executable, - compile_reference_executable(fusion)); - log(has_executable); + if (!gemm_config_set.has_sparsity) { + TF_ASSIGN_OR_RETURN(bool has_executable, + compile_reference_executable(fusion)); + log(has_executable); + } if (IsFusionKind(*fusion, kCuDnnFusionKind) || (IsFusionKind(*fusion, kTritonGemmFusionKind) && + !gemm_config_set.has_sparsity && IsCuDnnEnabled(config, debug_opts) && HasAlgorithmSupportedByCudnn(*fusion))) { const int plan_count = GetCuDnnPlanCount(*fusion, config); @@ -864,11 +891,10 @@ absl::StatusOr Execute(const AutotuneConfig& config, input_shapes.push_back(param->shape()); } - // Run with cuBLAS. + // Run with cuBLAS (optional). std::optional reference_buffer; - absl::Duration cublas_duration; - { - TF_RET_CHECK(executable_set.reference != nullptr); + absl::Duration cublas_duration = absl::InfiniteDuration(); + if (executable_set.reference != nullptr) { TF_ASSIGN_OR_RETURN(std::optional output, util.ProfileExecutable(&*executable_set.reference, stream, inputs, input_shapes)); @@ -925,7 +951,9 @@ absl::StatusOr Execute(const AutotuneConfig& config, *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(profiling_output->duration); - if (config.should_check_correctness()) { + // Reference buffer is available when `config.should_check_correctness()` + // is set and reference executable was compiled. + if (reference_buffer.has_value()) { TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, rz_allocator.CheckRedzones()); @@ -1157,7 +1185,11 @@ absl::StatusOr GemmFusionAutotuner::Run( if (IsFusionKind(*fusion, kCuDnnFusionKind)) { res.mutable_algorithm()->set_algo_id(-1); } else { - *res.mutable_triton() = kDefaultGemmTiling.ToProto(); + const HloDotInstruction* dot_instr = + Cast(hlo_query::GetFirstInstructionWithOpcode( + *fusion->called_computations().at(0), HloOpcode::kDot)); + auto config = ReduceTileSizes(*dot_instr, {kDefaultGemmTiling}).front(); + *res.mutable_triton() = config.ToProto(); } *res.mutable_run_time() = tsl::proto_utils::ToDurationProto(absl::ZeroDuration()); diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index 4d886a8f68988d..a0d3e85a782356 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" @@ -694,7 +695,7 @@ ENTRY e { RunFileCheck( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), R"( -// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}}} +// CHECK: backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"16","block_k":"16","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"}}} )")); EXPECT_TRUE(filecheck_matches); } else { @@ -770,6 +771,37 @@ ENTRY e { [](const TritonGemmConfig& config) { return config.split_k == 1; })); } +class GemmFusionAutotunerConfigTest + : public StatelessAutotunerTest, + public ::testing::WithParamInterface {}; + +TEST_P(GemmFusionAutotunerConfigTest, SparseDotDiscardsUnsupportedTiles) { + const std::string kHloText = R"( +HloModule test +ENTRY wais { + lhs = f16[5,1600] parameter(0) + rhs = f16[3200,10] parameter(1) + meta = u16[5,200] parameter(2) + ROOT dot = f32[5,10] dot(lhs, rhs, meta), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + auto dot = + Cast(module->entry_computation()->root_instruction()); + + auto configs = GetPossibleMatmulAutotuneConfigs( + *dot, se::CudaComputeCapability{8, 0}, GetDebugOptionsForTest(), + /*exhaustive_tiling_search=*/GetParam()); + for (const auto& config : configs) { + int metadata_size = config.block_m * config.block_k / 16; + EXPECT_LE(config.num_warps * WarpSize(), metadata_size); + EXPECT_GT(config.block_k, 16); // kMinTileSize + } +} + +INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep, + GemmFusionAutotunerConfigTest, ::testing::Bool()); + } // namespace } // namespace gpu } // namespace xla From ea0619d9091b50a6204cef9085953deef9d98c91 Mon Sep 17 00:00:00 2001 From: Eunjae Kim Date: Mon, 18 Mar 2024 08:18:52 -0700 Subject: [PATCH 026/670] Insert a task to the low priority task queue when the criticality is one of the fixed list of low priority criticalities and support padding the high priority batch with the unbatched tasks given via the ProcessBatchCallBack. PiperOrigin-RevId: 616833344 --- tensorflow/core/kernels/BUILD | 17 +- tensorflow/core/kernels/batch_kernels_test.cc | 331 ++++++++++++++++-- tensorflow/core/kernels/batching_util/BUILD | 3 + .../batching_util/batch_resource_base.cc | 149 +++++--- .../batching_util/batch_resource_base.h | 34 +- .../batching_util/shared_batch_scheduler.h | 122 +++++-- .../shared_batch_scheduler_test.cc | 234 ++++++++++++- 7 files changed, 763 insertions(+), 127 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index aa313ddbd3b032..f97593aaf43898 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1645,19 +1645,34 @@ cc_library( tf_cc_test( name = "batch_kernels_test", - size = "small", + size = "medium", srcs = ["batch_kernels_test.cc"], features = ["-layering_check"], deps = [ ":batch_kernel_test_util", ":batch_kernels", + ":cwise_op", ":function_ops", ":shape_ops", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/kernels/batching_util:warmup", + "//tensorflow/core/platform:status", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/core/public:version", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:refcount", + "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 68ae309504ffb6..320d17b14396c2 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/match.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/device_factory.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" +#include "tsl/platform/criticality.h" #include "tsl/platform/errors.h" #include "tsl/platform/refcount.h" #include "tsl/platform/status.h" @@ -65,10 +67,283 @@ TEST_P(BatchFunctionKernelTest, EnableAdaptiveScheduler) { INSTANTIATE_TEST_SUITE_P(Params, BatchFunctionKernelTest, ::testing::Bool()); -class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { +class SharedBatchFunctionTestState : public OpsTestBase { public: // Init test fixture with a batch kernel instance. - Status Init(bool enable_splitting) { + void CreateFunctionLibraryRuntime() { + pflr_ = std::make_unique( + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), + /*thread_pool=*/nullptr, /*parent=*/nullptr, + /*session_metadata=*/nullptr, + Rendezvous::Factory{[](const int64_t, const DeviceMgr *device_mgr, + tsl::core::RefCountPtr *r) { + *r = tsl::core::RefCountPtr( + new IntraProcessRendezvous(device_mgr)); + return absl::OkStatus(); + }}); + } +}; + +class BatchFunctionTestState : public SharedBatchFunctionTestState { + public: + // Init test fixture with a batch kernel instance. The caller guarantees that + // the device pointer is valid throughout the life of this class. + absl::Status Init(Device *device, bool enable_low_priority_queue) { + // Override the per-test/per-op device with a given device so that it can + // be shared between ops. + device_ = device; + + NameAttrList f; + f.set_name("ShapeEnforcingFunction"); + FunctionDef func = FunctionDefHelper::Create( + // function_name + f.name(), + // in_def + {"x:int64"}, + // out_def + {"o:int64"}, + // attr_def + {}, + // node_def + {{{"o"}, + "EnsureShape", + {"x"}, + {{"T", DataType::DT_INT64}, {"shape", TensorShape({4, 2})}}}}, + // ret_def + {{"o", "o:output"}}); + TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func)); + SharedBatchFunctionTestState::CreateFunctionLibraryRuntime(); + + std::vector inputs( + {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); + TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction") + .Attr("max_batch_size", 4) + .Attr("num_batch_threads", 4) + .Attr("allowed_batch_sizes", {4}) + .Attr("batch_timeout_micros", 5000000) + .Attr("max_enqueued_batches", 10) + .Attr("low_priority_max_batch_size", + enable_low_priority_queue ? 64 : 0) + .Attr("low_priority_batch_timeout_micros", + enable_low_priority_queue ? 50000000 : 0) + .Attr("low_priority_allowed_batch_sizes", + enable_low_priority_queue ? std::vector{1} + : std::vector()) + .Attr("low_priority_max_enqueued_batches", + enable_low_priority_queue ? 100 : 0) + .Attr("Tin", {DataType::DT_INT64}) + .Input(inputs) + .Attr("Tcaptured", std::vector{}) + .Input(std::vector{}) + .Attr("Tout", std::vector{DT_INT64}) + .Attr("f", f) + .Finalize(node_def())); + return OpsTestBase::InitOp(); + } + + void TestBody() override {} +}; + +class BatchFunctionTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // The device needs to be shared in each test case and within each test case + // only. + cpu_device_ = + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); + } + std::unique_ptr cpu_device_; +}; + +TEST_P(BatchFunctionTest, BatchingWorksWithoutCriticality) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + bool enable_low_priority_queue = GetParam(); + { + tsl::BlockingCounter blocking_counter(4); + // 8 threads run the batch op with no explicit criticality set. They are + // eventually batched to form a tensor with [4, 2] shape which is verified + // within the function. + for (int i = 0; i < 4; ++i) { + Env::Default()->SchedClosure([&]() { + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kCritical); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + } +} + +TEST_P(BatchFunctionTest, PaddingWorksWithoutCriticality) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + bool enable_low_priority_queue = GetParam(); + { + tsl::BlockingCounter blocking_counter(2); + // 2 threads run the batch op with no explicit criticality set. They are + // eventually batched and padded to form a tensor with [4, 2] shape which is + // verified within the function. + for (int i = 0; i < 2; ++i) { + Env::Default()->SchedClosure([&]() { + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kCritical); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + } +} + +#if defined(PLATFORM_GOOGLE) +TEST_P(BatchFunctionTest, BatchingWorks) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + bool enable_low_priority_queue = GetParam(); + { + tsl::BlockingCounter blocking_counter(4); + // 2 threads run the batch op with critical plus and 2 threads run the batch + // op with sheddable. They are eventually batched to form a tensor with [4, + // 2] shape which is verified within the function. + for (int i = 0; i < 2; ++i) { + Env::Default()->SchedClosure([&]() { + tsl::criticality::ScopedCriticality scoped_criticality( + tsl::criticality::Criticality::kCriticalPlus); + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kCriticalPlus); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + for (int i = 0; i < 2; ++i) { + Env::Default()->SchedClosure([&]() { + tsl::criticality::ScopedCriticality scoped_criticality( + tsl::criticality::Criticality::kSheddable); + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kSheddable); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {234, 567}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({234, 567}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + } +} + +TEST_P(BatchFunctionTest, PaddingWorks) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + bool enable_low_priority_queue = GetParam(); + { + tsl::BlockingCounter blocking_counter(2); + // 1 thread run the batch op with critical plus and 1 threads run the batch + // op with sheddable. They are eventually batched and padded to form a + // tensor with [4, 2] shape which is verified within the function. + Env::Default()->SchedClosure([&]() { + tsl::criticality::ScopedCriticality scoped_criticality( + tsl::criticality::Criticality::kCriticalPlus); + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kCriticalPlus); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + + Env::Default()->SchedClosure([&]() { + tsl::criticality::ScopedCriticality scoped_criticality( + tsl::criticality::Criticality::kSheddable); + ASSERT_EQ(tsl::criticality::GetCriticality(), + tsl::criticality::Criticality::kSheddable); + + BatchFunctionTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_ASSERT_OK( + test_state.Init(cpu_device_.get(), enable_low_priority_queue)); + test_state.AddInputFromList(TensorShape({1, 2}), {234, 567}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({234, 567}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + + blocking_counter.Wait(); + } +} +#endif + +INSTANTIATE_TEST_SUITE_P(BatchFunctionTest, BatchFunctionTest, + ::testing::Bool()); + +class BatchFunctionKernelParallelWarmupTestState + : public SharedBatchFunctionTestState { + public: + // Init test fixture with a batch kernel instance. + absl::Status Init(bool enable_splitting) { static auto *const cpu_device = []() { auto device = DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); @@ -98,40 +373,29 @@ class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { // ret_def {{"o", "o:output"}}); TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func)); - - pflr_ = std::make_unique( - device_mgr_.get(), Env::Default(), /*config=*/nullptr, - TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), - /*thread_pool=*/nullptr, /*parent=*/nullptr, - /*session_metadata=*/nullptr, - Rendezvous::Factory{[](const int64_t, const DeviceMgr *device_mgr, - tsl::core::RefCountPtr *r) { - *r = tsl::core::RefCountPtr( - new IntraProcessRendezvous(device_mgr)); - return absl::OkStatus(); - }}); + SharedBatchFunctionTestState::CreateFunctionLibraryRuntime(); std::vector inputs( {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); - TF_CHECK_OK(NodeDefBuilder("BatchTPUInput", "BatchFunction") - .Attr("max_batch_size", enable_splitting ? 16 : 8) - .Attr("num_batch_threads", 8) - .Attr("allowed_batch_sizes", {2, 4, 8}) - .Attr("batch_timeout_micros", 1000000) - .Attr("max_enqueued_batches", 10) - .Attr("enable_large_batch_splitting", true) - .Attr("low_priority_max_batch_size", 64) - .Attr("low_priority_batch_timeout_micros", 8000) - .Attr("low_priority_allowed_batch_sizes", {32, 64}) - .Attr("low_priority_max_enqueued_batches", 1000) - .Attr("Tin", {DataType::DT_INT64}) - .Input(inputs) - .Attr("Tcaptured", std::vector{}) - .Input(std::vector{}) - .Attr("Tout", std::vector{DT_INT64}) - .Attr("f", f) - .Finalize(node_def())); - return InitOp(); + TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction") + .Attr("max_batch_size", enable_splitting ? 16 : 8) + .Attr("num_batch_threads", 8) + .Attr("allowed_batch_sizes", {2, 4, 8}) + .Attr("batch_timeout_micros", 1000000) + .Attr("max_enqueued_batches", 10) + .Attr("enable_large_batch_splitting", true) + .Attr("low_priority_max_batch_size", 64) + .Attr("low_priority_batch_timeout_micros", 8000) + .Attr("low_priority_allowed_batch_sizes", {32, 64}) + .Attr("low_priority_max_enqueued_batches", 1000) + .Attr("Tin", {DataType::DT_INT64}) + .Input(inputs) + .Attr("Tcaptured", std::vector{}) + .Input(std::vector{}) + .Attr("Tout", std::vector{DT_INT64}) + .Attr("f", f) + .Finalize(node_def())); + return OpsTestBase::InitOp(); } void TestBody() override {} @@ -200,5 +464,6 @@ TEST_P(BatchFunctionKernelParallelWarmupTest, ParallelWarmup) { INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelParallelWarmupTestSuite, BatchFunctionKernelParallelWarmupTest, ::testing::Bool()); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 2d9e06650e54b5..d34bd7331a35d5 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -161,6 +161,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:criticality", "@local_tsl//tsl/platform:errors", ], ) @@ -180,6 +181,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:criticality", ], alwayslink = 1, ) @@ -391,6 +393,7 @@ cc_library( "//tensorflow/core/util:incremental_barrier", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 98a83fda8833a5..51d744616db8c6 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/kernels/batching_util/input_split_metadata.h" #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" @@ -272,6 +274,17 @@ const string& GetModelName(OpKernelContext* ctx) { return ctx->session_metadata()->name(); } +// Returns the sum of the task sizes. The caller must guarantee that the +// unique_ptrs in the argument vectors are not null. +int GetTotalTaskSize( + const std::vector>& tasks) { + int tasks_size = 0; + for (const auto& task : tasks) { + tasks_size += task->size(); + } + return tasks_size; +} + } // namespace std::unique_ptr @@ -617,17 +630,22 @@ int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const { } Status BatchResourceBase::ConcatInputTensors( - const BatchT& batch, OpKernelContext* context, - std::vector* concatenated_tensors) const { + const BatchT& batch, + const std::vector>& unbatched_tasks, + OpKernelContext* context, std::vector* concatenated_tensors) const { if (batch.num_tasks() == 0) { return errors::InvalidArgument("Empty batch."); } + + int unbatched_tasks_size = GetTotalTaskSize(unbatched_tasks); const bool just_for_warmup = batch.task(0).forced_warmup_batch_size > 0; const int padded_batch_size = - just_for_warmup ? batch.task(0).forced_warmup_batch_size - : RoundToLowestAllowedBatchSize(batch.size()); + just_for_warmup + ? batch.task(0).forced_warmup_batch_size + : RoundToLowestAllowedBatchSize(batch.size() + unbatched_tasks_size); const int padding_amount = - just_for_warmup ? padded_batch_size : padded_batch_size - batch.size(); + just_for_warmup ? padded_batch_size + : padded_batch_size - batch.size() - unbatched_tasks_size; profiler::TraceMe trace_me([padded_batch_size, padding_amount, disable_padding = batcher_queue_options_.disable_padding]() { @@ -636,6 +654,9 @@ Status BatchResourceBase::ConcatInputTensors( {"padding_amount", padding_amount}, {"disable_padding", disable_padding}}); }); + // TODO(b/316379576): Add metrics for the breakdown between the size of the + // original batch size and the unbatched task size and update the batch size + // to include the unbatched tasks. RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size, context->op_kernel().name()); RecordPaddingSizeV2(padding_amount, GetModelName(context), padded_batch_size, @@ -660,10 +681,14 @@ Status BatchResourceBase::ConcatInputTensors( if (just_for_warmup) { to_concatenate.reserve(padding_amount); } else { - to_concatenate.reserve(batch.num_tasks() + padding_amount); + to_concatenate.reserve(batch.num_tasks() + unbatched_tasks.size() + + padding_amount); for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { to_concatenate.push_back(batch.task(task_idx).inputs.at(i)); } + for (int task_idx = 0; task_idx < unbatched_tasks.size(); ++task_idx) { + to_concatenate.push_back(unbatched_tasks[task_idx]->inputs.at(i)); + } } // Add padding as needed if padding is allowed. Use the first row of the @@ -794,7 +819,8 @@ Status BatchResourceBase::ConcatInputTensors( } Status BatchResourceBase::SplitOutputTensors( - const std::vector& combined_outputs, BatchT* batch) const { + const std::vector& combined_outputs, BatchT* batch, + std::vector>& unbatched_tasks) const { DCHECK_GE(batch->num_tasks(), 1); if (batch->num_tasks() < 1) { return errors::Internal("Batch size expected to be positive; was ", @@ -802,14 +828,20 @@ Status BatchResourceBase::SplitOutputTensors( } std::vector task_sizes_plus_optional_padding; - task_sizes_plus_optional_padding.reserve(batch->num_tasks()); + task_sizes_plus_optional_padding.reserve(batch->num_tasks() + + unbatched_tasks.size()); for (int i = 0; i < batch->num_tasks(); ++i) { task_sizes_plus_optional_padding.push_back(batch->task(i).size()); } - const int padding_size = - batcher_queue_options_.disable_padding - ? 0 - : RoundToLowestAllowedBatchSize(batch->size()) - batch->size(); + for (int i = 0; i < unbatched_tasks.size(); ++i) { + task_sizes_plus_optional_padding.push_back(unbatched_tasks[i]->size()); + } + int unbatched_tasks_size = GetTotalTaskSize(unbatched_tasks); + const int padding_size = batcher_queue_options_.disable_padding + ? 0 + : RoundToLowestAllowedBatchSize( + batch->size() + unbatched_tasks_size) - + batch->size() - unbatched_tasks_size; if (padding_size > 0) { task_sizes_plus_optional_padding.push_back(padding_size); } @@ -829,7 +861,8 @@ Status BatchResourceBase::SplitOutputTensors( "Batched output tensor has 0 dimensions"); } if (output_tensor.shape().dim_size(0) != - static_cast(batch->size() + padding_size)) { + static_cast(batch->size() + unbatched_tasks_size + + padding_size)) { return errors::FailedPrecondition( "Batched output tensor's 0th dimension does not equal the sum of " "the 0th dimension sizes of the input tensors"); @@ -861,12 +894,35 @@ Status BatchResourceBase::SplitOutputTensors( task.context->set_output(i, split_tensor[j]); } } + for (int j = 0; j < unbatched_tasks.size(); ++j) { + // The unbatched tasks are not split, so no need to handle the partial + // case separately. + unbatched_tasks[j]->context->set_output( + i, split_tensor[batch->num_tasks() + j]); + } } return absl::OkStatus(); } -void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { +void BatchResourceBase::CleanUpFunctionHelper(BatchTask& task, + const Status& status) const { + WithContext wc(task.propagated_context); + if (!status.ok()) { + if (!absl::StrContains(status.message(), + "Function was cancelled before it was started")) { + task.status->Update(status); + } else { + // Do not propagate this error; Prefer a more helpful error message. + LOG(ERROR) << "ERROR!!!! " << status.message(); + } + } + task.done_callback(); +} + +void BatchResourceBase::ProcessFuncBatch( + std::unique_ptr batch, + std::vector> unbatched_tasks) const { if (batch->empty()) { return; } @@ -896,24 +952,19 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { if (cleanup_done) { return; } + // TODO(b/316379576): Update this to take the unbatch task cost into + // consideration when excluding the wasted cost and propagate cost to the + // unbatched tasks. SplitBatchCostsAndRecordMetrics(model_name, batch_cost_measurements, processed_size, *batch); // Clear the measurements before unblocking the batch task, as measurements // are associated with the task's thread context. batch_cost_measurements.clear(); for (int i = 0; i < batch->num_tasks(); ++i) { - WithContext wc(batch->task(i).propagated_context); - if (!status.ok()) { - if (!absl::StrContains( - status.message(), - "Function was cancelled before it was started")) { - batch->mutable_task(i)->status->Update(status); - } else { - // Do not propagate this error; Prefer a more helpful error message. - LOG(ERROR) << "ERROR!!!! " << status.message(); - } - } - batch->mutable_task(i)->done_callback(); + CleanUpFunctionHelper(*batch->mutable_task(i), status); + } + for (int i = 0; i < unbatched_tasks.size(); ++i) { + CleanUpFunctionHelper(*unbatched_tasks[i], status); } cleanup_done = true; }; @@ -927,7 +978,8 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { } std::vector concatenated_tensors; - status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors); + status = ConcatInputTensors(*batch, unbatched_tasks, last_task_context, + &concatenated_tensors); processed_size = RoundToLowestAllowedBatchSize(batch->size()); if (!status.ok()) { return; @@ -969,7 +1021,8 @@ void BatchResourceBase::ProcessFuncBatch(std::unique_ptr batch) const { return; } if (last_task.forced_warmup_batch_size == 0) { - final_status = SplitOutputTensors(combined_outputs, batch.get()); + final_status = SplitOutputTensors(combined_outputs, batch.get(), + unbatched_tasks); } }); } @@ -1011,7 +1064,7 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { const int num_input_edges = batch->task(0).inputs.size(); std::vector concatenated_tensors; const Status concat_status = - ConcatInputTensors(*batch, last_task_context, &concatenated_tensors); + ConcatInputTensors(*batch, {}, last_task_context, &concatenated_tensors); processed_size = RoundToLowestAllowedBatchSize(batch->size()); OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback); @@ -1081,6 +1134,20 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { return absl::OkStatus(); } +void BatchResourceBase::ProcessBatchCallBack( + std::unique_ptr> batch, + std::vector> unbatched_tasks) { + if (!session_metadata().name().empty()) { + absl::MutexLock lock(&outstanding_batch_mu_); + num_outstanding_batched_items_ -= batch->size(); + } + if (!has_process_batch_function_) { + ProcessBatch(std::move(batch)); + } else { + ProcessFuncBatch(std::move(batch), std::move(unbatched_tasks)); + } +} + // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, // creates it. Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, @@ -1094,23 +1161,19 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, } std::unique_ptr new_queue; - auto process_batch_callback = [this](std::unique_ptr batch) { - if (!session_metadata().name().empty()) { - absl::MutexLock lock(&outstanding_batch_mu_); - num_outstanding_batched_items_ -= batch->size(); - } - if (!has_process_batch_function_) { - ProcessBatch(std::move(batch)); - } else { - ProcessFuncBatch(std::move(batch)); - } - }; if (batcher_) { - TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_, - process_batch_callback, &new_queue)); + TF_RETURN_IF_ERROR(batcher_->AddQueue( + batcher_queue_options_, + absl::bind_front(&BatchResourceBase::ProcessBatchCallBack, this), + &new_queue)); } else if (adaptive_batcher_) { + std::function>)> + reduced_process_batch_callback = [this](std::unique_ptr batch) { + ProcessBatchCallBack(std::move(batch), {}); + }; TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue( - adaptive_batcher_queue_options_, process_batch_callback, &new_queue)); + adaptive_batcher_queue_options_, reduced_process_batch_callback, + &new_queue)); } else { return errors::Internal("No batcher defined."); } diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index 1bd122e1d1dc9e..60ecf980e95443 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -271,17 +271,37 @@ class BatchResourceBase : public ResourceBase { // returns 'batch_size'. int RoundToLowestAllowedBatchSize(int batch_size) const; - Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context, - std::vector* concatenated_tensors) const; - - Status SplitOutputTensors(const std::vector& combined_outputs, - BatchT* batch) const; - - void ProcessFuncBatch(std::unique_ptr batch) const; + // Helper function to propagate the status to the task's context and call the + // done callback on the task. + void CleanUpFunctionHelper(BatchTask& task, const Status& status) const; + + // Concatenates the input tensors of the tasks from the batch and the + // unbatched task vector. When padding is enabled in the batcher queue, they + // are padded with garbage value up to the nearest allowed batch size. + Status ConcatInputTensors( + const BatchT& batch, + const std::vector>& unbatched_tasks, + OpKernelContext* context, + std::vector* concatenated_tensors) const; + + Status SplitOutputTensors( + const std::vector& combined_outputs, BatchT* batch, + std::vector>& unbatched_tasks) const; + + void ProcessFuncBatch( + std::unique_ptr batch, + std::vector> unbatched_tasks = {}) const; // Processes a batch of one or more BatchTask entries. void ProcessBatch(std::unique_ptr batch) const; + // Callback function that wraps the Process*Batch functions above. The caller + // of the callback must guarantee that the unique pointers passed as argument + // are not null. + void ProcessBatchCallBack( + std::unique_ptr> batch, + std::vector> unbatched_tasks); + // Emits an index tensor, which the Unbatch op will use to un-concatenate // the tensor and attribute the pieces to the right batch keys. The index // tensor contains, for each input: [batch_key, start_offset, end_offset] diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index d1d8551c250bea..4b9a599a77c5ac 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/profiler/lib/context_types.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tsl/platform/criticality.h" #include "tsl/platform/errors.h" namespace tensorflow { @@ -436,6 +437,15 @@ class Queue { // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'. bool IsEmptyInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns true iff the task is a low priority task based on the queue option. + bool IsLowPriorityTask(std::unique_ptr* task); + + // Implementation of ScheduleWithoutOrEagerSplit above. Enqueues `task` as it + // is or split it inline (eagerly) to form batches to be processed by + // `Queue::ProcessBatch` + Status ScheduleWithoutOrEagerSplitImpl(std::unique_ptr* task) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Closes the open batch residing at the back of std::deque, and inserts a // fresh open batch behind it. void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -949,6 +959,72 @@ Status Queue::ScheduleWithLazySplit(std::unique_ptr* task) { return absl::OkStatus(); } +template +bool Queue::IsLowPriorityTask(std::unique_ptr* task) { + if (!options_.enable_priority_queue) { + return false; + } + + // The criticality is defined only when the task is a derived class of + // BatchTask. + if constexpr (std::is_base_of_v) { + // TODO(b/316379576): Make the criticality and priority configurable. + return ((*task)->criticality() == + tsl::criticality::Criticality::kSheddablePlus || + (*task)->criticality() == + tsl::criticality::Criticality::kSheddable); + } + + // Otherwise, consider it a high priority task and return false. + return false; +} + +template +Status Queue::ScheduleWithoutOrEagerSplitImpl( + std::unique_ptr* task) { + // TODO(b/161857471): + // Add test coverage when when concurrent incoming batches arrives and + // use up all queue capacity. + TF_RETURN_IF_ERROR(ValidateBatchTaskQueueCapacity((*task).get())); + + std::deque>>& batches = GetBatches(); + + const int64_t open_batch_remaining_slot = + max_execution_batch_size() - batches.back()->size(); + + const int64_t input_task_size = (*task)->size(); + + std::vector> output_tasks; + + if (input_task_size <= open_batch_remaining_slot || + !options_.enable_large_batch_splitting) { + // This is the fast path when input doesn't need to be split. + output_tasks.push_back(std::move(*task)); + } else { + TF_RETURN_IF_ERROR(SplitInputBatchIntoSubtasks(task, &output_tasks)); + } + + for (int i = 0; i < output_tasks.size(); ++i) { + if (batches.back()->size() + output_tasks[i]->size() > + max_execution_batch_size()) { + StartNewBatch(); + } + if (batches.back()->empty()) { + open_batch_start_time_micros_ = env_->NowMicros(); + } + profiler::TraceMeProducer trace_me( + [&output_tasks, i] { + return profiler::TraceMeEncode("ScheduleOutputTask", + {{"size", output_tasks[i]->size()}}); + }, + profiler::ContextType::kSharedBatchScheduler, + batches.back()->traceme_context_id()); + batches.back()->AddTask(std::move(output_tasks[i])); + } + + return absl::OkStatus(); +} + // TODO(b/194294263): // Merge `ScheduleWithoutOrEagerSplit` and `ScheduleWithLazySplit` into // `Schedule`. @@ -969,48 +1045,18 @@ Status Queue::ScheduleWithoutOrEagerSplit( DCHECK(!closed_); - // TODO(b/161857471): - // Add test coverage when when concurrent incoming batches arrives and - // use up all queue capacity. - TF_RETURN_IF_ERROR(ValidateBatchTaskQueueCapacity((*task).get())); - - std::deque>>& batches = GetBatches(); - - const int64_t open_batch_remaining_slot = - max_execution_batch_size() - batches.back()->size(); - - const int64_t input_task_size = (*task)->size(); - - std::vector> output_tasks; - - if (input_task_size <= open_batch_remaining_slot || - !large_batch_splitting) { - // This is the fast path when input doesn't need to be split. - output_tasks.push_back(std::move(*task)); + if (IsLowPriorityTask(task)) { + // Insert the task to the low priority task queue instead of the high + // priority batch queue below. + low_priority_tasks_.AddTask(std::move(*task)); } else { - TF_RETURN_IF_ERROR(SplitInputBatchIntoSubtasks(task, &output_tasks)); - } - - for (int i = 0; i < output_tasks.size(); ++i) { - if (batches.back()->size() + output_tasks[i]->size() > - max_execution_batch_size()) { - StartNewBatch(); - } - if (batches.back()->empty()) { - open_batch_start_time_micros_ = env_->NowMicros(); - } - profiler::TraceMeProducer trace_me( - [&output_tasks, i] { - return profiler::TraceMeEncode("ScheduleOutputTask", - {{"size", output_tasks[i]->size()}}); - }, - profiler::ContextType::kSharedBatchScheduler, - batches.back()->traceme_context_id()); - batches.back()->AddTask(std::move(output_tasks[i])); + TF_RETURN_IF_ERROR(ScheduleWithoutOrEagerSplitImpl(task)); } + // Check if the batch queue has a schedulable batch and mark it schedulable + // if it not already marked. if (!schedulable_batch_) { - if (batches.size() > 1 || IsOpenBatchSchedulable()) { + if (GetBatches().size() > 1 || IsOpenBatchSchedulable()) { schedulable_batch_ = true; notify_of_schedulable_batch = true; } diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index a703028a5e6234..29b79b3bb4b712 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" +#include #include #include #include // NOLINT(build/c++11) @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/platform/criticality.h" namespace tensorflow { namespace serving { @@ -46,19 +48,43 @@ using ::testing::HasSubstr; class FakeTask : public BatchTask { public: - explicit FakeTask(size_t size) : size_(size) {} + explicit FakeTask(size_t size, tsl::criticality::Criticality criticality = + tsl::criticality::Criticality::kCritical) + : size_(size), criticality_(criticality) {} ~FakeTask() override = default; size_t size() const override { return size_; } + tsl::criticality::Criticality criticality() const override { + return criticality_; + } + private: const size_t size_; + const tsl::criticality::Criticality criticality_; FakeTask(const FakeTask&) = delete; void operator=(const FakeTask&) = delete; }; +// Fake task taht doesn't inherit BatchTask and doesn't define criticality. The +// shared batch scheduler should still work with this task. +class FakeTaskWithoutCriticality { + public: + explicit FakeTaskWithoutCriticality(size_t size) : size_(size) {} + + ~FakeTaskWithoutCriticality() = default; + + size_t size() const { return size_; } + + private: + const size_t size_; + + FakeTaskWithoutCriticality(const FakeTaskWithoutCriticality&) = delete; + void operator=(const FakeTaskWithoutCriticality&) = delete; +}; + using Queue = BatchScheduler; using Scheduler = SharedBatchScheduler; using QueueOptions = Scheduler::QueueOptions; @@ -67,10 +93,26 @@ using SplitFunc = int first_output_task_size, int input_batch_size_limit, std::vector>* output_tasks)>; -// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on -// that task. Returns the resulting status. -Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { - std::unique_ptr task(new FakeTask(task_size)); +// Creates a FakeTask of size 'task_size' and 'criticality', and calls +// 'scheduler->Schedule()' on that task. Returns the resulting status. +// 'criticality' defaults to kCritical. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler, + tsl::criticality::Criticality criticality = + tsl::criticality::Criticality::kCritical) { + std::unique_ptr task(new FakeTask(task_size, criticality)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Helper function similar to the function above. Creates a FakeTask of size +// 'task_size' and calls 'scheduler->Schedule()' on that task. Returns the +// resulting status. +Status ScheduleTaskWithoutCriticality( + size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task( + new FakeTaskWithoutCriticality(task_size)); Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); @@ -349,6 +391,101 @@ TEST_P(SharedBatchSchedulerTest, EXPECT_TRUE(queue_1_callback_called); } +// The task in the shared batch scheduler template parameter does not define +// criticality priority queue. It should work as if the priority queue is +// disabled. +TEST_P( + SharedBatchSchedulerTest, + CallbackWithTaskVectorOkWithPriorityQueueEnabledWithCriticalitylessTask) { + bool queue_0_callback_called = false; + auto queue_0_callback = + [&queue_0_callback_called]( + std::unique_ptr> batch, + std::vector> tasks) { + queue_0_callback_called = true; + ASSERT_TRUE(batch->IsClosed()); + ASSERT_EQ(3, batch->num_tasks()); + EXPECT_EQ(1, batch->task(0).size()); + EXPECT_EQ(3, batch->task(1).size()); + EXPECT_EQ(5, batch->task(2).size()); + EXPECT_EQ(0, tasks.size()); + }; + bool queue_1_callback_called = false; + auto queue_1_callback = + [&queue_1_callback_called]( + std::unique_ptr> batch, + std::vector> tasks) { + queue_1_callback_called = true; + ASSERT_TRUE(batch->IsClosed()); + ASSERT_EQ(2, batch->num_tasks()); + EXPECT_EQ(2, batch->task(0).size()); + EXPECT_EQ(4, batch->task(1).size()); + EXPECT_EQ(0, tasks.size()); + }; + { + SharedBatchScheduler::Options options; + options.num_batch_threads = 3; + options.env = Env::Default(); + + std::shared_ptr> + shared_batch_scheduler; + TF_CHECK_OK(SharedBatchScheduler::Create( + options, &shared_batch_scheduler)); + + // Create two queues. + + const SharedBatchScheduler::QueueOptions + queue_options = { + .input_batch_size_limit = 10, + .batch_timeout_micros = 1000 * 1000, + .max_enqueued_batches = 2, + .enable_large_batch_splitting = enable_input_batch_split(), + .split_input_task_func = + [](std::unique_ptr* input_task, + int open_batch_remaining_slot, int max_batch_size, + std::vector>* + output_tasks) -> Status { + std::unique_ptr owned_input_task = + std::move(*input_task); + const int input_task_size = owned_input_task->size(); + + const internal::InputSplitMetadata input_split_metadata( + input_task_size, open_batch_remaining_slot, max_batch_size); + + const absl::FixedArray task_sizes = + input_split_metadata.task_sizes(); + const int num_batches = task_sizes.size(); + + output_tasks->resize(num_batches); + for (int i = 0; i < num_batches; i++) { + (*output_tasks)[i] = + std::make_unique(task_sizes[i]); + } + + return absl::OkStatus(); + }, + .enable_lazy_split = enable_lazy_split(), + .max_execution_batch_size = 10, + .enable_priority_queue = true}; + + std::unique_ptr> queue_0; + TF_CHECK_OK(shared_batch_scheduler->AddQueue(queue_options, + queue_0_callback, &queue_0)); + std::unique_ptr> queue_1; + TF_CHECK_OK(shared_batch_scheduler->AddQueue(queue_options, + queue_1_callback, &queue_1)); + + // Submit tasks to the two queues. + TF_ASSERT_OK(ScheduleTaskWithoutCriticality(1, queue_0.get())); + TF_ASSERT_OK(ScheduleTaskWithoutCriticality(2, queue_1.get())); + TF_ASSERT_OK(ScheduleTaskWithoutCriticality(3, queue_0.get())); + TF_ASSERT_OK(ScheduleTaskWithoutCriticality(4, queue_1.get())); + TF_ASSERT_OK(ScheduleTaskWithoutCriticality(5, queue_0.get())); + } + EXPECT_TRUE(queue_0_callback_called); + EXPECT_TRUE(queue_1_callback_called); +} + TEST_P(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) { // Set up a fake clock, which only advances when we explicitly tell it to. test_util::FakeClockEnv env(Env::Default()); @@ -912,6 +1049,93 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(/*enable_input_batch_split=*/false, /*enable_lazy_split=*/false))); +using SharedBatchSchedulerPriorityTest = SharedBatchSchedulerTest; + +TEST_P(SharedBatchSchedulerPriorityTest, + CallbackWithTaskVectorOkWithPriorityQueueEnabledWithPrioritySet) { + bool queue_callback_called = false; + auto queue_callback = [&queue_callback_called]( + std::unique_ptr> batch, + std::vector> tasks) { + queue_callback_called = true; + ASSERT_TRUE(batch->IsClosed()); + ASSERT_EQ(2, batch->num_tasks()); + EXPECT_EQ(1, batch->task(0).size()); + EXPECT_EQ(3, batch->task(1).size()); + EXPECT_EQ(1, tasks.size()); + EXPECT_EQ(5, tasks[0]->size()); + }; + + { + std::shared_ptr scheduler = + CreateSharedBatchScheduler(/*num_batch_threads=*/3); + + // Create two queues. + const QueueOptions queue_options = CreateQueueOptions( + /*max_execution_batch_size=*/10, /*input_batch_size_limit=*/10, + /*batch_timeout_micros=*/1 * 1000 * 1000, /*max_enqueued_batches=*/2, + /*enable_priority_queue=*/true); + std::unique_ptr queue = + CreateQueue(scheduler, queue_options, queue_callback); + + // Submit tasks to the two queues. + TF_ASSERT_OK(ScheduleTask(1, queue.get(), + tsl::criticality::Criticality::kCriticalPlus)); + TF_ASSERT_OK(ScheduleTask(3, queue.get(), + tsl::criticality::Criticality::kCriticalPlus)); + TF_ASSERT_OK(ScheduleTask(5, queue.get(), + tsl::criticality::Criticality::kSheddable)); + } + EXPECT_TRUE(queue_callback_called); +} + +TEST_P(SharedBatchSchedulerPriorityTest, + CallbackWithTaskVectorOkWithPriorityQueueDisabledWithPrioritySet) { + bool queue_callback_called = false; + auto queue_callback = [&queue_callback_called]( + std::unique_ptr> batch, + std::vector> tasks) { + queue_callback_called = true; + ASSERT_TRUE(batch->IsClosed()); + ASSERT_EQ(3, batch->num_tasks()); + EXPECT_EQ(1, batch->task(0).size()); + EXPECT_EQ(3, batch->task(1).size()); + EXPECT_EQ(5, batch->task(2).size()); + EXPECT_EQ(0, tasks.size()); + }; + + { + std::shared_ptr scheduler = + CreateSharedBatchScheduler(/*num_batch_threads=*/3); + + // Create two queues. + const QueueOptions queue_options = CreateQueueOptions( + /*max_execution_batch_size=*/10, /*input_batch_size_limit=*/10, + /*batch_timeout_micros=*/1 * 1000 * 1000, /*max_enqueued_batches=*/2, + /*enable_priority_queue=*/false); + std::unique_ptr queue = + CreateQueue(scheduler, queue_options, queue_callback); + + // Submit tasks to the two queues. + TF_ASSERT_OK(ScheduleTask(1, queue.get(), + tsl::criticality::Criticality::kCriticalPlus)); + TF_ASSERT_OK(ScheduleTask(3, queue.get(), + tsl::criticality::Criticality::kCriticalPlus)); + TF_ASSERT_OK(ScheduleTask(5, queue.get(), + tsl::criticality::Criticality::kSheddable)); + } + EXPECT_TRUE(queue_callback_called); +} + +// Lazy split is to be removed. The mixed priority batching is only supported +// when the lazy split is not enabled. +INSTANTIATE_TEST_SUITE_P( + Parameter, SharedBatchSchedulerPriorityTest, + ::testing::Values(std::make_tuple(/*enable_input_batch_split=*/true, + /*enable_lazy_split=*/false), + std::make_tuple(/*enable_input_batch_split=*/false, + /*enable_lazy_split=*/false))); + #ifdef PLATFORM_GOOGLE // This benchmark relies on https://github.com/google/benchmark features, // (in particular, `Benchmark::ThreadRange`) not available in open-sourced TF From a6e55e74bba04b80bab202a53265182d06307aa1 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 18 Mar 2024 09:24:37 -0700 Subject: [PATCH 027/670] [XLA:GPU] Add IndexingContext to store MLIRContext* and RTVars registry. PiperOrigin-RevId: 616850787 --- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/fusions/BUILD | 3 + .../xla/service/gpu/fusions/concatenate.cc | 2 +- .../xla/xla/service/gpu/fusions/concatenate.h | 4 +- .../service/gpu/fusions/concatenate_mlir.cc | 17 +- .../service/gpu/fusions/concatenate_mlir.h | 5 +- .../xla/service/gpu/fusions/fusion_emitter.cc | 19 +- .../xla/service/gpu/fusions/fusion_emitter.h | 6 +- .../fusions/in_place_dynamic_update_slice.h | 4 +- .../xla/service/gpu/fusions/input_slices.cc | 4 +- .../xla/service/gpu/fusions/input_slices.h | 4 +- .../service/gpu/fusions/input_slices_mlir.cc | 8 +- .../service/gpu/fusions/input_slices_mlir.h | 4 +- .../service/gpu/fusions/input_slices_test.cc | 5 +- .../xla/xla/service/gpu/fusions/loop.cc | 13 +- .../xla/xla/service/gpu/fusions/loop.h | 4 +- .../xla/xla/service/gpu/fusions/loop_mlir.cc | 17 +- .../xla/xla/service/gpu/fusions/loop_mlir.h | 4 +- .../xla/service/gpu/fusions/loop_mlir_test.cc | 16 +- .../xla/xla/service/gpu/fusions/loop_test.cc | 12 +- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 11 +- .../gpu/fusions/mlir/mlir_fusion_emitter.h | 1 + .../fusions/mlir/mlir_fusion_emitter_test.cc | 4 +- .../gpu/fusions/mlir/simplify_affine.cc | 5 +- .../gpu/fusions/mlir_emitter_test_base.cc | 3 +- .../gpu/fusions/mlir_emitter_test_base.h | 2 + .../xla/service/gpu/fusions/reduction_base.cc | 39 +-- .../xla/service/gpu/fusions/reduction_base.h | 13 +- .../gpu/fusions/reduction_base_test.cc | 51 ++-- .../xla/service/gpu/fusions/reduction_mlir.cc | 12 +- .../xla/xla/service/gpu/fusions/scatter.h | 4 +- .../xla/service/gpu/fusions/scatter_mlir.cc | 17 +- .../xla/service/gpu/fusions/scatter_mlir.h | 4 +- .../service/gpu/fusions/scatter_mlir_test.cc | 12 +- .../xla/xla/service/gpu/fusions/transpose.cc | 20 +- .../xla/xla/service/gpu/fusions/transpose.h | 4 +- .../xla/service/gpu/fusions/transpose_mlir.cc | 39 +-- .../xla/service/gpu/fusions/transpose_mlir.h | 14 +- .../gpu/fusions/transpose_mlir_test.cc | 16 +- .../xla/service/gpu/fusions/transpose_test.cc | 41 +-- .../xla/xla/service/gpu/ir_emitter_context.h | 4 + third_party/xla/xla/service/gpu/model/BUILD | 13 +- .../service/gpu/model/coalescing_analysis.cc | 29 ++- .../service/gpu/model/coalescing_analysis.h | 7 +- .../gpu/model/coalescing_analysis_test.cc | 6 +- .../model/gpu_indexing_performance_model.cc | 2 +- .../model/gpu_indexing_performance_model.h | 5 +- .../service/gpu/model/indexing_analysis.cc | 236 +++++++++++------- .../xla/service/gpu/model/indexing_analysis.h | 37 +-- .../gpu/model/indexing_analysis_test.cc | 14 +- .../xla/service/gpu/model/indexing_context.cc | 27 ++ .../xla/service/gpu/model/indexing_context.h | 54 ++++ .../xla/xla/service/gpu/model/indexing_map.cc | 21 +- .../xla/xla/service/gpu/model/indexing_map.h | 24 +- .../service/gpu/model/indexing_map_test.cc | 79 ++++-- .../service/gpu/model/indexing_test_utils.cc | 12 +- .../service/gpu/model/indexing_test_utils.h | 4 + .../xla/service/gpu/model/tile_analysis.cc | 8 +- 58 files changed, 665 insertions(+), 381 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/model/indexing_context.cc create mode 100644 third_party/xla/xla/service/gpu/model/indexing_context.h diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 59d7cc553ec9cc..508fd8f638e8a5 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -300,6 +300,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:name_uniquer", + "//xla/service/gpu/model:indexing_map", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 94e10f18bc4d55..fefac9f2e75fa0 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -216,6 +216,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_map", "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -705,6 +706,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu/model:indexing_map", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -981,6 +983,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_map", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/concatenate.cc index 084aece24b1c92..b8acbd4f8072d9 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate.cc @@ -58,7 +58,7 @@ ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis) : analysis_(analysis) {} std::optional ConcatenateFusion::ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const { + int64_t output_id, IndexingContext* indexing_context) const { return std::nullopt; // TODO(b/319081342): Implement this. } diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.h b/third_party/xla/xla/service/gpu/fusions/concatenate.h index 997033293eff2b..5e51b50c2d1408 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate.h @@ -38,11 +38,11 @@ class ConcatenateFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override; + int64_t output_id, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { // TODO(b/319081342): Implement this. return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index 638c8ec9436c25..974365eca8efa3 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -66,17 +66,17 @@ LaunchDimensions MlirConcatenateFusion::launch_dimensions() const { std::optional MlirConcatenateFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { return std::nullopt; } std::optional MlirConcatenateFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { return GetDefaultThreadIdToOutputIndexingMap( launch_dimensions(), /*unroll_factor=*/1, - GetLargestConcatOperandShape(analysis_), ctx); + GetLargestConcatOperandShape(analysis_), indexing_context); } std::vector @@ -96,7 +96,8 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction( const auto* concat = analysis_.fusion_heroes()[0]; mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); builder.setInsertionPointToStart(entry_function.addEntryBlock()); - auto* ctx = entry_function.getContext(); + auto* mlir_context = entry_function.getContext(); + IndexingContext indexing_context{mlir_context}; int num_inputs = fusion.fused_instructions_computation()->num_parameters(); SmallVector input_tensors( @@ -109,13 +110,15 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction( auto thread_id_to_input_map = ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, ctx) + /*root_index=*/0, /*hero_operand_index=*/0, &indexing_context) .value(); - auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing(concat, ctx); + auto epilogue_indexing = + ComputeEpilogueInputToOutputIndexing(concat, &indexing_context); for (auto [operand_index, operand] : llvm::enumerate(concat->operands())) { auto input_to_output_map = - *ComputeInputToOutputIndexing(concat, /*input_id=*/operand_index, ctx) + *ComputeInputToOutputIndexing(concat, /*input_id=*/operand_index, + &indexing_context) .indexing_maps.front() .begin(); auto thread_id_to_output_map = ComposeIndexingMaps( diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h index 5003046bf39e41..f07a637d16c956 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -42,11 +43,11 @@ class MlirConcatenateFusion : public MlirFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + IndexingContext* indexing_context) const override; protected: absl::Status EmitEntryFunction( diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index d4d3a33ce57e78..e18557c012df62 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -119,8 +119,9 @@ absl::Status AnnotateKernelLaunchDimensions( IndexingMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( const LaunchDimensions& launch_dims, int unroll_factor, - const Shape& output_shape, mlir::MLIRContext* ctx) { + const Shape& output_shape, IndexingContext* indexing_context) { std::vector output_dims(output_shape.rank()); + auto mlir_context = indexing_context->GetMLIRContext(); std::array thread_counts{ launch_dims.thread_counts_per_block().x, @@ -143,19 +144,20 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( // This means that this code supports some launch grids that the parallel // loop emitter doesn't support. This is safe, since the latter CHECK fails // if its assumptions are not fulfilled. - mlir::AffineExpr c0 = mlir::getAffineConstantExpr(0, ctx); + mlir::AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); mlir::AffineExpr linear_index = c0; uint64_t stride = 1; for (int i = 0; i < 3; ++i) { - auto coord = mlir::getAffineDimExpr(kIndexingMapThreadIdxDims[i], ctx) + - mlir::getAffineDimExpr(kIndexingMapBlockIdxDims[i], ctx) * - thread_counts[i]; + auto coord = + mlir::getAffineDimExpr(kIndexingMapThreadIdxDims[i], mlir_context) + + mlir::getAffineDimExpr(kIndexingMapBlockIdxDims[i], mlir_context) * + thread_counts[i]; auto linear_component = coord * stride; linear_index = linear_index + linear_component; stride *= total_sizes[i]; } - mlir::AffineExpr chunk_id = mlir::getAffineSymbolExpr(0, ctx); - mlir::AffineExpr unroll_elem_id = mlir::getAffineSymbolExpr(1, ctx); + mlir::AffineExpr chunk_id = mlir::getAffineSymbolExpr(0, mlir_context); + mlir::AffineExpr unroll_elem_id = mlir::getAffineSymbolExpr(1, mlir_context); linear_index = linear_index * unroll_factor + chunk_id * unroll_factor * launch_dims.launch_bound() + @@ -187,8 +189,9 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdToOutputIndexingMap( 1}); symbol_ranges.push_back({0, unroll_factor - 1}); IndexingMap indexing_map( + indexing_context, mlir::AffineMap::get(/*dimCount=*/6, - /*symbolCount=*/2, output_dims, ctx), + /*symbolCount=*/2, output_dims, mlir_context), dimension_ranges, symbol_ranges); // Remove the unroll_elem_id symbol if unrolling divides num_elements. if (num_elements % unroll_factor == 0) { diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h index dbc8e8718debe0..b5fa0f32152e32 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.h @@ -77,14 +77,14 @@ class KernelFusionInterface : public FusionInterface { // unsupported (scatter, in-place DUS). Implementations will return nullopt. // Note: Work in progress, not implemented for all emitters. virtual std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const = 0; + int64_t root_index, IndexingContext* indexing_context) const = 0; // Computes an indexing map from thread to input element(s) of the root's // **hero**. Note that in many cases this is not computable from the output // indexing. The indexing may only be known for some operands of the hero. virtual std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const = 0; + IndexingContext* indexing_context) const = 0; static constexpr std::array kIndexingMapThreadIdxDims = {0, 1, 2}; static constexpr std::array kIndexingMapBlockIdxDims = {3, 4, 5}; @@ -96,7 +96,7 @@ class KernelFusionInterface : public FusionInterface { // block sizes in the given launch dimensions. static IndexingMap GetDefaultThreadIdToOutputIndexingMap( const LaunchDimensions& launch_dims, int unroll_factor, - const Shape& output_shape, mlir::MLIRContext* ctx); + const Shape& output_shape, IndexingContext* indexing_context); }; // Base class for fusions that are implemented using a single kernel, which is diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h index 12be8043b05ec1..4e4f2d82e94a80 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h @@ -67,7 +67,7 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { + int64_t root_index, IndexingContext* indexing_context) const override { // The mapping cannot be statically computed in general, since the offsets // are unknown. return std::nullopt; @@ -75,7 +75,7 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { // TODO(b/319081342): Implement this. return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/input_slices.cc index 85f661a8f125f5..aa1398639bd397 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices.cc @@ -183,7 +183,7 @@ LaunchDimensions InputSlicesFusion::launch_dimensions() const { } std::optional InputSlicesFusion::ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const { + int64_t output_id, IndexingContext* indexing_context) const { // The mapping here is trivial and the same for all outputs - slice offsets // are applied in the indexing from slice outputs to slice inputs. auto launch_dims = launch_dimensions(); @@ -191,7 +191,7 @@ std::optional InputSlicesFusion::ComputeThreadIdToOutputIndexing( // still use the requested output's shape for clarity. const auto& shape = analysis_.fusion_roots()[output_id]->shape(); return GetDefaultThreadIdToOutputIndexingMap(launch_dims, unroll_factor_, - shape, ctx); + shape, indexing_context); } absl::Status InputSlicesFusion::EmitKernel( diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.h b/third_party/xla/xla/service/gpu/fusions/input_slices.h index 90f4f4e4a24d03..b1164c5df28e45 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.h +++ b/third_party/xla/xla/service/gpu/fusions/input_slices.h @@ -48,11 +48,11 @@ class InputSlicesFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override; + int64_t output_id, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { // TODO(b/319081342): Implement this. return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc index c1108ca37e8cd3..a10babd539b2e7 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc @@ -52,7 +52,7 @@ using mlir::ValueRange; std::optional MlirInputSlicesFusion::ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const { + int64_t output_id, IndexingContext* indexing_context) const { // The mapping here is trivial and the same for all outputs - slice offsets // are applied in the indexing from slice outputs to slice inputs. auto launch_dims = launch_dimensions(); @@ -60,7 +60,7 @@ MlirInputSlicesFusion::ComputeThreadIdToOutputIndexing( // still use the requested output's shape for clarity. const auto& shape = analysis_.fusion_roots()[output_id]->shape(); return GetDefaultThreadIdToOutputIndexingMap(launch_dims, unroll_factor_, - shape, ctx); + shape, indexing_context); } LaunchDimensions MlirInputSlicesFusion::launch_dimensions() const { @@ -80,8 +80,8 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction( // We enforce that all the root shapes have identical dimensions in // IsHloOpSupported. - auto indexing = - ComputeThreadIdToOutputIndexing(0, entry_function.getContext()); + IndexingContext indexing_context{entry_function.getContext()}; + auto indexing = ComputeThreadIdToOutputIndexing(0, &indexing_context); TF_RET_CHECK(indexing) << "Indexing is never nullopt"; int num_inputs = fusion.fused_instructions_computation()->num_parameters(); diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h index 1de06b963d9e59..53b9d76f97a9ca 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h @@ -37,11 +37,11 @@ class MlirInputSlicesFusion : public MlirFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override; + int64_t output_id, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { // TODO(b/319081342): Implement this. return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc index 094bbfac7a27a9..939ab506b62dc4 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -34,6 +35,7 @@ namespace { class InputSlicesTest : public HloTestBase { public: + InputSlicesTest() : indexing_context_(&mlir_context_) {} void SetUp() override { HloTestBase::SetUp(); printer_ = @@ -44,6 +46,7 @@ class InputSlicesTest : public HloTestBase { protected: AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; }; TEST_F(InputSlicesTest, ThreadIndexing) { @@ -76,7 +79,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { ASSERT_NE(fusion, nullptr); auto thread_id_to_output_indexing = - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); + fusion->ComputeThreadIdToOutputIndexing(0, &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc index e7a13200fe391f..35b1f18348ac6d 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop.cc @@ -215,23 +215,24 @@ LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} std::optional LoopFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { auto launch_dims = launch_dimensions(); return GetDefaultThreadIdToOutputIndexingMap( - launch_dims, config_.unroll_factor, GetElementShape(analysis_), ctx); + launch_dims, config_.unroll_factor, GetElementShape(analysis_), + indexing_context); } std::optional LoopFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { std::optional thread_id_to_output_indexing = - ComputeThreadIdToOutputIndexing(root_index, ctx); + ComputeThreadIdToOutputIndexing(root_index, indexing_context); if (!thread_id_to_output_indexing.has_value()) { return std::nullopt; } const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; - auto output_to_input_indexing = - ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + auto output_to_input_indexing = ComputeOutputToInputIndexing( + fusion_root, /*output_id=*/0, indexing_context); IndexingMapSet output_to_input_indexing_set = output_to_input_indexing.indexing_maps[hero_operand_index]; // Since we are computing the indexing for a non-fusion op, there is only one diff --git a/third_party/xla/xla/service/gpu/fusions/loop.h b/third_party/xla/xla/service/gpu/fusions/loop.h index e466abe66a843f..9371015cf0a356 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.h +++ b/third_party/xla/xla/service/gpu/fusions/loop.h @@ -40,11 +40,11 @@ class LoopFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + IndexingContext* indexing_context) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc index 82734d06cc9c9a..0989de4bde6726 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -63,23 +63,24 @@ const Shape& GetFusionResultShape(const HloFusionAnalysis& analysis) { } // namespace std::optional MlirLoopFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { auto launch_dims = launch_dimensions(); return GetDefaultThreadIdToOutputIndexingMap( - launch_dims, config_.unroll_factor, GetFusionResultShape(analysis_), ctx); + launch_dims, config_.unroll_factor, GetFusionResultShape(analysis_), + indexing_context); } std::optional MlirLoopFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { std::optional thread_id_to_output_indexing = - ComputeThreadIdToOutputIndexing(root_index, ctx); + ComputeThreadIdToOutputIndexing(root_index, indexing_context); if (!thread_id_to_output_indexing.has_value()) { return std::nullopt; } const HloInstruction* fusion_root = analysis_.fusion_roots()[root_index]; - auto output_to_input_indexing = - ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + auto output_to_input_indexing = ComputeOutputToInputIndexing( + fusion_root, /*output_id=*/0, indexing_context); IndexingMapSet output_to_input_indexing_set = output_to_input_indexing.indexing_maps[hero_operand_index]; // Since we are computing the indexing for a non-fusion op, there is only one @@ -106,8 +107,8 @@ absl::Status MlirLoopFusion::EmitEntryFunction( // We enforce that all the root shapes have identical dimensions in // IsHloOpSupported. - auto indexing = - ComputeThreadIdToOutputIndexing(0, entry_function.getContext()); + IndexingContext indexing_context{entry_function.getContext()}; + auto indexing = ComputeThreadIdToOutputIndexing(0, &indexing_context); TF_RET_CHECK(indexing) << "Indexing is never nullopt"; int num_inputs = fusion.fused_instructions_computation()->num_parameters(); diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h index 228c8c87b5ff28..b70b7070ab626f 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h @@ -37,11 +37,11 @@ class MlirLoopFusion : public MlirFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + IndexingContext* indexing_context) const override; protected: absl::Status EmitEntryFunction( diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 1f3d41bddc46a0..9febfd5d565e66 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -47,8 +47,8 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + auto thread_id_to_output_indexing = fusion.ComputeThreadIdToOutputIndexing( + /*root_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( @@ -90,8 +90,8 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { auto analysis = AnalyzeFusion(*root, device_info_); MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + auto thread_id_to_output_indexing = fusion.ComputeThreadIdToOutputIndexing( + /*root_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) @@ -106,7 +106,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + /*root_index=*/0, /*hero_operand_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) @@ -142,8 +142,8 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { auto analysis = AnalyzeFusion(*root, device_info_); MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); + auto thread_id_to_output_indexing = fusion.ComputeThreadIdToOutputIndexing( + /*root_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( @@ -162,7 +162,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { th_x + bl_x * 128 in [0, 5999] )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + /*root_index=*/0, /*hero_operand_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index 1bb5fdb8705d30..91e56aafe0a6b5 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -37,6 +37,7 @@ namespace { class LoopTest : public HloTestBase { public: + LoopTest() : indexing_context_(&mlir_context_) {} void SetUp() override { HloTestBase::SetUp(); @@ -50,6 +51,7 @@ class LoopTest : public HloTestBase { TestGpuDeviceInfo::RTXA6000DeviceInfo(); AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; }; absl::StatusOr> GetFusion( @@ -84,7 +86,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); + &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( @@ -127,7 +129,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); + &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) @@ -143,7 +145,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + /*root_index=*/0, /*hero_operand_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) @@ -180,7 +182,7 @@ TEST_F(LoopTest, Broadcast) { TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); + &indexing_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( @@ -200,7 +202,7 @@ TEST_F(LoopTest, Broadcast) { )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); + /*root_index=*/0, /*hero_operand_index=*/0, &indexing_context_); EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 0701425c81ca5b..8e2124dbe0e87b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -67,6 +67,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -468,7 +469,8 @@ absl::StatusOr> EmitPad( const HloInstruction* instr, mlir::Type result_element_type, ValueRange indices, const OperandProvider& operand_provider, ImplicitLocOpBuilder& b) { - auto indexing = ComputeOutputToInputIndexing(instr, 0, b.getContext()); + IndexingContext indexing_context{b.getContext()}; + auto indexing = ComputeOutputToInputIndexing(instr, 0, &indexing_context); const auto& indexing_map = *indexing.indexing_maps[0].begin(); mlir::Value is_in_bounds = CheckConstraints(indexing_map, indices, {}, b); @@ -673,9 +675,10 @@ absl::StatusOr> HloToMlir( operand->shape().element_type(), builder)); arg_types.push_back(operand_element_type); } - auto input_indices = GetInputIndices( - ComputeOutputToInputIndexing(instr, 0, builder.getContext()), indices, - builder); + IndexingContext indexing_context(builder.getContext()); + auto input_indices = + GetInputIndices(ComputeOutputToInputIndexing(instr, 0, &indexing_context), + indices, builder); SmallVector operands; for (auto&& [operand_number, operand_indices] : llvm::enumerate(input_indices)) { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index 6baf86372613e5..79836a9d5ed8a3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index b0ed47330a5558..d5623e11ae58ee 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -61,12 +61,12 @@ class DummyCopyFusionEmitter : public MlirFusionEmitterBase { LaunchDimensions launch_dimensions() const final { return {1, 100}; } std::optional ComputeThreadIdToOutputIndexing( - int64_t, mlir::MLIRContext*) const final { + int64_t, IndexingContext*) const final { return std::nullopt; } std::optional ComputeThreadIdToInputIndexing( - int64_t, int64_t, mlir::MLIRContext*) const final { + int64_t, int64_t, IndexingContext*) const final { return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc index c86fbda41c4cf1..2507281a283ebd 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -114,7 +115,9 @@ struct RewriteAffineApply } } - IndexingMap map(op.getAffineMap(), dim_ranges, symbol_ranges); + IndexingContext indexing_context(op->getContext()); + IndexingMap map(&indexing_context, op.getAffineMap(), dim_ranges, + symbol_ranges); map.Simplify(); auto expr = map.GetAffineMap().getResult(0); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc index 2dfc06b9e747af..bdf424c079c7a3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc @@ -49,7 +49,8 @@ limitations under the License. namespace xla { namespace gpu { -MlirEmitterTestBaseImpl::MlirEmitterTestBaseImpl() { +MlirEmitterTestBaseImpl::MlirEmitterTestBaseImpl() + : indexing_context_(&mlir_context_) { // clang-format off mlir_context_.loadDialect< mlir::affine::AffineDialect, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h index a299c2ea4007ba..147b57b6f84b70 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -50,6 +51,7 @@ class MlirEmitterTestBaseImpl : public HloTestBase { stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; AffineMapPrinter thread_id_printer_; }; diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index 57db8735cc0e60..6ea9220034eaa8 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -317,18 +317,19 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { } std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { if (!groups_.is_reduction_root[root_index]) { // Non-transpose roots are elementwise by definition. - return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + return ComputeThreadIdToInputIndexing(root_index, 0, indexing_context); } auto* root = analysis_.fusion_roots()[root_index]; auto* hero = analysis_.fusion_heroes()[root_index]; - auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx); - auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), - tiling_.GetThreadsPerBlock(), - tiling_.GetThreadStrides()); + auto mlir_context = indexing_context->GetMLIRContext(); + auto block_offsets = GetBlockOffsetsForTiling(tiling_, mlir_context); + auto thread_ids = DelinearizeInBoundsIndex( + mlir::getAffineDimExpr(0, mlir_context), tiling_.GetThreadsPerBlock(), + tiling_.GetThreadStrides()); auto physical_shape = ShapeUtil::DeleteDimensions(hero->dimensions(), hero->operand(0)->shape()); @@ -352,9 +353,10 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( auto physical_index = [&]() { if (is_row_reduction_) { IndexingMap linear_index( + indexing_context, mlir::AffineMap::get( 6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept], - ctx), + mlir_context), dimension_ranges, {}); int rows_per_warp = GetRowsPerWarp(); if (rows_per_warp > 1) { @@ -367,20 +369,21 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( return ComposeIndexingMaps( linear_index, GetBitcastMap(ShapeUtil::MakeShape( PRED, {tiling_.GetShape()[kRowKept]}), - physical_shape, ctx)); + physical_shape, indexing_context)); } IndexingMap projected_index( + indexing_context, mlir::AffineMap::get( 6, 0, {block_offsets.getResult(kColMajorKept), block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]}, - ctx), + mlir_context), dimension_ranges, {}); projected_index.AddConstraint( mlir::getAffineDimExpr( - KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) % + KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context) % WarpSize(), {0, 0}); if (!is_row_reduction_) { @@ -395,24 +398,25 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( GetBitcastMap(ShapeUtil::DeleteDimension( ReductionDimensions::kColReducedDimension, tiling_.GetXlaShape()), - physical_shape, ctx)); + physical_shape, indexing_context)); }(); auto map = ComposeIndexingMaps( physical_index, - GetBitcastMap(FirstShape(hero->shape()), FirstShape(root->shape()), ctx)); + GetBitcastMap(FirstShape(hero->shape()), FirstShape(root->shape()), + indexing_context)); int group_index = groups_.group_id_per_root[root_index]; map.AddConstraint( mlir::getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[1], - ctx), + mlir_context), {group_index, group_index}); return map; } std::optional ReductionInfo::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { auto* hero = analysis_.fusion_heroes()[root_index]; if (groups_.is_reduction_root[root_index] && hero_operand_index >= hero->operand_count() / 2) { @@ -421,15 +425,16 @@ std::optional ReductionInfo::ComputeThreadIdToInputIndexing( } auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), + GetIndexingMapForTiling(tiling_, indexing_context), GetBitcastMap(tiling_.GetXlaShape(), - hero->operand(hero_operand_index)->shape(), ctx)); + hero->operand(hero_operand_index)->shape(), + indexing_context)); // Only threads with the right y block index actually do anything for this // root. int group_index = groups_.group_id_per_root[root_index]; map.AddConstraint( mlir::getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[1], - ctx), + indexing_context->GetMLIRContext()), {group_index, group_index}); return map; } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.h b/third_party/xla/xla/service/gpu/fusions/reduction_base.h index 93c2ecc2681f83..89442524b7e058 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.h @@ -57,11 +57,11 @@ class ReductionInfo { int GetRowsPerWarp() const; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const; + int64_t root_index, IndexingContext* indexing_context) const; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const; + IndexingContext* indexing_context) const; LaunchDimensions launch_dimensions() const; @@ -93,15 +93,16 @@ class ReductionFusionBase : public Base { : analysis_(analysis), reduction_info_(ReductionInfo::Create(analysis)) {} std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { - return reduction_info().ComputeThreadIdToOutputIndexing(root_index, ctx); + int64_t root_index, IndexingContext* indexing_context) const override { + return reduction_info().ComputeThreadIdToOutputIndexing(root_index, + indexing_context); } std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { return reduction_info().ComputeThreadIdToInputIndexing( - root_index, hero_operand_index, ctx); + root_index, hero_operand_index, indexing_context); } LaunchDimensions launch_dimensions() const override { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc index 2c4ffa0e9ce078..6b7e8dcc2c4f42 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -35,9 +36,14 @@ namespace gpu { namespace { class ReductionTest : public HloTestBase { + public: + ReductionTest() : indexing_context_(&mlir_context_) {} + protected: stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; }; class FakeReductionFusion : public ReductionFusionBase { @@ -78,11 +84,10 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); FakeReductionFusion fusion(analysis); - mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( (d3 * 8 + d0 floordiv 32) floordiv 64, (d3 * 8 + d0 floordiv 32) mod 64, @@ -103,7 +108,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { d3 * 8 + d0 floordiv 32 in [0, 6399] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( (d3 * 8 + d0 floordiv 32) floordiv 64, @@ -147,11 +152,10 @@ TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); FakeReductionFusion fusion(analysis); - mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 + (d0 floordiv 4) floordiv 64, (d0 floordiv 4) mod 64, @@ -172,7 +176,7 @@ TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { d3 * 64 + d0 floordiv 4 in [0, 6399] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( d3 + (d0 floordiv 4) floordiv 64, @@ -217,11 +221,10 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); FakeReductionFusion fusion(analysis); - mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3, d0 floordiv 32 + s1 * 32, @@ -235,7 +238,7 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) { d0 mod 32 in [0, 31] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( d3, @@ -273,10 +276,9 @@ TEST_F(ReductionTest, ThreadIndexingOutputLayout) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); FakeReductionFusion fusion(analysis); - mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( (d3 * 8 + d0 floordiv 32) floordiv 64, @@ -322,7 +324,6 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { auto* root = module->entry_computation()->root_instruction(); auto analysis = AnalyzeFusion(*root, device_info_); FakeReductionFusion fusion(analysis); - mlir::MLIRContext mlir_context; constexpr char kExpectedIndexing[] = R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( @@ -344,11 +345,11 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { d0 mod 32 + s2 * 32 in [0, 511] d3 * 8 + d0 floordiv 32 in [0, 6399] )"; + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(1, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(kExpectedIndexing)); EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), - MatchIndexingString(kExpectedIndexing)); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(1, &indexing_context_)->ToString(), MatchIndexingString(kExpectedIndexing)); } @@ -377,9 +378,9 @@ TEST_F(ReductionTest, bla) { FakeReductionFusion fusion(analysis); mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( d3, (d0 + s2 * 512) * 2 + s3 diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index 1ff23dddcf51ba..c0e500803c0b46 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -147,7 +147,9 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const { int num_warps_row = tiling.GetThreadsPerBlock() [ReductionDimensions::kRowMinorReducedDimension] / WarpSize(); - auto ctx = state.entry_function.getContext(); + + auto* mlir_context = state.entry_function.getContext(); + IndexingContext indexing_context(mlir_context); auto zero = builder.create(0); auto lane_id = builder.create(); @@ -161,10 +163,10 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const { auto thread_ids = mlir_converter::ApplyAffineMap( mlir::AffineMap::get( /*dimCount=*/1, /*symbolCount=*/0, - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, mlir_context), tiling.GetThreadsPerBlock(), tiling.GetThreadStrides()), - ctx), + mlir_context), {thread_id}, {}, builder); SmallVector thread_and_block_indices{thread_id, zero, zero, block_id, zero, zero}; @@ -200,7 +202,7 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const { } bool use_shared = !shared_tile_size.empty(); - auto output_indexing = ComputeThreadIdToOutputIndexing(0, ctx); + auto output_indexing = ComputeThreadIdToOutputIndexing(0, &indexing_context); auto output_indices = mlir_converter::ApplyAffineMap( output_indexing->GetAffineMap(), thread_and_block_indices, {}, builder); auto thread_has_output = mlir_converter::CheckConstraints( @@ -236,7 +238,7 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const { SmallVector> results; for (auto* hero : reduction_heroes_) { auto input_indexing = ComputeThreadIdToInputIndexing( - reduction_roots_.at(hero).front(), 0, ctx); + reduction_roots_.at(hero).front(), 0, &indexing_context); TF_ASSIGN_OR_RETURN( auto accumulated, state.EmitPerThreadReducedElements(*input_indexing, hero, inits[hero])); diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.h b/third_party/xla/xla/service/gpu/fusions/scatter.h index 6982bbc8e6bd2c..6b0e2c5fe81eb9 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter.h +++ b/third_party/xla/xla/service/gpu/fusions/scatter.h @@ -44,7 +44,7 @@ class ScatterFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { + int64_t root_index, IndexingContext* indexing_context) const override { // The kernel iterates over updates, whose correspondence to output // elements cannot be computed statically. return std::nullopt; @@ -52,7 +52,7 @@ class ScatterFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { // TODO(b/319081342): Implement this. return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index 85242c0740e7b6..2f4a3e0af8ce5d 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -83,13 +83,13 @@ bool MlirScatterFusion::IsSupported(const HloFusionAnalysis& analysis) { } std::optional MlirScatterFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { return std::nullopt; } std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { auto* scatter = DynCast(analysis_.fusion_heroes().front()); int64_t scatter_operand_count = scatter->scatter_operand_count(); @@ -106,7 +106,8 @@ std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( // Compute thread id mapping based on the first update operand. Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); IndexingMap scatter_update_map = GetDefaultThreadIdToOutputIndexingMap( - launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx); + launch_dimensions(), config_.unroll_factor, scatter_update_shape, + indexing_context); // For scatter indices we project indexing for scatter updates and take the // first result of the affine map only, because they coincide. @@ -114,11 +115,14 @@ std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( Shape scatter_indices_shape = scatter->scatter_indices()->shape(); CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString(); // Create a map from scatter update to scatter indices. + auto* mlir_context = indexing_context->GetMLIRContext(); IndexingMap updates_to_indices_map{ + indexing_context, mlir::AffineMap::get( /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1, - {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)}, - ctx), + {mlir::getAffineDimExpr(0, mlir_context), + mlir::getAffineSymbolExpr(0, mlir_context)}, + mlir_context), /*dim_ranges=*/RangesFromTensorSizes(scatter_update_shape.dimensions()), /*symbol_ranges=*/ RangesFromTensorSizes({scatter_indices_shape.dimensions(1)})}; @@ -185,10 +189,11 @@ absl::Status MlirScatterFusion::EmitEntryFunction( const HloInstruction* scatter_update = scatter->operand(kScatterUpdateIndex); mlir::MLIRContext* mlir_context = entry_function.getContext(); + IndexingContext indexing_context{mlir_context}; auto thread_id_to_update_map = ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/kScatterUpdateIndex, - mlir_context) + &indexing_context) .value(); thread_id_to_update_map.Simplify(); thread_id_to_update_map.RemoveUnusedSymbols(); diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h index e66e2c6a4f5a78..016a67c7c512fd 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h @@ -42,11 +42,11 @@ class MlirScatterFusion : public MlirFusionEmitterBase { static bool IsSupported(const HloFusionAnalysis& analysis); std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + IndexingContext* indexing_context) const override; protected: absl::Status EmitEntryFunction( diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc index f7fdba3b97db30..dd868683d745bf 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -97,25 +97,25 @@ TEST_F(MlirScatterFusionTest, ThreadId_IndexingUnrolled) { EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) + /*root_index=*/0, /*hero_operand_index=*/3, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) + /*root_index=*/0, /*hero_operand_index=*/4, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) + /*root_index=*/1, /*hero_operand_index=*/3, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) + /*root_index=*/1, /*hero_operand_index=*/4, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kUpdatesIndexing)); @@ -137,13 +137,13 @@ TEST_F(MlirScatterFusionTest, ThreadId_IndexingUnrolled) { EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) + /*root_index=*/0, /*hero_operand_index=*/2, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kIndicesIndexing)); EXPECT_THAT( fusion .ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) + /*root_index=*/1, /*hero_operand_index=*/2, &indexing_context_) ->ToString(thread_id_printer_), MatchIndexingString(kIndicesIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc index 99f113cbafbea7..fbce46e7b82665 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose.cc @@ -284,19 +284,20 @@ LaunchDimensions TransposeFusion::launch_dimensions() const { } std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { + auto* mlir_context = indexing_context->GetMLIRContext(); const auto& hero = *analysis_.fusion_heroes()[root_index]; const auto& root = *analysis_.fusion_roots()[root_index]; if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { // Non-transpose roots are elementwise by definition. - return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + return ComputeThreadIdToInputIndexing(root_index, 0, indexing_context); } // The block offsets are permuted, but the thread offsets remain the same. - auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) + auto block_offset = GetBlockOffsetsForTiling(tiling_, mlir_context) .getSubMap(std::vector{permutation_.begin(), permutation_.end()}); - auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx); + auto thread_offset = GetThreadOffsetsForTiling(tiling_, mlir_context); auto permuted_tiled_shape = ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); @@ -304,20 +305,21 @@ std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( GetIndexingMapForTiling( block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(), tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), - permuted_tiled_shape.dimensions()), - GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); + permuted_tiled_shape.dimensions(), indexing_context), + GetBitcastMap(permuted_tiled_shape, hero.shape(), indexing_context)); map.Simplify(); return map; } std::optional TransposeFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { + IndexingContext* indexing_context) const { const auto& hero = *analysis_.fusion_heroes()[root_index]; auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); + GetIndexingMapForTiling(tiling_, indexing_context), + GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), + indexing_context)); map.Simplify(); return map; } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.h b/third_party/xla/xla/service/gpu/fusions/transpose.h index 899b1cb94390ae..d45cf15c762561 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose.h @@ -64,11 +64,11 @@ class TransposeFusion : public KernelFusionEmitterBase { LaunchDimensions launch_dimensions() const override; std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; + IndexingContext* indexing_context) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 8f3f4ef37480b4..4b8a2af5661935 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -146,23 +146,24 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) } std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, MLIRContext* ctx) const { + int64_t root_index, IndexingContext* indexing_context) const { const auto& hero = *analysis_.fusion_heroes()[root_index]; const auto& root = *analysis_.fusion_roots()[root_index]; if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { // Non-transpose roots are elementwise by definition. - return ComputeThreadIdToInputIndexing(root_index, 0, ctx); + return ComputeThreadIdToInputIndexing(root_index, 0, indexing_context); } - return ComputeThreadIdToOutputIndexing(hero, ctx); + return ComputeThreadIdToOutputIndexing(hero, indexing_context); } IndexingMap MlirTransposeFusion::ComputeThreadIdToOutputIndexing( - const HloInstruction& hero, MLIRContext* ctx) const { + const HloInstruction& hero, IndexingContext* indexing_context) const { // The block offsets are permuted, but the thread offsets remain the same. - auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) + auto* mlir_context = indexing_context->GetMLIRContext(); + auto block_offset = GetBlockOffsetsForTiling(tiling_, mlir_context) .getSubMap(std::vector{permutation_.begin(), permutation_.end()}); - auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx); + auto thread_offset = GetThreadOffsetsForTiling(tiling_, mlir_context); auto permuted_tiled_shape = ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); @@ -170,17 +171,18 @@ IndexingMap MlirTransposeFusion::ComputeThreadIdToOutputIndexing( GetIndexingMapForTiling( block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(), tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), - permuted_tiled_shape.dimensions()), - GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); + permuted_tiled_shape.dimensions(), indexing_context), + GetBitcastMap(permuted_tiled_shape, hero.shape(), indexing_context)); map.Simplify(); return map; } IndexingMap MlirTransposeFusion::ComputeThreadIdToInputIndexing( - const HloInstruction& hero, MLIRContext* ctx) const { + const HloInstruction& hero, IndexingContext* indexing_context) const { auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); + GetIndexingMapForTiling(tiling_, indexing_context), + GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), + indexing_context)); map.Simplify(); return map; } @@ -194,6 +196,7 @@ LaunchDimensions MlirTransposeFusion::launch_dimensions() const { IndexingMap GetSharedMemoryWriteIndexingMap( const IndexingMap& thread_id_indexing, int loop_dim) { auto* mlir_context = thread_id_indexing.GetMLIRContext(); + IndexingContext indexing_context{mlir_context}; AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); AffineExpr th_x = mlir::getAffineDimExpr(0, mlir_context); @@ -201,6 +204,7 @@ IndexingMap GetSharedMemoryWriteIndexingMap( mlir::bindSymbolsList(mlir_context, llvm::MutableArrayRef(tile_sizes)); IndexingMap shmem_write_indexing{ + &indexing_context, AffineMap::get( thread_id_indexing.GetDimensionCount(), thread_id_indexing.GetSymbolCount(), @@ -219,7 +223,8 @@ IndexingMap GetSharedMemoryReadIndexingMap( const IndexingMap& thread_id_indexing, int loop_dim) { IndexingMap write_indexing = GetSharedMemoryWriteIndexingMap(thread_id_indexing, loop_dim); - return IndexingMap{write_indexing.GetAffineMap().getSubMap({0, 2, 1}), + return IndexingMap{thread_id_indexing.GetIndexingContext(), + write_indexing.GetAffineMap().getSubMap({0, 2, 1}), write_indexing.GetDimensionRanges(), write_indexing.GetSymbolRanges(), write_indexing.GetConstraints()}; @@ -236,10 +241,11 @@ absl::StatusOr> MlirTransposeFusion::EmitWriteToShMemMlir( int num_inputs = fusion.fused_instructions_computation()->num_parameters(); int num_outputs = entry_function.getArguments().size() - num_inputs; + IndexingContext indexing_context{builder.getContext()}; SmallVector shmem_intermediate_result; for (auto* transpose : shmem_transposes_) { auto input_indexing = - ComputeThreadIdToInputIndexing(*transpose, builder.getContext()); + ComputeThreadIdToInputIndexing(*transpose, &indexing_context); IndexingMap shmem_input_indexing = GetSharedMemoryWriteIndexingMap(input_indexing, permutation_[2]); @@ -288,15 +294,16 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( const mlir_converter::PartitionedComputations& computations, const CallTargetProvider& call_targets, ValueRange shmem_tensors) const { int num_inputs = fusion.fused_instructions_computation()->num_parameters(); - + auto* mlir_context = builder.getContext(); + IndexingContext indexing_context{mlir_context}; ValueRange output_tensor_args = entry_function.getArguments().drop_front(num_inputs); auto output_indexing = ComputeThreadIdToOutputIndexing( - *shmem_transposes_.front(), builder.getContext()); + *shmem_transposes_.front(), &indexing_context); auto shmem_output_indexing = GetSharedMemoryReadIndexingMap(output_indexing, permutation_[2]); auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( - shmem_transposes_.front(), builder.getContext()); + shmem_transposes_.front(), &indexing_context); auto root_indexing = ComposeIndexingMaps(output_indexing, epilogue_indexing); auto result_tensors = EmitThreadLoopNest( builder, output_tensor_args, output_indexing, diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 58c8d6265ae838..fd9f5863e8260e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -54,20 +54,20 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { static bool IsSupported(const HloFusionAnalysis& analysis); std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { + IndexingContext* indexing_context) const override { return ComputeThreadIdToInputIndexing( - *analysis_.fusion_heroes()[root_index], ctx); + *analysis_.fusion_heroes()[root_index], indexing_context); } protected: - IndexingMap ComputeThreadIdToInputIndexing(const HloInstruction& hero, - mlir::MLIRContext* ctx) const; - IndexingMap ComputeThreadIdToOutputIndexing(const HloInstruction& hero, - mlir::MLIRContext* ctx) const; + IndexingMap ComputeThreadIdToInputIndexing( + const HloInstruction& hero, IndexingContext* indexing_context) const; + IndexingMap ComputeThreadIdToOutputIndexing( + const HloInstruction& hero, IndexingContext* indexing_context) const; absl::Status EmitEntryFunction( const mlir_converter::PartitionedComputations& computations, diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index e1d64067afb90a..38fe0789b8eadf 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -46,9 +46,9 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { auto analysis = AnalyzeFusion(*root, device_info_); MlirTransposeFusion fusion(analysis); - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, d0 floordiv 32 + s1 * 4, @@ -67,7 +67,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { s2 in [0, 0] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, @@ -105,9 +105,9 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { auto analysis = AnalyzeFusion(*root, device_info_); MlirTransposeFusion fusion(analysis); - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion.ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64, @@ -126,7 +126,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { s2 in [0, 0] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + fusion.ComputeThreadIdToOutputIndexing(0, &indexing_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s1 * 4, diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc index d7363bbd39f382..94d3df1898ad3b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc @@ -37,9 +37,14 @@ namespace { using ::testing::HasSubstr; class TransposeTest : public HloTestBase { + public: + TransposeTest() : indexing_context_(&mlir_context_) {} + protected: stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; }; absl::StatusOr> GetTransposeFusion( @@ -74,9 +79,9 @@ TEST_F(TransposeTest, ThreadIndexing021) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, d0 floordiv 32 + s1 * 4, @@ -94,9 +99,9 @@ TEST_F(TransposeTest, ThreadIndexing021) { s1 in [0, 7] s2 in [0, 0] )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToOutputIndexing(0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, d0 floordiv 32 + (d3 mod 2) * 32 + s1 * 4, @@ -136,9 +141,9 @@ TEST_F(TransposeTest, ThreadIndexing201) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, d0 floordiv 32 + (d3 * 32 + s1 * 4) mod 64, @@ -156,9 +161,9 @@ TEST_F(TransposeTest, ThreadIndexing201) { s1 in [0, 7] s2 in [0, 0] )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToOutputIndexing(0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s1 * 4, d3 floordiv 2, @@ -200,9 +205,9 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToInputIndexing(0, 0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s0 * 4, d3, @@ -222,9 +227,9 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d0 floordiv 32 + s0 * 4 in [0, 23] d0 mod 32 in [0, 23] )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( + EXPECT_THAT(fusion->ComputeThreadIdToOutputIndexing(0, &indexing_context_) + ->ToString(), + MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( s0, d0 floordiv 32, diff --git a/third_party/xla/xla/service/gpu/ir_emitter_context.h b/third_party/xla/xla/service/gpu/ir_emitter_context.h index cc79e4cd3c8266..2ae9a636d7fcc3 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_context.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_context.h @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernel_reuse_cache.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/name_uniquer.h" #include "xla/stream_executor/device_description.h" @@ -69,6 +70,7 @@ class IrEmitterContext { platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), mlir_context_(mlir_context), + indexing_context_(mlir_context_), llvm_module_(llvm_module), emit_kernels_(emit_kernels) {} // Disallow copy and assign. @@ -98,6 +100,7 @@ class IrEmitterContext { return cc != nullptr ? *cc : se::RocmComputeCapability(); } mlir::MLIRContext* mlir_context() { return mlir_context_; } + IndexingContext* indexing_context() { return &indexing_context_; } llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } @@ -126,6 +129,7 @@ class IrEmitterContext { std::string platform_name_; const se::DeviceDescription& gpu_device_info_; mlir::MLIRContext* mlir_context_; + IndexingContext indexing_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; std::vector constants_; diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 84e8e5001e101b..a3b08d3fc5b971 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -418,10 +418,18 @@ xla_cc_test( cc_library( name = "indexing_map", - srcs = ["indexing_map.cc"], - hdrs = ["indexing_map.h"], + srcs = [ + "indexing_context.cc", + "indexing_map.cc", + ], + hdrs = [ + "indexing_context.h", + "indexing_map.h", + ], deps = [ ":affine_map_printer", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -576,6 +584,7 @@ xla_cc_test( srcs = ["coalescing_analysis_test.cc"], deps = [ ":coalescing_analysis", + ":indexing_map", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 2eed7c5ad26826..c697fb752f9b90 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -223,11 +223,14 @@ bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map, if (thread_id_to_input_indexing_map.GetAffineMap().getNumResults() == 0) { return true; } - MLIRContext* mlir_context = thread_id_to_input_indexing_map.GetMLIRContext(); + IndexingContext* indexing_context = + thread_id_to_input_indexing_map.GetIndexingContext(); + mlir::MLIRContext* mlir_context = indexing_context->GetMLIRContext(); AffineExpr thread_x_dim = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapThreadIdxDims[0], mlir_context); AffineExpr c0 = mlir::getAffineConstantExpr(0, mlir_context); IndexingMap thread_x_first_32_elements{ + indexing_context, AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), {Interval{0, 31}}, {}}; @@ -257,7 +260,8 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( const HloFusionAdaptor& fusion_adaptor, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + KernelFusionInterface* fusion_interface, + IndexingContext* indexing_context) { GroupedByOpIndexingMap result; for (const auto& [root_index, hero] : llvm::enumerate(fusion_analysis.fusion_heroes())) { @@ -269,7 +273,7 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( // Compute thread ID -> hero operand indexing map. std::optional thread_id_to_hero_operand_map = fusion_interface->ComputeThreadIdToInputIndexing( - root_index, hero_operand_index, mlir_context); + root_index, hero_operand_index, indexing_context); if (!thread_id_to_hero_operand_map.has_value()) { return std::nullopt; } @@ -277,7 +281,7 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( HloInstructionAdaptor hero_operand_adaptor(*hero_operand); GroupedByOpIndexingMap instr_indexing_keyed_by_operands = ComputeGroupedOutputToInputIndexing( - fusion_adaptor, hero_operand_adaptor, mlir_context); + fusion_adaptor, hero_operand_adaptor, indexing_context); // For every operand compute thread ID -> physical layout of operand // indexing map. for (const HloInstruction* operand : operands) { @@ -291,11 +295,11 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( IndexingMap operand_logical_to_physical_map = GetIndexingMapFromLogicalToPhysicalLayout(operand_shape, - mlir_context); + indexing_context); IndexingMap operand_physical_to_linearized_shape = GetBitcastMap( ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( operand_shape), - GetLinearizedShape(operand_shape), mlir_context); + GetLinearizedShape(operand_shape), indexing_context); IndexingMap operand_logical_to_linearized_physical_shape = operand_logical_to_physical_map * operand_physical_to_linearized_shape; @@ -330,12 +334,12 @@ CoalescingAnalysis::CoalescingAnalysis( const HloInstruction* instr, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + KernelFusionInterface* fusion_interface, IndexingContext* indexing_context, bool use_heuristic) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(instr); if (!use_heuristic && ComputeCoalescingForAllOperands( *fusion_adaptor, operands, fusion_analysis, - fusion_interface, mlir_context)) { + fusion_interface, indexing_context)) { return; } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. @@ -347,12 +351,12 @@ CoalescingAnalysis::CoalescingAnalysis( const HloInstruction* producer, const HloInstruction* consumer, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context, + KernelFusionInterface* fusion_interface, IndexingContext* indexing_context, bool use_heuristic) { ProducerConsumerFusion fusion_adaptor(producer, consumer); if (!use_heuristic && ComputeCoalescingForAllOperands(fusion_adaptor, operands, fusion_analysis, - fusion_interface, mlir_context)) { + fusion_interface, indexing_context)) { return; } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. @@ -364,11 +368,12 @@ bool CoalescingAnalysis::ComputeCoalescingForAllOperands( const HloFusionAdaptor& fusion_adaptor, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context) { + KernelFusionInterface* fusion_interface, + IndexingContext* indexing_context) { std::optional thread_id_to_input_memory_layouts = GetThreadIdToInputMemoryLayoutsMaps(fusion_adaptor, operands, fusion_analysis, fusion_interface, - mlir_context); + indexing_context); if (!thread_id_to_input_memory_layouts.has_value()) { return false; } diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index 300036aa453bae..86e93dcad69d3b 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -38,7 +38,7 @@ class CoalescingAnalysis { absl::Span operands, const HloFusionAnalysis& fusion_analysis, KernelFusionInterface* fusion_interface = nullptr, - mlir::MLIRContext* mlir_context = nullptr, + IndexingContext* indexing_context = nullptr, bool use_heuristic = true); // Computes read coalescing for operands of fused `producer` and `consumer`. @@ -47,7 +47,7 @@ class CoalescingAnalysis { absl::Span operands, const HloFusionAnalysis& fusion_analysis, KernelFusionInterface* fusion_interface = nullptr, - mlir::MLIRContext* mlir_context = nullptr, + IndexingContext* indexing_context = nullptr, bool use_heuristic = true); // Returns true if the operand is read coalesced. @@ -58,7 +58,8 @@ class CoalescingAnalysis { const HloFusionAdaptor& fusion_adaptor, absl::Span operands, const HloFusionAnalysis& fusion_analysis, - KernelFusionInterface* fusion_interface, mlir::MLIRContext* mlir_context); + KernelFusionInterface* fusion_interface, + IndexingContext* indexing_context = nullptr); absl::flat_hash_map coalescing_per_operand_; bool is_coalesced_computed_by_heuristic_ = false; diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 18a69aa6bf404b..5a788bb1e0fee1 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -44,6 +45,8 @@ using ::testing::ElementsAre; class CoalescingTest : public HloTestBase { public: + CoalescingTest() : indexing_context_(&mlir_context_) {} + std::vector IsReadCoalescedPerOperand(absl::string_view hlo_string) { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -58,7 +61,7 @@ class CoalescingTest : public HloTestBase { EXPECT_TRUE(emitter.ok()); CoalescingAnalysis coalescing_analysis(root, root->operands(), analysis, - fusion, &mlir_context_, + fusion, &indexing_context_, /*use_heuristic=*/false); std::vector results; @@ -80,6 +83,7 @@ class CoalescingTest : public HloTestBase { stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; }; TEST_F(CoalescingTest, IdentityLayout) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 58325183af50e6..7d8802568a5e2d 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -123,7 +123,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( // operands. For each instruction, tells which elements of the instructions // result will be used to compute one result element of the fusion. auto grouped_fusion_indexing = ComputeGroupedOutputToInputIndexing( - fusion_adaptor, roots[0], mlir_context_); + fusion_adaptor, roots[0], &indexing_context_); int64_t flops = 0; int64_t bytes_read = 0; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 14d7e520a820d3..0f2b66eef4ca07 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" @@ -42,7 +43,8 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { : hlo_op_profile_(&HloOpProfiles::Singleton().GetProfile(device_info)), device_info_(device_info), shape_size_(shape_size), - mlir_context_(mlir_context) {} + mlir_context_(mlir_context), + indexing_context_(mlir_context_) {} EstimateRunTimeData EstimateRunTimeForFusion( const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true); @@ -68,6 +70,7 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { const se::DeviceDescription* device_info_; HloCostAnalysis::ShapeSizeFunction shape_size_; mlir::MLIRContext* mlir_context_; + IndexingContext indexing_context_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index a6a14c28ca8161..cc2cc9f2b83519 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -77,22 +78,27 @@ HloInstructionIndexing CreateUnknownIndexing(int64_t count = 1) { return indexing; } -IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* ctx) { +IndexingMap CreateIdentityMap(const Shape& shape, + IndexingContext* indexing_context) { if (shape.IsTuple()) { // Should happen only for variadic reduce. In that case all tuple shapes are // equal. - return CreateIdentityMap(shape.tuple_shapes(0), ctx); + return CreateIdentityMap(shape.tuple_shapes(0), indexing_context); } auto dims = shape.dimensions(); IndexingMap identity_map = IndexingMap::FromTensorSizes( - AffineMap::getMultiDimIdentityMap(dims.size(), ctx), dims, {}); + indexing_context, + AffineMap::getMultiDimIdentityMap(dims.size(), + indexing_context->GetMLIRContext()), + dims, {}); return identity_map; } HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( - const HloInstruction* instr, MLIRContext* mlir_context) { - IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); + const HloInstruction* instr, IndexingContext* indexing_context) { + IndexingMap identity_map = + CreateIdentityMap(instr->shape(), indexing_context); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(instr->operand_count()); @@ -104,21 +110,24 @@ HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( } HloInstructionIndexing ComputeInputToOutputCwiseOpIndexing( - const HloInstruction* instr, MLIRContext* mlir_context) { - IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); + const HloInstruction* instr, IndexingContext* indexing_context) { + IndexingMap identity_map = + CreateIdentityMap(instr->shape(), indexing_context); return HloInstructionIndexing::FromIndexingMaps({identity_map}); } HloInstructionIndexing ComputeOutputToInputBroadcastOpIndexing( - const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { + const HloBroadcastInstruction* bcast, IndexingContext* indexing_context) { auto output_dims = bcast->shape().dimensions(); + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); std::vector exprs; exprs.reserve(bcast->dimensions().size()); for (int64_t bcast_dim : bcast->dimensions()) { exprs.push_back(getAffineDimExpr(bcast_dim, mlir_context)); } IndexingMap indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), output_dims, {}); @@ -126,7 +135,9 @@ HloInstructionIndexing ComputeOutputToInputBroadcastOpIndexing( } HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( - const HloBroadcastInstruction* bcast, MLIRContext* mlir_context) { + const HloBroadcastInstruction* bcast, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + absl::Span bcast_dims = bcast->dimensions(); const Shape& input_shape = bcast->operand(0)->shape(); @@ -149,6 +160,7 @@ HloInstructionIndexing ComputeInputToOutputBroadcastOpIndexing( std::distance(bcast_dims.begin(), bcast_dim), mlir_context)); } IndexingMap indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(input_shape.rank(), added_dims_sizes.size(), exprs, mlir_context), input_shape.dimensions(), added_dims_sizes); @@ -166,7 +178,10 @@ std::vector RangesFromUpperBounds(absl::Span bounds) { } HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( - const HloConcatenateInstruction* concat, MLIRContext* mlir_context) { + const HloConcatenateInstruction* concat, + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + const auto& operand_0_dims = concat->operand(0)->shape().dimensions(); // Initialize affine map and domain. Only concat_dim elements of both have to @@ -185,7 +200,7 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( int64_t operand_concat_dim = operand->shape().dimensions()[concat_dim]; dim_ranges[concat_dim] = Interval{offset, offset + operand_concat_dim - 1}; concat_indexing.indexing_maps[operand_id].insert( - IndexingMap(affine_map.getAffineMap(), dim_ranges, + IndexingMap(indexing_context, affine_map.getAffineMap(), dim_ranges, /*symbol_ranges=*/{})); offset += operand_concat_dim; } @@ -194,7 +209,9 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( HloInstructionIndexing ComputeInputToOutputConcatenateOpIndexing( const HloConcatenateInstruction* concat, int input_id, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + int64_t concat_dim = concat->concatenate_dimension(); int64_t offset = 0; for (int64_t operand_id = 0; operand_id < input_id; ++operand_id) { @@ -207,8 +224,8 @@ HloInstructionIndexing ComputeInputToOutputConcatenateOpIndexing( AffineMap::getMultiDimIdentityMap(operand_dims.size(), mlir_context); affine_map.setResult(concat_dim, getAffineDimExpr(concat_dim, mlir_context) + offset); - IndexingMap indexing_map = - IndexingMap::FromTensorSizes(affine_map.getAffineMap(), operand_dims, {}); + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + indexing_context, affine_map.getAffineMap(), operand_dims, {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); } @@ -216,10 +233,10 @@ HloInstructionIndexing ComputeInputToOutputConcatenateOpIndexing( // until the HloParameterInstruction is found. HloInstructionIndexing ComputeOutputToInputFusionOpIndexing( const HloFusionInstruction* fusion, int output_id, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(fusion); auto grouped_indexing_maps = ComputeGroupedOutputToInputIndexing( - *fusion_adaptor, fusion_adaptor->GetRoots()[output_id], mlir_context); + *fusion_adaptor, fusion_adaptor->GetRoots()[output_id], indexing_context); // After the traversal, `grouped_indexing_maps` is keyed by // HloParameterInstructions. Convert them back to the operand id and return. @@ -232,7 +249,9 @@ HloInstructionIndexing ComputeOutputToInputFusionOpIndexing( } HloInstructionIndexing ComputeOutputToInputDotOpIndexing( - const HloDotInstruction* dot, MLIRContext* mlir_context) { + const HloDotInstruction* dot, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + CHECK_NE(dot, nullptr); const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); absl::Span lhs_contracting_dims( @@ -297,11 +316,13 @@ HloInstructionIndexing ComputeOutputToInputDotOpIndexing( } IndexingMap lhs_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), lhs_exprs, mlir_context), dot->shape().dimensions(), input_dim_sizes); IndexingMap rhs_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(dot->shape().rank(), input_dim_sizes.size(), rhs_exprs, mlir_context), dot->shape().dimensions(), input_dim_sizes); @@ -313,7 +334,10 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl( absl::Span output_dims, absl::Span padding_low, absl::Span padding_high, - absl::Span padding_interior, MLIRContext* mlir_context) { + absl::Span padding_interior, + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + int64_t output_rank = output_dims.size(); std::vector exprs; @@ -338,12 +362,15 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl( ++output_dim_id; } return IndexingMap{ + indexing_context, AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), dimension_ranges, /*symbol_ranges = */ {}, absl::MakeSpan(constraints)}; } HloInstructionIndexing ComputeOutputToInputPadOpIndexing( - const HloPadInstruction* pad, MLIRContext* mlir_context) { + const HloPadInstruction* pad, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + const Shape& output_shape = pad->shape(); int64_t rank = output_shape.rank(); SmallVector padding_low, padding_high, padding_interior; @@ -357,8 +384,9 @@ HloInstructionIndexing ComputeOutputToInputPadOpIndexing( } IndexingMap input_indexing_map = ComputeOutputToInputPadOpIndexingImpl( output_shape.dimensions(), padding_low, padding_high, padding_interior, - mlir_context); + indexing_context); IndexingMap padding_value_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), /*symbol_upper_bounds=*/{}); return HloInstructionIndexing::FromIndexingMaps( @@ -367,7 +395,9 @@ HloInstructionIndexing ComputeOutputToInputPadOpIndexing( HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( const HloReduceInstruction* reduce, int output_id, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), reduce->dimensions().end()); @@ -389,10 +419,12 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( exprs.push_back(getAffineDimExpr(output_dim_id++, mlir_context)); } IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_shape.rank(), reduce_dims_ids.size(), exprs, mlir_context), output_shape.dimensions(), parallel_dims_sizes); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), {}); @@ -409,7 +441,9 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( const HloReduceInstruction* reduce, int input_id, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + absl::flat_hash_set reduce_dims_ids(reduce->dimensions().begin(), reduce->dimensions().end()); const Shape& input_shape = reduce->operand(input_id)->shape(); @@ -429,10 +463,12 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( inits_exprs.push_back(getAffineSymbolExpr(output_dim_id++, mlir_context)); } IndexingMap inputs_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(input_shape.rank(), /*symbolCount=*/0, inputs_exprs, mlir_context), input_shape.dimensions(), {}); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(0, /*symbolCount=*/output_rank, inits_exprs, mlir_context), {}, output_shape.dimensions()); @@ -452,7 +488,9 @@ HloInstructionIndexing ComputeInputToOutputReduceOpIndexing( // of bounds. HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( const HloReduceWindowInstruction* reduce_window, int output_id, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + const Shape& input_shape = reduce_window->operand(0)->shape(); const Shape& output_shape = GetOutputShape(reduce_window, 0); int64_t rank = input_shape.rank(); @@ -492,11 +530,11 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for pad op that pads the input. IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl( padded_input_dimensions, padding_low, padding_high, padding_interior, - mlir_context); + indexing_context); // Indexing map for reduce-window, that does not do any padding. IndexingMap reduce_window_indexing_no_padding( - AffineMap::get(rank, rank, exprs, mlir_context), dim_ranges, - symbol_ranges); + indexing_context, AffineMap::get(rank, rank, exprs, mlir_context), + dim_ranges, symbol_ranges); // Composed indexing. IndexingMap inputs_indexing = ComposeIndexingMaps( @@ -506,6 +544,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for the init value. IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), output_shape.dimensions(), /*symbol_upper_bounds=*/{}); @@ -677,30 +716,35 @@ AffineMap ComputeReshapeIndexingMap(const Shape& input, const Shape& output, }; HloInstructionIndexing ComputeOutputToInputReshapeOpIndexing( - const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { + const HloReshapeInstruction* reshape, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + const auto& input = reshape->operand(0)->shape(); const auto& output = reshape->shape(); IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( - ComputeReshapeIndexingMap(input, output, mlir_context), + indexing_context, ComputeReshapeIndexingMap(input, output, mlir_context), output.dimensions(), {}); reshape_indexing_map.Simplify(); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } HloInstructionIndexing ComputeInputToOutputReshapeOpIndexing( - const HloReshapeInstruction* reshape, MLIRContext* mlir_context) { + const HloReshapeInstruction* reshape, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); + const auto& input = reshape->operand(0)->shape(); const auto& output = reshape->shape(); IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( - ComputeReshapeIndexingMap(output, input, mlir_context), + indexing_context, ComputeReshapeIndexingMap(output, input, mlir_context), input.dimensions(), {}); reshape_indexing_map.Simplify(); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } HloInstructionIndexing ComputeReverseOpIndexing( - const HloReverseInstruction* reverse, MLIRContext* mlir_context) { + const HloReverseInstruction* reverse, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); absl::flat_hash_set reverse_dims(reverse->dimensions().begin(), reverse->dimensions().end()); auto output_dims = reverse->shape().dimensions(); @@ -717,6 +761,7 @@ HloInstructionIndexing ComputeReverseOpIndexing( } IndexingMap indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_dims.size(), /*symbolCount=*/0, exprs, mlir_context), output_dims, {}); @@ -725,7 +770,8 @@ HloInstructionIndexing ComputeReverseOpIndexing( } HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( - const HloSliceInstruction* slice, MLIRContext* mlir_context) { + const HloSliceInstruction* slice, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); auto output_rank = slice->shape().rank(); std::vector exprs; @@ -736,6 +782,7 @@ HloInstructionIndexing ComputeOutputToInputSliceOpIndexing( slice->slice_starts()[dim]); } IndexingMap indexing_map = IndexingMap::FromTensorSizes( + indexing_context, AffineMap::get(output_rank, /*symbolCount=*/0, exprs, mlir_context), slice->shape().dimensions(), {}); return HloInstructionIndexing::FromIndexingMaps({indexing_map}); @@ -749,25 +796,31 @@ AffineMap ComputeTransposeIndexingMap(absl::Span permutation, } HloInstructionIndexing ComputeOutputToInputTransposeOpIndexing( - const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - AffineMap inverse_permutation = ComputeTransposeIndexingMap( - InversePermutation(transpose->dimensions()), mlir_context); - return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( - inverse_permutation, transpose->shape().dimensions(), {})}); + const HloTransposeInstruction* transpose, + IndexingContext* indexing_context) { + AffineMap inverse_permutation = + ComputeTransposeIndexingMap(InversePermutation(transpose->dimensions()), + indexing_context->GetMLIRContext()); + return HloInstructionIndexing::FromIndexingMaps( + {IndexingMap::FromTensorSizes(indexing_context, inverse_permutation, + transpose->shape().dimensions(), {})}); } HloInstructionIndexing ComputeInputToOutputTransposeOpIndexing( - const HloTransposeInstruction* transpose, MLIRContext* mlir_context) { - AffineMap forward_permutation = - ComputeTransposeIndexingMap(transpose->dimensions(), mlir_context); + const HloTransposeInstruction* transpose, + IndexingContext* indexing_context) { + AffineMap forward_permutation = ComputeTransposeIndexingMap( + transpose->dimensions(), indexing_context->GetMLIRContext()); return HloInstructionIndexing::FromIndexingMaps({IndexingMap::FromTensorSizes( - forward_permutation, transpose->operand(0)->shape().dimensions(), {})}); + indexing_context, forward_permutation, + transpose->operand(0)->shape().dimensions(), {})}); } } // namespace IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, - MLIRContext* ctx) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); ShapeUtil::BitcastDecomposition decomposed_bitcast = ShapeUtil::DecomposeBitcast(input_shape, output_shape); @@ -779,7 +832,8 @@ IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, << "Failed to deduce permutation for a bitcast."; return IndexingMap::FromTensorSizes( - ComputeTransposeIndexingMap(permutation.value(), ctx), + indexing_context, + ComputeTransposeIndexingMap(permutation.value(), mlir_context), input_shape.dimensions(), {}); } if (std::holds_alternative( @@ -787,38 +841,39 @@ IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, // Note: ComputeReshapeIndexingMap assumes it's computing an output->input // indexing, so input and output are reversed. return IndexingMap::FromTensorSizes( - ComputeReshapeIndexingMap(output_shape, input_shape, ctx), + indexing_context, + ComputeReshapeIndexingMap(output_shape, input_shape, mlir_context), input_shape.dimensions(), {}); } // `trt` stands for transpose-reshape-transpose decomposition of bitcast. auto trt = std::get(decomposed_bitcast); - auto transpose_map_1 = ComputeTransposeIndexingMap(trt.transpose1_dims, ctx); - auto reshape_map = - ComputeReshapeIndexingMap(trt.reshape_shape, trt.transpose1_shape, ctx); - auto transpose_map_2 = ComputeTransposeIndexingMap(trt.transpose2_dims, ctx); + auto transpose_map_1 = + ComputeTransposeIndexingMap(trt.transpose1_dims, mlir_context); + auto reshape_map = ComputeReshapeIndexingMap( + trt.reshape_shape, trt.transpose1_shape, mlir_context); + auto transpose_map_2 = + ComputeTransposeIndexingMap(trt.transpose2_dims, mlir_context); auto bitcast_map = transpose_map_2.compose(reshape_map).compose(transpose_map_1); - return IndexingMap::FromTensorSizes(bitcast_map, input_shape.dimensions(), - {}); + return IndexingMap::FromTensorSizes(indexing_context, bitcast_map, + input_shape.dimensions(), {}); } namespace { HloInstructionIndexing ComputeOutputToInputBitcastOpIndexing( - const HloInstruction* bitcast, MLIRContext* mlir_context) { - auto bitcast_map = GetBitcastMap(bitcast->shape(), - bitcast->operand(0)->shape(), mlir_context); + const HloInstruction* bitcast, IndexingContext* indexing_context) { + auto bitcast_map = GetBitcastMap( + bitcast->shape(), bitcast->operand(0)->shape(), indexing_context); bitcast_map.Simplify(); - return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } HloInstructionIndexing ComputeInputToOutputBitcastOpIndexing( - const HloInstruction* bitcast, MLIRContext* mlir_context) { + const HloInstruction* bitcast, IndexingContext* indexing_context) { auto bitcast_map = GetBitcastMap(bitcast->operand(0)->shape(), - bitcast->shape(), mlir_context); + bitcast->shape(), indexing_context); bitcast_map.Simplify(); - return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } @@ -867,32 +922,38 @@ llvm::SmallVector DelinearizeInBoundsIndex( return result; } -IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(const Shape& shape, - MLIRContext* ctx) { +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( + const Shape& shape, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); if (shape.rank() == 0) { - return IndexingMap(AffineMap::get(ctx), {}, {}); + return IndexingMap(indexing_context, AffineMap::get(mlir_context), {}, {}); } return IndexingMap::FromTensorSizes( + indexing_context, ComputeTransposeIndexingMap( - InversePermutation(ToTransposeDimensions(shape.layout())), ctx), + InversePermutation(ToTransposeDimensions(shape.layout())), + mlir_context), ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape) .dimensions(), {}); } -IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(const Shape& shape, - MLIRContext* ctx) { +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( + const Shape& shape, IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); if (shape.rank() == 0) { - return IndexingMap(AffineMap::get(ctx), {}, {}); + return IndexingMap(indexing_context, AffineMap::get(mlir_context), {}, {}); } return IndexingMap::FromTensorSizes( - ComputeTransposeIndexingMap(ToTransposeDimensions(shape.layout()), ctx), + indexing_context, + ComputeTransposeIndexingMap(ToTransposeDimensions(shape.layout()), + mlir_context), shape.dimensions(), {}); } AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx) { - auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(3, ctx), + MLIRContext* mlir_context) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), tiling.GetBlockCounts(), tiling.GetBlockStrides()); for (auto&& [offset, tile_size] : @@ -903,13 +964,13 @@ AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, } AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx) { - auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), + MLIRContext* mlir_context) { + auto offsets = DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), tiling.GetThreadsPerBlock(), tiling.GetThreadStrides()); for (int dim = 0; dim < tiling.GetShape().size(); ++dim) { if (tiling.GetThreadTileSize()[dim] > 1) { - offsets[dim] = offsets[dim] + getAffineSymbolExpr(dim, ctx) * + offsets[dim] = offsets[dim] + getAffineSymbolExpr(dim, mlir_context) * tiling.GetThreadsPerBlock()[dim]; } } @@ -917,11 +978,13 @@ AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, } IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx) { + IndexingContext* indexing_context) { + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); return GetIndexingMapForTiling( - GetBlockOffsetsForTiling(tiling, ctx), - GetThreadOffsetsForTiling(tiling, ctx), tiling.GetNumThreadsPerBlock(), - tiling.GetNumBlocks(), tiling.GetThreadTileSize(), tiling.GetShape()); + GetBlockOffsetsForTiling(tiling, mlir_context), + GetThreadOffsetsForTiling(tiling, mlir_context), + tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(), + tiling.GetThreadTileSize(), tiling.GetShape(), indexing_context); } IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, @@ -929,7 +992,8 @@ IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, int64_t threads_per_block, int64_t num_blocks, absl::Span thread_tile_sizes, - absl::Span tiled_shape) { + absl::Span tiled_shape, + IndexingContext* indexing_context) { llvm::SmallVector offsets; offsets.reserve(block_offsets.getNumResults()); for (auto [block, thread] : @@ -941,8 +1005,8 @@ IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, }; auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), block_offsets.getNumSymbols(), offsets, - offsets[0].getContext()); - IndexingMap map{affine_map, dimension_ranges, + indexing_context->GetMLIRContext()); + IndexingMap map{indexing_context, affine_map, dimension_ranges, RangesFromUpperBounds(thread_tile_sizes)}; for (int i = 0; i < tiled_shape.size(); ++i) { map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1}); @@ -1034,7 +1098,7 @@ GroupedByOpIndexingMap GroupIndexingMapsByProducers( GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, - MLIRContext* ctx) { + IndexingContext* ctx) { auto initial_map = CreateIdentityMap(target_instr.instruction().shape(), ctx); GroupedByOpIndexingMap grouped_indexing_maps; @@ -1088,9 +1152,9 @@ bool FuseProducerConsumerOutputToInputIndexing( const HloInstruction* producer_instr, absl::flat_hash_map* consumer_indexing, - MLIRContext* mlir_context) { + IndexingContext* indexing_context) { auto producer_indexing = ComputeOutputToInputIndexing( - producer_instr, /*output_id=*/0, mlir_context); + producer_instr, /*output_id=*/0, indexing_context); auto consumer_indexing_maps = (*consumer_indexing)[producer_instr]; for (const auto& [producer_operand_id, producer_operand_indexing] : llvm::enumerate(producer_indexing.indexing_maps)) { @@ -1109,7 +1173,7 @@ bool FuseProducerConsumerOutputToInputIndexing( HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, int output_id, - MLIRContext* ctx) { + IndexingContext* ctx) { if (HloInstruction::IsOpElementwise(instr->opcode())) { return ComputeOutputToInputCwiseOpIndexing(instr, ctx); } @@ -1163,7 +1227,7 @@ HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, int input_id, - MLIRContext* ctx) { + IndexingContext* ctx) { if (HloInstruction::IsOpElementwise(instr->opcode())) { return ComputeInputToOutputCwiseOpIndexing(instr, ctx); } @@ -1200,15 +1264,15 @@ HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, } IndexingMap ComputeEpilogueInputToOutputIndexing( - const HloInstruction* epilogue_root, mlir::MLIRContext* ctx, + const HloInstruction* epilogue_root, IndexingContext* indexing_context, std::function is_root) { auto* instr = epilogue_root; - auto root_indexing = CreateIdentityMap(instr->shape(), ctx); + auto root_indexing = CreateIdentityMap(instr->shape(), indexing_context); while (!is_root(instr)) { // There can be multiple users, but they must have compatible indexing maps. auto* user = instr->users().front(); - auto user_indexing = - ComputeInputToOutputIndexing(user, user->operand_index(instr), ctx); + auto user_indexing = ComputeInputToOutputIndexing( + user, user->operand_index(instr), indexing_context); root_indexing = root_indexing * *user_indexing.indexing_maps[0].begin(); root_indexing.Simplify(); instr = user; diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index 59a56ae750a03d..47abac957e0e0e 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -67,15 +67,15 @@ std::string ToString(const mlir::AffineMap& affine_map); // Computes indexing maps for all input operands necessary to compute an element // of the `output_id` instruction output. -HloInstructionIndexing ComputeOutputToInputIndexing(const HloInstruction* instr, - int output_id, - mlir::MLIRContext* ctx); +HloInstructionIndexing ComputeOutputToInputIndexing( + const HloInstruction* instr, int output_id, + IndexingContext* indexing_context); // Computes indexing maps for all output operands that the element of the // `input_id` instruction input will participate in. -HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, - int input_id, - mlir::MLIRContext* ctx); +HloInstructionIndexing ComputeInputToOutputIndexing( + const HloInstruction* instr, int input_id, + IndexingContext* indexing_context); // Computes the indexing for `epilogue_parent`'s epilogue. For example, if // `epilogue_parent` is a transpose, computes the input to output indexing for @@ -94,7 +94,7 @@ HloInstructionIndexing ComputeInputToOutputIndexing(const HloInstruction* instr, // FindNonTrivialHero, i.e., each instruction in the epilogue only has a single // user, or the users have identical indexing maps. IndexingMap ComputeEpilogueInputToOutputIndexing( - const HloInstruction* epilogue_root, mlir::MLIRContext* ctx, + const HloInstruction* epilogue_root, IndexingContext* indexing_context, std::function is_root = [](const HloInstruction* instr) { return instr->IsRoot(); }); @@ -105,7 +105,7 @@ using GroupedByOpIndexingMap = // cluster starting with `target_instr` and going from def to use. GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( const HloFusionAdaptor& fusion_adaptor, HloInstructionAdaptor target_instr, - mlir::MLIRContext* ctx); + IndexingContext* indexing_context); // Groups indexing maps by instructions. absl::flat_hash_map @@ -118,44 +118,45 @@ bool FuseProducerConsumerOutputToInputIndexing( const HloInstruction* producer_instr, absl::flat_hash_map* consumer_indexing, - mlir::MLIRContext* mlir_context); + IndexingContext* mlir_context); // Creates an indexing map for bitcasting from `input_shape` to `output_shape`. // Equivalent to linearizing the input_shape index and then delinearizing it // to output_shape. IndexingMap GetBitcastMap(const Shape& input_shape, const Shape& output_shape, - mlir::MLIRContext* ctx); + IndexingContext* indexing_context); // Creates an indexing map from the physical layout of the tensor to its logical // layout. -IndexingMap GetIndexingMapFromPhysicalLayoutToLogical(const Shape& shape, - mlir::MLIRContext* ctx); +IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( + const Shape& shape, IndexingContext* indexing_context); // Creates an indexing map from the logical layout of the tensor to its physical // layout. -IndexingMap GetIndexingMapFromLogicalToPhysicalLayout(const Shape& shape, - mlir::MLIRContext* ctx); +IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( + const Shape& shape, IndexingContext* indexing_context); // Creates an indexing map from thread and block IDs to elements of the tiled // shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 // are thread indices (currently only 0 is used), dimensions 3 to 5 are block // indices (currently only 3 is used). mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx); + mlir::MLIRContext* mlir_context); mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx); + mlir::MLIRContext* mlir_context); // Convenience functions for the two functions above // (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up // the ranges of dimensions and symbols. IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - mlir::MLIRContext* ctx); + IndexingContext* indexing_context); IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, mlir::AffineMap thread_offsets, int64_t threads_per_block, int64_t num_blocks, absl::Span thread_tile_sizes, - absl::Span tiled_shape); + absl::Span tiled_shape, + IndexingContext* indexing_context); // Returns the shape of the output of the instruction. const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 7388d194d3eaa4..39cade7c560b58 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -91,7 +91,7 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { auto fusion_adaptor = ProducerConsumerFusion(transpose, root); auto grouped_indexing = ComputeGroupedOutputToInputIndexing( - fusion_adaptor, fusion_adaptor.GetRoots()[0], &mlir_context_); + fusion_adaptor, fusion_adaptor.GetRoots()[0], &indexing_context_); EXPECT_THAT(grouped_indexing, UnorderedElementsAre( Pair(root, ElementsAre(MatchIndexingMap(R"( @@ -148,7 +148,7 @@ TEST_F(IndexingAnalysisTest, auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); auto grouped_indexing = ComputeGroupedOutputToInputIndexing( - *fusion_adaptor, fusion_adaptor->GetRoots()[0], &mlir_context_); + *fusion_adaptor, fusion_adaptor->GetRoots()[0], &indexing_context_); EXPECT_THAT(grouped_indexing, UnorderedElementsAre( @@ -200,7 +200,7 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(exponential); HloInstructionAdaptor parameter_adaptor(*parameter); auto grouped_indexing = ComputeGroupedOutputToInputIndexing( - *fusion_adaptor, parameter_adaptor, &mlir_context_); + *fusion_adaptor, parameter_adaptor, &indexing_context_); EXPECT_THAT(grouped_indexing, UnorderedElementsAre(Pair( parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) @@ -240,7 +240,7 @@ TEST_F(IndexingAnalysisTest, auto parameter_0 = bcast.GetOperand(0); auto grouped_indexing = ComputeGroupedOutputToInputIndexing( - *fusion_adaptor, bcast, &mlir_context_); + *fusion_adaptor, bcast, &indexing_context_); EXPECT_THAT( grouped_indexing, UnorderedElementsAre( @@ -2083,7 +2083,7 @@ TEST_F(IndexingAnalysisTest, TilingIndexing) { Tiling tiling{/*shape=*/{1022, 256, 16}, /*tile_sizes=*/{8, 1, 4}, /*num_threads=*/{1, 4, 4}}; - auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_); + auto indexing_map = GetIndexingMapForTiling(tiling, &indexing_context_); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( @@ -2118,7 +2118,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { ASSERT_TRUE(module.ok()); EXPECT_THAT(ComputeEpilogueInputToOutputIndexing( (*module)->entry_computation()->GetInstructionWithName("t"), - &mlir_context_) + &indexing_context_) .ToString(), MatchIndexingString(R"( (d0, d1) -> (d0 + d1 * 1000) @@ -2139,7 +2139,7 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { ASSERT_TRUE(module.ok()); EXPECT_THAT(ComputeEpilogueInputToOutputIndexing( (*module)->entry_computation()->GetInstructionWithName("t"), - &mlir_context_) + &indexing_context_) .ToString(), MatchIndexingString(R"( (d0, d1) -> (d0, d1) diff --git a/third_party/xla/xla/service/gpu/model/indexing_context.cc b/third_party/xla/xla/service/gpu/model/indexing_context.cc new file mode 100644 index 00000000000000..f44e4977e41baa --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_context.cc @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_context.h" + +namespace xla { +namespace gpu { + +IndexingContext::RTValsID IndexingContext::RegisterRTSymbol( + const HloInstruction* instr, IndexingMap indexing_map) { + return 0; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_context.h b/third_party/xla/xla/service/gpu/model/indexing_context.h new file mode 100644 index 00000000000000..2560cd09ab1864 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_context.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +class IndexingContext { + public: + using RTValsID = int64_t; + + explicit IndexingContext(mlir::MLIRContext* mlir_context) + : mlir_context_(mlir_context) {} + + mlir::MLIRContext* GetMLIRContext() const { return mlir_context_; } + + // TBD: This method should behave like a thread-safe counter. It will register + // a new RTSymbol by adding it to `rt_vals_registry_` with the newly generated + // ID. + RTValsID RegisterRTSymbol(const HloInstruction* instr, + IndexingMap indexing_map); + + private: + mlir::MLIRContext* mlir_context_; + absl::flat_hash_map> + rt_vals_registry_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_CONTEXT_H_ diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 27ec08b41c58bc..7d92c7dcb9e6e1 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_context.h" #include "tsl/platform/logging.h" // IWYU pragma: keep namespace xla { @@ -376,7 +377,7 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) { auto rhs = SimplifyOnce(binop.getRHS()); // Rewrite `(x // c) * c + (x % c)` to `x`. - // TODO(jreiffers): This should also work with (a+b)+c. + // This should also work with (a+b)+c. auto rewrite_add = [&](AffineExpr a, AffineExpr b) -> AffineExpr { if (auto mod = GetConstantRhs(a, AffineExprKind::Mod)) { if (auto mul = GetConstantRhs(b, AffineExprKind::Mul); mod == mul) { @@ -596,12 +597,22 @@ std::vector RangesFromTensorSizes( } IndexingMap IndexingMap::FromTensorSizes( - AffineMap affine_map, absl::Span dim_upper_bounds, + IndexingContext* indexing_context, AffineMap affine_map, + absl::Span dim_upper_bounds, absl::Span symbol_upper_bounds) { - return IndexingMap{affine_map, RangesFromTensorSizes(dim_upper_bounds), + return IndexingMap{indexing_context, affine_map, + RangesFromTensorSizes(dim_upper_bounds), RangesFromTensorSizes(symbol_upper_bounds)}; } +mlir::MLIRContext* IndexingMap::GetMLIRContext() const { + return indexing_context_->GetMLIRContext(); +} + +IndexingContext* IndexingMap::GetIndexingContext() const { + return indexing_context_; +} + void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { if (auto dim_expr = mlir::dyn_cast(expr)) { Interval& current_range = dim_ranges_[dim_expr.getPosition()]; @@ -1011,7 +1022,9 @@ IndexingMap ComposeIndexingMaps(const IndexingMap& first, combined_symbol_ranges.push_back(symbol_range); } - IndexingMap composed_indexing_map(composed_map, first.GetDimensionRanges(), + IndexingContext* indexing_context = first.GetIndexingContext(); + IndexingMap composed_indexing_map(indexing_context, composed_map, + first.GetDimensionRanges(), std::move(combined_symbol_ranges)); // Add constraints that are already present in the producer_map. We have to // compute consumer_map(producer_constraints). To keep all symbols and diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 655d877745860f..e6e84bf82a9f7d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -36,6 +36,8 @@ limitations under the License. namespace xla { namespace gpu { +class IndexingContext; + // Interval represents a closed interval [lower_bound, upper_bound]. struct Interval { std::string ToString() const; @@ -166,10 +168,11 @@ std::vector RangesFromTensorSizes( class IndexingMap { public: IndexingMap( - mlir::AffineMap affine_map, std::vector dim_ranges, - std::vector symbol_ranges, + IndexingContext* indexing_context, mlir::AffineMap affine_map, + std::vector dim_ranges, std::vector symbol_ranges, absl::Span> constraints = {}) - : affine_map_(affine_map), + : indexing_context_(indexing_context), + affine_map_(affine_map), dim_ranges_(std::move(dim_ranges)), symbol_ranges_(std::move(symbol_ranges)) { for (const auto& [expr, range] : constraints) { @@ -177,10 +180,12 @@ class IndexingMap { } } - IndexingMap(mlir::AffineMap affine_map, std::vector dim_ranges, + IndexingMap(IndexingContext* indexing_context, mlir::AffineMap affine_map, + std::vector dim_ranges, std::vector symbol_ranges, const llvm::DenseMap& constraints) - : affine_map_(affine_map), + : indexing_context_(indexing_context), + affine_map_(affine_map), dim_ranges_(std::move(dim_ranges)), symbol_ranges_(std::move(symbol_ranges)), constraints_(constraints) {} @@ -188,7 +193,8 @@ class IndexingMap { static IndexingMap GetUndefined() { return IndexingMap(); } static IndexingMap FromTensorSizes( - mlir::AffineMap affine_map, absl::Span dim_upper_bounds, + IndexingContext* indexing_context, mlir::AffineMap affine_map, + absl::Span dim_upper_bounds, absl::Span symbol_upper_bounds); std::string ToString( @@ -200,7 +206,10 @@ class IndexingMap { bool Simplify(); // Return MLIRContext. - mlir::MLIRContext* GetMLIRContext() const { return affine_map_.getContext(); } + mlir::MLIRContext* GetMLIRContext() const; + + // Return IndexingContext. + IndexingContext* GetIndexingContext() const; // Returns the affine map. mlir::AffineMap GetAffineMap() const { return affine_map_; } @@ -265,6 +274,7 @@ class IndexingMap { // Returns true if simplification was performed. bool SimplifyConstraintRanges(); + IndexingContext* indexing_context_ = nullptr; mlir::AffineMap affine_map_; std::vector dim_ranges_; std::vector symbol_ranges_; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 2e7cbd309b1489..ffc6743863244d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -35,12 +35,16 @@ using ::testing::ElementsAre; class IndexingMapTest : public HloTestBase { public: + IndexingMapTest() + : HloTestBase(), mlir_context_(), indexing_context_(&mlir_context_) {} mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; AffineMapPrinter printer_; }; TEST_F(IndexingMapTest, Evaluation) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), {4, 4}, {2, 2}); @@ -65,10 +69,12 @@ TEST_F(IndexingMapTest, Evaluation) { TEST_F(IndexingMapTest, Composition_Permutation) { IndexingMap producer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), {4, 4}, {2, 2}); IndexingMap consumer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); auto composed = ComposeIndexingMaps(consumer, producer); @@ -84,10 +90,12 @@ TEST_F(IndexingMapTest, Composition_Permutation) { TEST_F(IndexingMapTest, Composition_RestrictedInterval) { IndexingMap producer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), {5, 6}, {7, 2}); IndexingMap consumer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); auto composed = ComposeIndexingMaps(consumer, producer); @@ -103,6 +111,7 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { IndexingMap producer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), {50, 60}, {70, 20}); producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), @@ -111,6 +120,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { Interval{1, 1}); IndexingMap consumer = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), Interval{0, 20}); @@ -146,6 +156,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), {50, 60}, {70, 20}); // This constraint cannot be removed, because it contains a "used symbol". @@ -168,6 +179,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused symbol. @@ -185,6 +197,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( + &indexing_context_, ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), {32}, {1, 2, 3, 4, 5}); @@ -204,7 +217,8 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + &indexing_context_, ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, + {}); indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), Interval{50, 54}); @@ -220,7 +234,8 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + &indexing_context_, ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, + {}); indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), Interval{5, 11}); @@ -233,9 +248,9 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {Interval{0, 99}}, {Interval{-99, 99}}); + IndexingMap indexing_map = IndexingMap( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Interval{0, 99}}, {Interval{-99, 99}}); indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), Interval{-11, -5}); @@ -249,9 +264,9 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivNegativeDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {Interval{0, 99}}, {Interval{-99, 99}}); + IndexingMap indexing_map = IndexingMap( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Interval{0, 99}}, {Interval{-99, 99}}); indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), Interval{-11, -5}); @@ -266,7 +281,8 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierPositiveBounds) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); + &indexing_context_, ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, + {}); indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), Interval{14, 33}); @@ -279,9 +295,9 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {Interval{0, 99}}, {Interval{-99, 99}}); + IndexingMap indexing_map = IndexingMap( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Interval{0, 99}}, {Interval{-99, 99}}); indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), Interval{-11, -5}); @@ -295,9 +311,9 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulNegativeMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {Interval{0, 99}}, {Interval{-99, 99}}); + IndexingMap indexing_map = IndexingMap( + &indexing_context_, ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), + {Interval{0, 99}}, {Interval{-99, 99}}); indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), Interval{-11, -5}); @@ -311,7 +327,8 @@ TEST_F(IndexingMapTest, TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { IndexingMap indexing_map = IndexingMap( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {Interval{5, 5}}, {}); + &indexing_context_, ParseAffineMap("(d0) -> (d0)", &mlir_context_), + {Interval{5, 5}}, {}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5) @@ -324,7 +341,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), + {8, 16}, {}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) @@ -341,7 +359,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { "d2 mod 10)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), + {9, 9, 9}, {}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( @@ -360,7 +379,8 @@ TEST_F(IndexingMapTest, " (d0 * 16 + d1 * 4 + d2) mod 8)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), + {10, 10, 10}, {}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 + d2 floordiv 4) floordiv 2, @@ -377,7 +397,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), + {8, 9}, {}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) @@ -391,7 +412,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { auto serialized_map = "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {128}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128) @@ -404,7 +426,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { auto serialized_map = "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {128}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715) @@ -417,7 +440,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " "14)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {1234}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0) @@ -431,7 +455,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) " "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {1234, 128, 4}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) @@ -447,7 +472,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " "20000)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {872, 4, 128, 896}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( @@ -467,7 +493,8 @@ TEST_F(IndexingMapTest, "()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) " "* 2) floordiv 4)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); + &indexing_context_, ParseAffineMap(serialized_map, &mlir_context_), {}, + {2, 128}); indexing_map.Simplify(); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ( diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc index e7b7e39ac71325..55fa6433ea77ba 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc @@ -53,17 +53,17 @@ HloInstruction* IndexingTestBase::ParseAndGetRoot( HloInstructionIndexing IndexingTestBase::GetOutputToInputIndexing( const HloInstruction* instr, int output_id, bool use_physical_layout) { HloInstructionIndexing indexing = - ComputeOutputToInputIndexing(instr, output_id, &mlir_context_); + ComputeOutputToInputIndexing(instr, output_id, &indexing_context_); if (!use_physical_layout) return indexing; IndexingMap output_permutation = GetIndexingMapFromPhysicalLayoutToLogical( - GetOutputShape(instr, output_id), &mlir_context_); + GetOutputShape(instr, output_id), &indexing_context_); for (const auto& [operand_id, indexing_maps] : llvm::enumerate(indexing.indexing_maps)) { IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( - instr->operand(operand_id)->shape(), &mlir_context_); + instr->operand(operand_id)->shape(), &indexing_context_); absl::flat_hash_set operand_indexing_maps; for (const IndexingMap& indexing_map : indexing_maps) { @@ -86,17 +86,17 @@ HloInstructionIndexing IndexingTestBase::GetOutputToInputIndexing( HloInstructionIndexing IndexingTestBase::GetInputToOutputIndexing( const HloInstruction* instr, int input_id, bool use_physical_layout) { HloInstructionIndexing indexing = - ComputeInputToOutputIndexing(instr, input_id, &mlir_context_); + ComputeInputToOutputIndexing(instr, input_id, &indexing_context_); if (!use_physical_layout) return indexing; IndexingMap input_permutation = GetIndexingMapFromPhysicalLayoutToLogical( - instr->operand(input_id)->shape(), &mlir_context_); + instr->operand(input_id)->shape(), &indexing_context_); for (const auto& [output_id, indexing_maps] : llvm::enumerate(indexing.indexing_maps)) { IndexingMap operand_permutation = GetIndexingMapFromLogicalToPhysicalLayout( - GetOutputShape(instr, output_id), &mlir_context_); + GetOutputShape(instr, output_id), &indexing_context_); absl::flat_hash_set operand_indexing_maps; for (const IndexingMap& indexing_map : indexing_maps) { diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h index 62abd0e5e7fdb4..a0a304b0d43104 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -52,6 +53,8 @@ MATCHER_P(MatchIndexingString, indexing_string, "") { class IndexingTestBase : public HloTestBase { public: + IndexingTestBase() + : HloTestBase(), mlir_context_(), indexing_context_(&mlir_context_) {} HloInstruction* ParseAndGetRoot(absl::string_view hlo_string); HloInstructionIndexing GetOutputToInputIndexing( @@ -63,6 +66,7 @@ class IndexingTestBase : public HloTestBase { bool use_physical_layout = false); mlir::MLIRContext mlir_context_; + IndexingContext indexing_context_; std::unique_ptr module_; }; diff --git a/third_party/xla/xla/service/gpu/model/tile_analysis.cc b/third_party/xla/xla/service/gpu/model/tile_analysis.cc index 3599d19575e70e..8560c10de5af0d 100644 --- a/third_party/xla/xla/service/gpu/model/tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/tile_analysis.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_context.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -252,7 +253,8 @@ std::optional RawSymbolicTileFromIndexingMap( /*static*/ std::optional SymbolicTile::FromIndexingMap( const IndexingMap& indexing_map) { - MLIRContext* mlir_context = indexing_map.GetAffineMap().getContext(); + IndexingContext* indexing_context = indexing_map.GetIndexingContext(); + MLIRContext* mlir_context = indexing_context->GetMLIRContext(); int64_t num_input_dims = indexing_map.GetDimensionCount(); std::vector exprs; exprs.reserve(num_input_dims); @@ -294,8 +296,8 @@ std::optional RawSymbolicTileFromIndexingMap( mlir_context); IndexingMap composed_indexing_map( - indexing_map.GetAffineMap().compose(producer_map), tile_dimension_ranges, - tile_symbol_ranges); + indexing_context, indexing_map.GetAffineMap().compose(producer_map), + tile_dimension_ranges, tile_symbol_ranges); composed_indexing_map.Simplify(); From b9752df1863297248ca2e6a74cb11125aa520474 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 09:50:41 -0700 Subject: [PATCH 028/670] Include ifrt_proxy in xla_extension.so (so it is propagated to jaxlib). PiperOrigin-RevId: 616858886 --- third_party/xla/xla/python/BUILD | 1 + .../xla/xla/python/ifrt_proxy/client/BUILD | 12 +++++-- .../xla/python/ifrt_proxy/client/py_module.cc | 22 +++++++------ .../xla/python/ifrt_proxy/client/py_module.h | 31 ++++++++++++++++++ .../xla/xla/python/ifrt_proxy/jax/BUILD | 1 - .../ifrt_proxy/jax/ifrt_proxy_internal.py | 8 +++-- third_party/xla/xla/python/xla.cc | 5 +++ .../xla/xla/python/xla_extension/__init__.pyi | 1 + .../xla/python/xla_extension/ifrt_proxy.pyi | 32 +++++++++++++++++++ 9 files changed, 97 insertions(+), 16 deletions(-) create mode 100644 third_party/xla/xla/python/ifrt_proxy/client/py_module.h create mode 100644 third_party/xla/xla/python/xla_extension/ifrt_proxy.pyi diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 508a1a82233fe7..3668ab3ba2ff5a 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1258,6 +1258,7 @@ cc_library( "//xla/pjrt/distributed:protocol_proto_cc", "//xla/pjrt/distributed:service", "//xla/python/ifrt", + "//xla/python/ifrt_proxy/client:py_module", "//xla/python/pjrt_ifrt", "//xla/service/cpu:collectives_interface", "@local_tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 7a6071ad0b2aee..7a989ec885c38a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -14,7 +14,6 @@ load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test") load("@local_tsl//tsl:tsl.bzl", "if_google") -load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -494,9 +493,17 @@ ifrt_proxy_cc_test( ], ) -tsl_pybind_extension( +cc_library( name = "py_module", srcs = ["py_module.cc"], + hdrs = ["py_module.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//xla/python:__pkg__"], deps = [ ":grpc_client", ":registry", @@ -514,6 +521,5 @@ tsl_pybind_extension( "@local_tsl//tsl/platform:statusor", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", - "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc b/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc index 4b407bb438bb71..c20dc63c4d06d9 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/py_module.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "xla/python/ifrt_proxy/client/py_module.h" #include #include @@ -32,7 +33,6 @@ #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // NOLINT // IWYU pragma: keep -#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt_proxy/client/registry.h" @@ -100,20 +100,22 @@ absl::StatusOr> GetClient( } } // namespace -} // namespace proxy -} // namespace ifrt -} // namespace xla -PYBIND11_MODULE(py_module, m) { - pybind11_protobuf::ImportNativeProtoCasters(); +void BuildIfrtProxySubmodule(pybind11::module_& m) { + pybind11::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); - using ::xla::ifrt::proxy::PyClientConnectionOptions; - pybind11::class_(m, "ClientConnectionOptions") + pybind11::class_(sub_module, + "ClientConnectionOptions") .def(pybind11::init<>()) .def_readwrite("on_disconnect", &PyClientConnectionOptions::on_disconnect) .def_readwrite("on_connection_update", &PyClientConnectionOptions::on_connection_update); - m.def("get_client", xla::ValueOrThrowWrapper(xla::ifrt::proxy::GetClient), - pybind11::arg("proxy_server_address"), pybind11::arg("options")); + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + pybind11::arg("proxy_server_address"), + pybind11::arg("options")); } + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/py_module.h b/third_party/xla/xla/python/ifrt_proxy/client/py_module.h new file mode 100644 index 00000000000000..508d91a0f2d7c5 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/py_module.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "pybind11/pybind11.h" // from @pybind11 + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(pybind11::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/jax/BUILD b/third_party/xla/xla/python/ifrt_proxy/jax/BUILD index b86f65e9c3596a..1a84033e2fa33a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/jax/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/jax/BUILD @@ -31,7 +31,6 @@ pytype_strict_library( # copybara:uncomment_end deps = [ "//xla/python:xla_client", - "//xla/python/ifrt_proxy/client:py_module", "@pybind11_abseil//pybind11_abseil:status", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py b/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py index 746575cdd61135..790c9567e010af 100644 --- a/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py +++ b/third_party/xla/xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Library to help create a IFRT proxy client.""" +"""Library to help create a IFRT proxy client. + +This library is no longer recommended nor used in OSS; it is used internally +within google code. TODO(madthanu): Remove library. +""" import dataclasses from typing import Callable, Optional from pybind11_abseil import status from xla.python import xla_client -from xla.python.ifrt_proxy.client import py_module @dataclasses.dataclass @@ -47,6 +50,7 @@ def get_client(proxy_server_address: str) -> xla_client.Client: """Creates an IFRT Proxy client for the given server address.""" global _backend_created _backend_created = True + py_module = xla_client._xla.ifrt_proxy # pylint: disable=protected-access cpp_options = py_module.ClientConnectionOptions() cpp_options.on_disconnect = _connection_options.on_disconnect cpp_options.on_connection_update = _connection_options.on_connection_update diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 05752d3ff715e1..12870ab43f93d4 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt_proxy/client/py_module.h" #include "xla/python/py_client.h" #include "xla/service/cpu/collectives_interface.h" #include "tsl/python/lib/core/numpy.h" //NOLINT @@ -1016,6 +1017,10 @@ static void Init(py::module_& m) { BuildMlirSubmodule(m_nb); BuildCustomCallShardingPybindAPI(m_nb); + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + py::class_> preemption_sync_manager(m, "PreemptionSyncManager"); diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 294a62a9136cfb..8fe1300bd94c73 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -37,6 +37,7 @@ from typing import ( import numpy as np +from . import ifrt_proxy from . import jax_jit from . import mlir from . import ops diff --git a/third_party/xla/xla/python/xla_extension/ifrt_proxy.pyi b/third_party/xla/xla/python/xla_extension/ifrt_proxy.pyi new file mode 100644 index 00000000000000..f65685025e5166 --- /dev/null +++ b/third_party/xla/xla/python/xla_extension/ifrt_proxy.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable + +from xla.python import xla_extension + +_Status = Any +Client = xla_extension.Client + + +class ClientConnectionOptions: + on_disconnect: Optional[Callable[[_Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... From a26cc64f94d09388f4fe5f3a21c309b74fed5a03 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 18 Mar 2024 10:37:00 -0700 Subject: [PATCH 029/670] Rollback of GpuTimer: improve kernel execution time measurement accuracy This breaks the gemm algorithm picker test on V100 PiperOrigin-RevId: 616875670 --- .../xla/service/gpu/conv_algorithm_picker.cc | 12 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 31 +---- .../xla/xla/stream_executor/build_defs.bzl | 8 -- third_party/xla/xla/stream_executor/gpu/BUILD | 42 +------ .../xla/xla/stream_executor/gpu/gpu_timer.cc | 114 ++---------------- .../xla/xla/stream_executor/gpu/gpu_timer.h | 34 +----- .../gpu/gpu_timer_kernel.cu.cc | 52 -------- .../stream_executor/gpu/gpu_timer_kernel.h | 26 ---- .../gpu/gpu_timer_kernel_stub.cc | 22 ---- 9 files changed, 21 insertions(+), 320 deletions(-) delete mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc delete mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h delete mode 100644 third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_stub.cc diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc index 4c21084f51d48b..54bde3ac33e147 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc @@ -612,6 +612,7 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Use assignment instead of brace-list to make GCC 4.9 happy. RunConvOptions options; options.runner_cache = runner; + options.profile_result = &profile_result; // The following plan timing code is based on // https://github.com/NVIDIA/cudnn-frontend/blob/60496f42fdc7a4ccc059f5934e306e728a756755/include/cudnn_frontend_find_plan.h float max_time = 0; @@ -624,20 +625,15 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Dry-run to warmup the plan. launch_status = RunGpuConv(config, operand_buffers, result_buffers, scratch_memory, stream, options); - // It is intentional that the warm-up run does not have a profile result. - // This avoids a timeout and error message if lazy module loading is enabled - // by ensuring that lazy loading happens outside the GpuTimer region. - options.profile_result = &profile_result; constexpr int kMaxIter = 10; // Iterate until the new measurement is within kThreshold of the current // minimum. int num_iters = 0; - for (; num_iters < kMaxIter && launch_status.ok(); ++num_iters) { + for (; + num_iters < kMaxIter && launch_status.ok() && profile_result.is_valid(); + num_iters++) { launch_status = RunGpuConv(config, operand_buffers, result_buffers, scratch_memory, stream, options); - if (!profile_result.is_valid()) { - break; - } float old_min_time = min_time; min_time = std::min(min_time, profile_result.elapsed_time_in_ms()); max_time = std::max(max_time, profile_result.elapsed_time_in_ms()); diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index a5ff6665031619..9ed90d2ba6eadd 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -241,15 +240,6 @@ class GemmAutotuner { auto tuned_func = [&](const se::blas::AlgorithmType& algorithm) -> absl::StatusOr { - // Do a warm-up run first, without a profile result. This avoids a timeout - // and error message if lazy module loading is enabled by ensuring that - // lazy loading happens outside the GpuTimer. RunGemm swallows error codes - // when profile_result is passed, as it is in the measurement below, but - // not otherwise. It is, therefore, consistent to ignore the error code - // here. - static_cast(RunGemm(gemm_config, lhs_buffer_, rhs_buffer_, - output_buffer_, workspace_buffer, - deterministic_ops_, stream_, algorithm)); se::blas::ProfileResult profile_result; // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail // for all algorithms if we're targeting < sm_50. But because we pass a @@ -421,28 +411,15 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, config.GetGpuComputeCapability()); if (update_algorithm) { - int64_t new_algorithm{}; if (algorithm.has_gemm()) { - new_algorithm = algorithm.gemm().algorithm(); + backend_config.set_selected_algorithm(algorithm.gemm().algorithm()); } else { // NOTE: runtime autotuning is no longer available => set to default - new_algorithm = se::blas::kDefaultAlgorithm; + backend_config.set_selected_algorithm(se::blas::kDefaultAlgorithm); } - - if (new_algorithm == old_algorithm && - backend_config.has_selected_algorithm()) { - // We don't need to update the backend config if - // the algorithm hasn't changed unless previously - // the algorithm wasn't set explicitly. - return false; - } - - backend_config.set_selected_algorithm(new_algorithm); - TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config)); - return true; // We changed `gemm` } - - return false; // No change to `gemm` + TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config)); + return old_algorithm != backend_config.selected_algorithm(); } absl::StatusOr RunOnComputation(HloComputation* computation, diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 4e43fbec8d0c1e..6916574c646edf 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -88,11 +88,3 @@ def cuda_only_cc_library(name, tags = [], **kwargs): restricted_to = kwargs.get("restricted_to"), target_compatible_with = kwargs.get("target_compatible_with"), ) - -# TODO(hebecker): Remove this once we've fixed our ARM build -def if_google_arm_build( - if_true, # @unused - if_false = []): - return select({ - "//conditions:default": if_false, - }) diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 8dfb83e2657c98..f75e32f0fe8866 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -20,7 +20,6 @@ load( load( "//xla/stream_executor:build_defs.bzl", "gpu_only_cc_library", - "if_google_arm_build", "if_gpu_is_configured", ) load( @@ -316,47 +315,11 @@ gpu_only_cc_library( ], ) -gpu_only_cc_library( - name = "gpu_timer_kernel_header", - hdrs = ["gpu_timer_kernel.h"], -) - -gpu_kernel_library( - name = "gpu_timer_kernel", - srcs = if_gpu_is_configured(["gpu_timer_kernel.cu.cc"]), - deps = [ - ":gpu_timer_kernel_header", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -# TODO(hebecker): Remove this once we have fixed our ARM build -gpu_only_cc_library( - name = "gpu_timer_kernel_stub", - srcs = [ - "gpu_timer_kernel_stub.cc", - ], - deps = [":gpu_timer_kernel_header"], -) - -# TODO(hebecker): Remove this once we have fixed our ARM build -cc_library( - name = "gpu_timer_kernel_not_on_google_arm", - deps = if_google_arm_build( - [":gpu_timer_kernel_stub"], - [":gpu_timer_kernel"], - ), -) - gpu_only_cc_library( name = "gpu_timer_header", hdrs = ["gpu_timer.h"], deps = [ ":gpu_executor_header", - ":gpu_timer_kernel_header", ":gpu_types_header", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", @@ -371,7 +334,6 @@ gpu_only_cc_library( ":gpu_driver_header", ":gpu_executor_header", ":gpu_stream", - ":gpu_timer_kernel_header", ":gpu_types_header", "//xla/stream_executor", "//xla/stream_executor:stream_executor_internal", @@ -386,9 +348,7 @@ gpu_only_cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ] + if_gpu_is_configured([ - ":gpu_timer_kernel_not_on_google_arm", - ]) + if_cuda_is_configured([ + ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:rocm_driver", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc index c0256e8051c719..ecd3f40c6725c9 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc @@ -51,21 +51,10 @@ absl::Duration RandomDuration() { return absl::Microseconds(distribution(rng)); } -bool ShouldLaunchDelayKernel() { - // Only launch the delay kernel if CUDA_LAUNCH_BLOCKING is not set to 1. - static bool value = [] { - const char* blocking = std::getenv("CUDA_LAUNCH_BLOCKING"); - return !blocking || std::string_view{blocking} != "1"; - }(); - return value; -} - } // namespace /*deprecated*/ /*static*/ absl::StatusOr GpuTimer::Create( GpuStream* stream) { - // This deprecated factory does not launch the delay kernel and may lead to - // reduced measurement accuracy. GpuExecutor* parent = stream->parent(); GpuContext* context = parent->gpu_context(); GpuEventHandle start_event; @@ -83,8 +72,6 @@ bool ShouldLaunchDelayKernel() { /*deprecated*/ /*static*/ absl::StatusOr> GpuTimer::CreateIfNeeded(GpuStream* stream, bool is_needed) { - // This deprecated factory does not launch the delay kernel and may lead to - // reduced measurement accuracy. if (is_needed) { TF_ASSIGN_OR_RETURN(GpuTimer t, GpuTimer::Create(stream)); return {std::make_optional(std::move(t))}; @@ -92,78 +79,16 @@ GpuTimer::CreateIfNeeded(GpuStream* stream, bool is_needed) { return std::nullopt; } -/*static*/ absl::StatusOr -GpuTimer::GpuSemaphore::Create(StreamExecutor* executor) { - // Allocate the value in pinned host memory that can be read from both - // host and device. - TF_ASSIGN_OR_RETURN(auto alloc, - executor->HostMemoryAllocate(sizeof(GpuSemaphoreState))); - return GpuSemaphore{std::move(alloc)}; +[[deprecated("So it can quietly call a deprecated method")]] /*static*/ absl:: + StatusOr + GpuTimer::Create(Stream* stream) { + return GpuTimer::Create(AsGpuStream(stream)); } -DeviceMemory GpuTimer::GpuSemaphore::device() { - // This assumes unified addressing, as we do not explicitly translate the - // host pointer into a device pointer. - return DeviceMemory::MakeFromByteSize( - ptr_->opaque(), sizeof(GpuSemaphoreState)); -} - -/*static*/ absl::StatusOr GpuTimer::Create(Stream* real_stream) { - StreamExecutor* executor = real_stream->parent(); - GpuStream* stream = AsGpuStream(real_stream); - GpuExecutor* parent = stream->parent(); - GpuContext* context = parent->gpu_context(); - GpuEventHandle start_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, - GpuDriver::EventFlags::kDefault)); - GpuEventHandle stop_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, - GpuDriver::EventFlags::kDefault)); - CHECK(start_event != nullptr && stop_event != nullptr); - GpuSemaphore semaphore{}; - if (ShouldLaunchDelayKernel()) { - // Check the assumption that this device supports unified addressing, - // otherwise skip the delay kernel - TF_ASSIGN_OR_RETURN(int status, GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, - parent->device())); - if (!status) { - LOG(WARNING) << "Skipping the delay kernel because the device does not " - "support unified addressing"; - } else { - // Allocate a semaphore value that will be used to signal to the delay - // kernel that it may exit. - TF_ASSIGN_OR_RETURN(semaphore, GpuSemaphore::Create(executor)); - *semaphore = GpuSemaphoreState::Hold; - // In principle the kernel could be loaded lazily and shared across - // multiple GpuTimer objects. - TF_ASSIGN_OR_RETURN( - auto kernel, - (TypedKernel, - GpuSemaphoreState>::Create(executor, "DelayKernel", - delay_kernel::kernel()))); - // Launch a delay kernel into this stream, which will spin until - // GetElapsedDuration() is called, the timer is destroyed, or the timeout - // in the kernel is reached. - TF_RETURN_IF_ERROR(real_stream->ThenLaunch( - ThreadDim(1, 1, 1), BlockDim(1, 1, 1), kernel, semaphore.device(), - GpuSemaphoreState::Release)); - } - } - // The start event goes after the delay kernel in the stream - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, - stream->gpu_stream())); - return absl::StatusOr{absl::in_place, parent, start_event, - stop_event, stream, std::move(semaphore)}; -} - -/*static*/ absl::StatusOr> GpuTimer::CreateIfNeeded( - Stream* stream, bool is_needed) { - if (is_needed) { - TF_ASSIGN_OR_RETURN(GpuTimer t, GpuTimer::Create(stream)); - return {std::make_optional(std::move(t))}; - } - return std::nullopt; +[[deprecated("So it can quietly call a deprecated method")]] /*static*/ absl:: + StatusOr> + GpuTimer::CreateIfNeeded(Stream* stream, bool is_needed) { + return GpuTimer::CreateIfNeeded(AsGpuStream(stream), is_needed); } /*static*/ void GpuTimer::ReturnRandomDurationsForTesting() { @@ -172,17 +97,6 @@ DeviceMemory GpuTimer::GpuSemaphore::device() { GpuTimer::~GpuTimer() { GpuContext* context = parent_->gpu_context(); - if (semaphore_ && !is_stopped_) { - // Signal the delay kernel that it can exit - *semaphore_ = GpuSemaphoreState::Release; - // Wait for the delay kernel to exit before destroying the value that it is - // watching. - absl::Status status = - GpuDriver::SynchronizeStream(context, stream_->gpu_stream()); - if (!status.ok()) { - LOG(ERROR) << status; - } - } if (start_event_ != nullptr) { absl::Status status = GpuDriver::DestroyEvent(context, &start_event_); if (!status.ok()) { @@ -203,18 +117,6 @@ absl::StatusOr GpuTimer::GetElapsedDuration() { } TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent_->gpu_context(), stop_event_, stream_->gpu_stream())); - // If we launched the delay kernel then check if it already timed out. - if (semaphore_) { - if (*semaphore_ == GpuSemaphoreState::TimedOut) { - // The delay kernel did not achieve the intended result. - LOG(ERROR) << "Delay kernel timed out: measured time has sub-optimal " - "accuracy. There may be a missing warmup execution, please " - "investigate in Nsight Systems."; - } else { - // Signal that the kernel can exit - *semaphore_ = GpuSemaphoreState::Release; - } - } float elapsed_milliseconds = NAN; if (!GpuDriver::GetEventElapsedTime(parent_->gpu_context(), &elapsed_milliseconds, start_event_, diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index 251c77ec7ee1ea..8fd83bec6499e3 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" namespace xla { @@ -37,29 +36,9 @@ namespace gpu { class GpuExecutor; class GpuStream; -// When a timer is created it launches a delay kernel into the given stream and -// queues a start event immediately afterwards. This delay kernel blocks -// execution on the stream until GetElapsedDuration() is called, at which point -// an end event is queued and the delay kernel exits. This allows the device -// execution time of the tasks queued to the stream while the timer is active -// to be measured more accurately. +// Timer is started once it's created, and is stopped once read. class GpuTimer { public: - class GpuSemaphore { - public: - GpuSemaphore() = default; - static absl::StatusOr Create(StreamExecutor* executor); - explicit operator bool() const { return bool{ptr_}; } - GpuSemaphoreState& operator*() { - return *static_cast(ptr_->opaque()); - } - DeviceMemory device(); - - private: - explicit GpuSemaphore(std::unique_ptr alloc) - : ptr_{std::move(alloc)} {} - std::unique_ptr ptr_; - }; static absl::StatusOr Create(Stream* stream); [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr Create(GpuStream* stream); @@ -74,20 +53,17 @@ class GpuTimer { CreateIfNeeded(GpuStream* stream, bool is_needed); explicit GpuTimer(GpuExecutor* parent, GpuEventHandle start_event, - GpuEventHandle stop_event, GpuStream* stream, - GpuSemaphore semaphore = {}) + GpuEventHandle stop_event, GpuStream* stream) : parent_(parent), start_event_(start_event), stop_event_(stop_event), - stream_(stream), - semaphore_(std::move(semaphore)) {} + stream_(stream) {} GpuTimer(GpuTimer&& other) : parent_(other.parent_), start_event_(std::exchange(other.start_event_, nullptr)), stop_event_(std::exchange(other.stop_event_, nullptr)), - stream_(other.stream_), - semaphore_(std::move(other.semaphore_)) {} + stream_(other.stream_) {} GpuTimer& operator=(GpuTimer&& other) { if (this != &other) { @@ -95,7 +71,6 @@ class GpuTimer { start_event_ = std::exchange(other.start_event_, nullptr); stop_event_ = std::exchange(other.stop_event_, nullptr); stream_ = other.stream_; - semaphore_ = std::move(other.semaphore_); } return *this; } @@ -111,7 +86,6 @@ class GpuTimer { GpuEventHandle start_event_ = nullptr; GpuEventHandle stop_event_ = nullptr; GpuStream* stream_; - GpuSemaphore semaphore_; bool is_stopped_ = false; GpuTimer(const GpuTimer&) = delete; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc deleted file mode 100644 index 0ce4b1d9fbb323..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" - -#include - -namespace stream_executor::gpu { -namespace { -// Wait for the value pointed to by `semaphore` to have value `target`, timing -// out after approximately `APPROX_TIMEOUT_SECONDS` seconds if that value is -// not reached. This can happen if, for example, blocking launches are enabled -// via CUDA_LAUNCH_BLOCKING=1. It can also happen if launching a kernel after -// this delay kernel causes synchronisation, e.g. because of lazy loading. -__global__ void DelayKernel(volatile GpuSemaphoreState* semaphore, - GpuSemaphoreState target) { - constexpr int64_t WAIT_CYCLES{1024}; - constexpr int64_t TIMEOUT_CYCLES{200000000}; // 100ms at 2GHz - const int64_t tstart{clock64()}; - bool target_not_reached; - while ((target_not_reached = (*semaphore != target)) && - (clock64() - tstart) < TIMEOUT_CYCLES) { - int64_t elapsed{}; - const int64_t t0{clock64()}; - do { - elapsed = clock64() - t0; - } while (elapsed < WAIT_CYCLES); - } - if (target_not_reached) { - // We are exiting due to the timeout. Signal this back to the host so that - // we can emit a warning, as it probably indicates suboptimal usage. - *semaphore = GpuSemaphoreState::TimedOut; - } -} -} // namespace - -namespace delay_kernel { -void* kernel() { return reinterpret_cast(DelayKernel); } -} // namespace delay_kernel - -} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h deleted file mode 100644 index 2ac358b4ee56c5..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ - -namespace stream_executor::gpu { -enum struct GpuSemaphoreState { Hold, Release, TimedOut }; -namespace delay_kernel { -void* kernel(); // returns a pointer to a CUDA C++ device function -} // namespace delay_kernel -} // namespace stream_executor::gpu - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_stub.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_stub.cc deleted file mode 100644 index 5286b5445b8b56..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_stub.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" - -namespace stream_executor::gpu { -namespace delay_kernel { -void* kernel() { return nullptr; } -} // namespace delay_kernel -} // namespace stream_executor::gpu From e37cdcfa8b76ec3ea8fbeb8748b472b77603e42b Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 18 Mar 2024 11:10:33 -0700 Subject: [PATCH 030/670] Cherry-pick https://github.com/llvm/llvm-project/commit/daa350c1995015daac552548c34b87220f21156d into TF This lets MSVC compile MLIR again PiperOrigin-RevId: 616887711 --- ...50c1995015daac552548c34b87220f21156d.patch | 77 +++++++++++++++++++ third_party/llvm/workspace.bzl | 1 + 2 files changed, 78 insertions(+) create mode 100644 third_party/llvm/daa350c1995015daac552548c34b87220f21156d.patch diff --git a/third_party/llvm/daa350c1995015daac552548c34b87220f21156d.patch b/third_party/llvm/daa350c1995015daac552548c34b87220f21156d.patch new file mode 100644 index 00000000000000..541c4e2f3bbbcb --- /dev/null +++ b/third_party/llvm/daa350c1995015daac552548c34b87220f21156d.patch @@ -0,0 +1,77 @@ +commit daa350c1995015daac552548c34b87220f21156d +Author: Benjamin Kramer +Date: Sun Mar 17 14:05:41 2024 +0100 + + [mlir] Work around MSVC bug + + MSVC fails to parse this construct, leading to + MlirTranslateMain.cpp(70): error C2065: 'inputSplitMarker': undeclared identifier + + Just switching to brace init works around the issue + +diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +index 51504ad58282..44c5e9826f3b 100644 +--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp ++++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +@@ -128,7 +128,7 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { + cl::desc("Print the list of registered dialects and exit"), + cl::location(showDialectsFlag), cl::init(false)); + +- static cl::opt splitInputFile( ++ static cl::opt splitInputFile{ + "split-input-file", llvm::cl::ValueOptional, + cl::callback([&](const std::string &str) { + // Implicit value: use default marker if flag was used without value. +@@ -137,7 +137,7 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { + }), + cl::desc("Split the input file into chunks using the given or " + "default marker and process each chunk independently"), +- cl::location(splitInputFileFlag), cl::init("")); ++ cl::location(splitInputFileFlag), cl::init("")}; + + static cl::opt outputSplitMarker( + "output-split-marker", +diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +index 1aaf8adb50a7..bd9928950ecc 100644 +--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp ++++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +@@ -62,7 +62,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, + llvm::cl::desc("Allow operation with no registered dialects (discouraged: testing only!)"), + llvm::cl::init(false)); + +- static llvm::cl::opt inputSplitMarker( ++ static llvm::cl::opt inputSplitMarker{ + "split-input-file", llvm::cl::ValueOptional, + llvm::cl::callback([&](const std::string &str) { + // Implicit value: use default marker if flag was used without value. +@@ -71,7 +71,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, + }), + llvm::cl::desc("Split the input file into chunks using the given or " + "default marker and process each chunk independently"), +- llvm::cl::init("")); ++ llvm::cl::init("")}; + + static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", +diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp +index d312765e40b0..c6ad6c361e99 100644 +--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp ++++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp +@@ -136,7 +136,7 @@ int main(int argc, char **argv) { + llvm::cl::desc( + "Print out the parsed ODS information from the input file"), + llvm::cl::init(false)); +- llvm::cl::opt inputSplitMarker( ++ llvm::cl::opt inputSplitMarker{ + "split-input-file", llvm::cl::ValueOptional, + llvm::cl::callback([&](const std::string &str) { + // Implicit value: use default marker if flag was used without value. +@@ -145,7 +145,7 @@ int main(int argc, char **argv) { + }), + llvm::cl::desc("Split the input file into chunks using the given or " + "default marker and process each chunk independently"), +- llvm::cl::init("")); ++ llvm::cl::init("")}; + llvm::cl::opt outputSplitMarker( + "output-split-marker", + llvm::cl::desc("Split marker to use for merging the ouput"), diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c190989fc46286..67616ae9c97943 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -17,6 +17,7 @@ def repo(name): ], build_file = "//third_party/llvm:llvm.BUILD", patch_file = [ + "//third_party/llvm:daa350c1995015daac552548c34b87220f21156d.patch", "//third_party/llvm:generated.patch", # Autogenerated, don't remove. "//third_party/llvm:build.patch", "//third_party/llvm:mathextras.patch", From fa310c6b644e8859c59fdd01cf6dbb1d85496f8c Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Mon, 18 Mar 2024 11:19:07 -0700 Subject: [PATCH 031/670] PR #9873: [ROCm] Don't use CUDA PTX for ROCM in ComputationIdCmd Imported from GitHub PR https://github.com/openxla/xla/pull/9873 Copybara import of the project: -- 818077159230e06ce8e94b3c556d1d68fa125b09 by Dragan Mladjenovic : [ROCm] Don't use CUDA PTX for ROCM in ComputationIdCmd Merging this change closes #9873 PiperOrigin-RevId: 616890560 --- .../xla/xla/service/gpu/runtime/command_buffer_cmd.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index c1909dc98e437a..1da3c5d4253c9f 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -502,6 +502,7 @@ CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, StateManager& state) { +#if defined(GOOGLE_CUDA) { absl::MutexLock lock(&mutex_); if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); @@ -514,6 +515,7 @@ absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, absl::MutexLock lock(&mutex_); memset_kernels_.emplace(params.executor, std::move(kernel)); +#endif // GOOGLE_CUDA return absl::OkStatus(); } @@ -540,6 +542,7 @@ absl::Status ComputationIdCmd::Record( << "; execution_scope_id=" << execution_scope_id.value(); VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; +#if defined(GOOGLE_CUDA) se::Kernel* memset_kernel = [&] { absl::MutexLock lock(&mutex_); return memset_kernels_[execute_params.stream->parent()].get(); @@ -553,6 +556,10 @@ absl::Status ComputationIdCmd::Record( auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); return command_buffer->Launch(execution_scope_id, se::ThreadDim(1), se::BlockDim(1), *memset_kernel, *args); +#else + return command_buffer->Memset(execution_scope_id, &dst, value, + /*num_elements=*/1); +#endif // GOOGLE_CUDA } //===----------------------------------------------------------------------===// From 3f979c5b9e78106530bb425d895251d560cf444a Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 18 Mar 2024 11:42:46 -0700 Subject: [PATCH 032/670] Fix build breakage for //tensorflow/lite/delegates/flex:util_test. PiperOrigin-RevId: 616898445 --- tensorflow/lite/delegates/flex/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 5f126f68124cf8..da620e2e011019 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -342,6 +342,7 @@ tf_cc_test( srcs = ["util_test.cc"], deps = [ ":util", + "//tensorflow/c:tf_datatype", "//tensorflow/core:framework", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", "//tensorflow/lite:string", From d817b4582e48e66b947882c5bfe407ae763902ef Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 18 Mar 2024 13:01:18 -0700 Subject: [PATCH 033/670] [xla::ffi] Forked Pointer to xla/ffi/api/ffi.h It is useful in both "internal" and "external" FFI versions. PiperOrigin-RevId: 616921396 --- third_party/xla/xla/ffi/api/ffi.h | 25 ++++++++++++++++++++++++ third_party/xla/xla/ffi/api/ffi_test.cc | 26 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index f6d742656a2a79..6c281e4878f960 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -243,6 +243,31 @@ struct ArgDecoding> { } }; +//===----------------------------------------------------------------------===// +// Attributes decoding +//===----------------------------------------------------------------------===// + +// A type tag to mark i64 attributes as pointers to `T`. +template +struct Pointer {}; + +template +struct AttrDecoding> { + using Type = T*; + + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (type != XLA_FFI_AttrType_I64) { + return diagnostic.Emit("Wrong attribute type: ") + << "expected i64 for passing user data but got " << type; + } + + static_assert(sizeof(uintptr_t) == sizeof(int64_t)); + uintptr_t ptr = *reinterpret_cast(attr); + return reinterpret_cast(ptr); + } +}; + //===----------------------------------------------------------------------===// // Result encoding //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index af720f2686370c..53ef32daa7283c 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -231,6 +231,32 @@ TEST(FfiTest, BindingPlatformStreamInference) { (void)Ffi::BindTo(+[](TestStream stream) { return Error::Success(); }); } +TEST(FfiTest, PointerAttr) { + std::string foo = "foo"; + + // Test for convenience attr binding that casts i64 attribute to user-type + // pointers. It's up to the user to guarantee that pointer is valid. + auto ptr = reinterpret_cast(&foo); + static_assert(sizeof(ptr) == sizeof(int64_t)); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("ptr", static_cast(ptr)); + + CallFrameBuilder builder; + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](const std::string* str) { + EXPECT_EQ(*str, "foo"); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Attr>("ptr").To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + //===----------------------------------------------------------------------===// // Performance benchmarks are below. //===----------------------------------------------------------------------===// From 013c4759ca4b40cb1281fb19473c1d5a3f6d18cd Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Mon, 18 Mar 2024 13:01:32 -0700 Subject: [PATCH 034/670] Fix conversions between ShardingParam and HloSharding. This change ensures that conversions work correctly for meshes with more than 2 axis, and adds additional tests for conversions between HloSharding to ShardingParam, ShardingParam to OpSharding, and ShardingParam to HloSharding. PiperOrigin-RevId: 616921476 --- third_party/xla/xla/python/ifrt/ir/BUILD | 3 + .../xla/xla/python/ifrt/ir/sharding_param.cc | 77 +++++++--- .../xla/xla/python/ifrt/ir/sharding_param.h | 3 + .../python/ifrt/ir/tests/verify_array.mlir | 2 +- .../python/ifrt/ir/tests/verify_reshard.mlir | 4 +- third_party/xla/xla/python/ifrt/support/BUILD | 1 + .../ifrt/support/sharding_conversions.cc | 27 +++- .../ifrt/support/sharding_conversions_test.cc | 145 +++++++++++++++--- 8 files changed, 210 insertions(+), 52 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index e81f592e48baff..5f99830ce7703f 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -134,10 +134,13 @@ cc_library( ":ifrt_dialect_inc_gen", ":ifrt_interfaces_inc_gen", ":ifrt_ops_inc_gen", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", ], ) diff --git a/third_party/xla/xla/python/ifrt/ir/sharding_param.cc b/third_party/xla/xla/python/ifrt/ir/sharding_param.cc index 68c9e472901385..d8b36fb5d72d87 100644 --- a/third_party/xla/xla/python/ifrt/ir/sharding_param.cc +++ b/third_party/xla/xla/python/ifrt/ir/sharding_param.cc @@ -18,14 +18,20 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tsl/platform/errors.h" namespace xla { namespace ifrt { @@ -66,14 +72,37 @@ void PopulateDevices(llvm::ArrayRef permutation, } // namespace +absl::Status ShardingParam::MinorToMajor::verify() const { + if (permutation.size() != axis_sizes.size() || axis_sizes.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect same non-zero size for `permutation` and `axis_sizes`. Actual ", + permutation.size(), " vs ", axis_sizes.size())); + } + llvm::DenseSet permutation_set(permutation.begin(), permutation.end()); + if (permutation_set.size() != permutation.size()) { + return absl::InvalidArgumentError( + absl::StrCat("`permutation` [", absl::StrJoin(permutation, ","), + "] has duplicate values")); + } + for (const int index : permutation) { + if (index < 0 || index >= axis_sizes.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Out of range axis ", index, " to the mesh of [", + absl::StrJoin(permutation, ","), "] on ", + absl::StrJoin(axis_sizes, "x"))); + } + } + return absl::OkStatus(); +} + mlir::LogicalResult ShardingParam::MinorToMajor::verify( llvm::function_ref emit_error) const { - if (permutation.size() != axis_sizes.size() || axis_sizes.empty()) { - return emit_error() << "Expect same non-zero size for `permutation` and " - "`axis_sizes`. Actual " - << permutation.size() << " vs " << axis_sizes.size(); + auto status = verify(); + if (status.ok()) { + return mlir::success(); + } else { + return emit_error() << status.message(); } - return mlir::success(); } void ShardingParam::MinorToMajor::ToDeviceList( @@ -120,12 +149,8 @@ mlir::FailureOr ShardingParam::Parse( return ShardingParam(dim_shards, minor_to_major); } -mlir::LogicalResult ShardingParam::verify( - llvm::function_ref emit_error) const { - if (mlir::failed(minor_to_major().verify(emit_error))) { - return mlir::failure(); - } - +absl::Status ShardingParam::verify() const { + TF_RETURN_IF_ERROR(minor_to_major().verify()); int dim_index = 0; int cum_size = 1; for (const int index : minor_to_major().permutation) { @@ -135,17 +160,11 @@ mlir::LogicalResult ShardingParam::verify( if (dim_index == dim_shards().size()) { break; } - if (index < 0 || index >= minor_to_major().axis_sizes.size()) { - return emit_error() << "Out of range axis " << index << " to the mesh of " - << minor_to_major().permutation << " on " - << minor_to_major().axis_sizes; - } - cum_size *= minor_to_major().axis_sizes[index]; if (cum_size > dim_shards()[dim_index]) { - return emit_error() << "Dimension #" << dim_index << " of " - << dim_shards()[dim_index] - << " shards can't be assigned to the axes"; + return absl::InvalidArgumentError(absl::StrCat( + "Dimension #", dim_index, " of ", dim_shards()[dim_index], + " shards can't be assigned to the axes")); } else if (cum_size == dim_shards()[dim_index]) { cum_size = 1; dim_index++; @@ -155,12 +174,22 @@ mlir::LogicalResult ShardingParam::verify( dim_index++; } if (dim_index != dim_shards().size()) { - return emit_error() << "Can't shard the dims " << dim_shards() - << " to the mesh of " << minor_to_major().permutation - << " on " << minor_to_major().axis_sizes; + return absl::InvalidArgumentError(absl::StrCat( + "Can't shard the dims ", absl::StrJoin(dim_shards(), "x"), + " to the mesh of [", absl::StrJoin(minor_to_major().permutation, ","), + "] on ", absl::StrJoin(minor_to_major().axis_sizes, "x"))); } + return absl::OkStatus(); +} - return mlir::success(); +mlir::LogicalResult ShardingParam::verify( + llvm::function_ref emit_error) const { + auto status = verify(); + if (status.ok()) { + return mlir::success(); + } else { + return emit_error() << status.message(); + } } std::string ShardingParam::DebugString() const { diff --git a/third_party/xla/xla/python/ifrt/ir/sharding_param.h b/third_party/xla/xla/python/ifrt/ir/sharding_param.h index 5388f860b53ee7..13de6a96e9dcb5 100644 --- a/third_party/xla/xla/python/ifrt/ir/sharding_param.h +++ b/third_party/xla/xla/python/ifrt/ir/sharding_param.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLFunctionalExtras.h" @@ -79,6 +80,7 @@ class ShardingParam { // The size of mesh dimensions before the permutation. llvm::SmallVector axis_sizes; + absl::Status verify() const; mlir::LogicalResult verify( llvm::function_ref emit_error) const; @@ -94,6 +96,7 @@ class ShardingParam { : dim_shards_(dim_shards), minor_to_major_(minor_to_major) {} static mlir::FailureOr Parse(mlir::AsmParser& ods_parser); + absl::Status verify() const; mlir::LogicalResult verify( llvm::function_ref emit_error) const; diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_array.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_array.mlir index 339d351958d3e2..81b557bf28d5e9 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_array.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_array.mlir @@ -64,7 +64,7 @@ func.func @array_requires_same_permutation_and_axis_sizes() { // ----- func.func @array_requires_enough_devices() { - // expected-error@+2 {{Can't shard the dims 2, 2 to the mesh of 0 on 2}} + // expected-error@+2 {{Can't shard the dims 2x2 to the mesh of [0] on 2}} %0 = builtin.unrealized_conversion_cast to !ifrt.array, 2x2 to [0] on 2, [0,1]> return diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_reshard.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_reshard.mlir index cc8370e81f9ad1..a34af467efe6a7 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_reshard.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_reshard.mlir @@ -47,7 +47,7 @@ func.func @reshard_requires_same_global_shape( func.func @reshard_requires_non_negative_axis_index( %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) attributes {ifrt.function} { - // expected-error@+3 {{Out of range axis -1 to the mesh of -1 on 2}} + // expected-error@+3 {{Out of range axis -1 to the mesh of [-1] on 2}} %0 = ifrt.Reshard(%arg0) : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) -> !ifrt.array, 1x2 to [-1] on 2, [2,3]> @@ -59,7 +59,7 @@ func.func @reshard_requires_non_negative_axis_index( func.func @reshard_requires_valid_axis_index( %arg0: !ifrt.array, 1x1 to [0] on 2, [0,1]>) attributes {ifrt.function} { - // expected-error@+3 {{Out of range axis 1234567890 to the mesh of 1234567890 on 2}} + // expected-error@+3 {{Out of range axis 1234567890 to the mesh of [1234567890] on 2}} %0 = ifrt.Reshard(%arg0) : (!ifrt.array, 1x1 to [0] on 2, [0,1]>) -> !ifrt.array, 1x2 to [1234567890] on 2, [2,3]> diff --git a/third_party/xla/xla/python/ifrt/support/BUILD b/third_party/xla/xla/python/ifrt/support/BUILD index f0405ba8ca8783..33907fd34f3d13 100644 --- a/third_party/xla/xla/python/ifrt/support/BUILD +++ b/third_party/xla/xla/python/ifrt/support/BUILD @@ -40,6 +40,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/python/ifrt/support/sharding_conversions.cc b/third_party/xla/xla/python/ifrt/support/sharding_conversions.cc index fc315091c8c18e..1c6a2b3f6f5a73 100644 --- a/third_party/xla/xla/python/ifrt/support/sharding_conversions.cc +++ b/third_party/xla/xla/python/ifrt/support/sharding_conversions.cc @@ -106,14 +106,21 @@ absl::StatusOr ToHloSharding(const ShardingParam& sharding_param) { cum_size *= dim_shard; dims.push_back(dim_shard); } + // Applies the inverse of the transposes from `ToShardingParam`. + llvm::SmallVector permutation; + int num_axis = sharding_param.minor_to_major().permutation.size(); + permutation.reserve(num_axis); + for (const int axis_id : + llvm::reverse(sharding_param.minor_to_major().permutation)) { + permutation.push_back(num_axis - axis_id - 1); + } if (device_count != cum_size) { // Add the replicated dimension. dims.push_back(device_count / cum_size); - return HloSharding::PartialTile(TileAssignment( - dims, reshape_dims, sharding_param.minor_to_major().permutation)); + return HloSharding::PartialTile( + TileAssignment(dims, reshape_dims, permutation)); } else { - return HloSharding::IotaTile(dims, reshape_dims, - sharding_param.minor_to_major().permutation); + return HloSharding::IotaTile(dims, reshape_dims, permutation); } } @@ -175,8 +182,16 @@ absl::StatusOr ToShardingParam(const HloSharding& hlo_sharding, llvm::reverse(tile_assignment.iota()->reshape_dims())) { minor_to_major.axis_sizes.push_back(reshape_dim); } - for (int axis_id : tile_assignment.iota()->transpose_perm()) { - minor_to_major.permutation.push_back(axis_id); + // The devices generated by HloSharding + // np.arange(ndevices).reshape(reshape_dims).transpose(transpose_perm) + // must be equal to the devices ShardingParam + // np.arange(ndevices).reshape(reverse(axis_size)).T.transpose(perm).T + // Step 1: Compute transpose(transpose_perm).T. + // Step 2: Compute T.transpose(transpose_perm).T. + int num_axis = tile_assignment.iota()->transpose_perm().size(); + for (int axis_id : + llvm::reverse(tile_assignment.iota()->transpose_perm())) { + minor_to_major.permutation.push_back(num_axis - axis_id - 1); } } return ShardingParam(dim_shards, std::move(minor_to_major)); diff --git a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc index 4f9cdd2f6ffe7b..22b213ff7c2d7d 100644 --- a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc +++ b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/python/ifrt/support/sharding_conversions.h" #include +#include #include #include @@ -35,6 +36,7 @@ limitations under the License. #include "xla/python/ifrt/sharding_test_util.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" +#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -44,22 +46,24 @@ namespace support { namespace { using ::tsl::testing::StatusIs; +using xla::HloSharding; -absl::StatusOr ToHloShardingViaOpSharding( +absl::StatusOr ToHloShardingViaOpSharding( const ShardingParam& sharding_param, absl::Span device_list) { TF_ASSIGN_OR_RETURN(xla::OpSharding op_sharding, ToOpSharding(sharding_param, device_list)); - return xla::HloSharding::FromProto(op_sharding); + return HloSharding::FromProto(op_sharding); } TEST(ShardingConversionsTest, Replicated) { ShardingParam expected_sharding_param{ /*dim_shards=*/{1, 1, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_iota_sharding, + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, ToHloSharding(expected_sharding_param)); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); EXPECT_EQ(hlo_sharding.ToString(), "{replicated}"); EXPECT_EQ(hlo_sharding, hlo_iota_sharding); @@ -67,7 +71,7 @@ TEST(ShardingConversionsTest, Replicated) { ToShardingParam(hlo_iota_sharding, 3, 6)); // We do not compare expected_sharding_param and sharding_param because they // haven't been canonicalized (1x1x1 to [0, 1] on 2x3 vs. 1x1x1 to [0] on 6). - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual_hlo_sharding, + TF_ASSERT_OK_AND_ASSIGN(const HloSharding actual_hlo_sharding, ToHloSharding(sharding_param)); EXPECT_EQ(hlo_iota_sharding, actual_hlo_sharding); } @@ -75,10 +79,11 @@ TEST(ShardingConversionsTest, Replicated) { TEST(ShardingConversionsTest, SingleDeviceReplicated) { ShardingParam expected_sharding_param{ /*dim_shards=*/{1, 1}, {/*permutation=*/{0}, /*axis_sizes=*/{1}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_iota_sharding, + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, ToHloSharding(expected_sharding_param)); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {0})); EXPECT_EQ(hlo_sharding.ToString(), "{replicated}"); EXPECT_EQ(hlo_sharding, hlo_iota_sharding); @@ -91,10 +96,11 @@ TEST(ShardingConversionsTest, Permutation) { ShardingParam expected_sharding_param{ /*dim_shards=*/{2, 1, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_iota_sharding, + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, ToHloSharding(expected_sharding_param)); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); EXPECT_EQ(hlo_sharding.ToString(), "{devices=[2,1,3]0,3,1,4,2,5}"); EXPECT_EQ(hlo_sharding, hlo_iota_sharding); @@ -106,10 +112,11 @@ TEST(ShardingConversionsTest, Permutation) { TEST(ShardingConversionsTest, Partial) { ShardingParam expected_sharding_param{ /*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_iota_sharding, + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, ToHloSharding(expected_sharding_param)); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3, 4, 5})); EXPECT_EQ(hlo_sharding.ToString(), "{devices=[2,1,3]0,1,2,3,4,5 last_tile_dim_replicate}"); @@ -118,7 +125,7 @@ TEST(ShardingConversionsTest, Partial) { ToShardingParam(hlo_iota_sharding, 2, 6)); // We do not compare expected_sharding_param and sharding_param because they // haven't been canonicalized (2x1 to [0, 1] on 2x3 vs. 2x1 to [0] on 6). - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding actual_hlo_sharding, + TF_ASSERT_OK_AND_ASSIGN(const HloSharding actual_hlo_sharding, ToHloSharding(sharding_param)); EXPECT_EQ(hlo_iota_sharding, actual_hlo_sharding); } @@ -126,10 +133,11 @@ TEST(ShardingConversionsTest, Partial) { TEST(ShardingConversionsTest, OneDimToTwoAxes) { ShardingParam expected_sharding_param{ /*dim_shards=*/{4}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{2, 2}}}; - TF_ASSERT_OK_AND_ASSIGN(const xla::HloSharding hlo_iota_sharding, + TF_EXPECT_OK(expected_sharding_param.verify()); + TF_ASSERT_OK_AND_ASSIGN(const HloSharding hlo_iota_sharding, ToHloSharding(expected_sharding_param)); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {0, 1, 2, 3})); EXPECT_EQ(hlo_sharding.ToString(), "{devices=[4]0,2,1,3}"); EXPECT_EQ(hlo_sharding, hlo_iota_sharding); @@ -142,21 +150,116 @@ TEST(ShardingConversionsTest, NonTrivialDeviceAssignment) { ShardingParam expected_sharding_param{ /*dim_shards=*/{2, 1, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(expected_sharding_param.verify()); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(expected_sharding_param, {6, 5, 4, 3, 2, 1})); EXPECT_EQ(hlo_sharding.ToString(), "{devices=[2,1,3]6,3,5,2,4,1}"); } +TEST(ShardingConversionsTest, VerifyIncorrectShardings) { + ShardingParam different_permutation_and_axis{ + /*dim_shards=*/{1, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2}}}; + EXPECT_FALSE(different_permutation_and_axis.verify().ok()); + ShardingParam too_many_slices{/*dim_shards=*/{2, 2}, + {/*permutation=*/{0}, /*axis_sizes=*/{2}}}; + EXPECT_FALSE(too_many_slices.verify().ok()); + ShardingParam cannot_distribute_slices{ + /*dim_shards=*/{1, 2}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{3, 2}}}; + EXPECT_FALSE(cannot_distribute_slices.verify().ok()); + ShardingParam incorrect_permutation{ + /*dim_shards=*/{4, 1}, + {/*permutation=*/{0, 1, 1}, /*axis_sizes=*/{2, 2, 2}}}; + EXPECT_FALSE(incorrect_permutation.verify().ok()); +} + TEST(ShardingConversionsTest, ErrorOnDeviceAssignment) { ShardingParam sharding_param{/*dim_shards=*/{2, 1, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(sharding_param.verify()); EXPECT_THAT( ToHloShardingViaOpSharding(sharding_param, {6, 5, 4, 3, 2}), StatusIs(absl::StatusCode::kOutOfRange, ::testing::HasSubstr("Can't map device with logical id 5"))); } +struct HloShardingTestStruct { + HloSharding hlo_sharding; + int rank; + int num_devices; +}; + +using HloShardingToShardingParamTest = + ::testing::TestWithParam; + +TEST_P(HloShardingToShardingParamTest, HloShardingToShardingParam) { + const auto& param = GetParam(); + TF_ASSERT_OK_AND_ASSIGN( + auto sharding_param, + ToShardingParam(param.hlo_sharding, param.rank, param.num_devices)); + // We cannot verify sharding param because we're losing info about the + // axis_size during these conversions. While strictly some ShardingParam + // are invalid because they have more dims than axis, in practice this is not + // a problem because we can still correctly map the shards to the devices. + TF_ASSERT_OK_AND_ASSIGN(auto actual_hlo_sharding, + ToHloSharding(sharding_param)); + EXPECT_EQ(param.hlo_sharding, actual_hlo_sharding); + // Verify that the conversion to OpSharding is also correct. + std::vector device_ids(param.num_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + TF_ASSERT_OK_AND_ASSIGN( + auto hlo_via_op_sharding, + ToHloShardingViaOpSharding(sharding_param, device_ids)); + EXPECT_EQ(param.hlo_sharding, hlo_via_op_sharding); +} + +INSTANTIATE_TEST_SUITE_P( + HloShardingConversionTests, HloShardingToShardingParamTest, + testing::ValuesIn({ + {HloSharding::IotaTile({4, 2}), 2, 8}, + {HloSharding::IotaTile({2, 4}, {4, 2}, {1, 0}), 2, 8}, + {HloSharding::IotaTile({8, 1}), 2, 8}, + {HloSharding::IotaTile({8, 1}, {4, 2}, {1, 0}), 2, 8}, + {HloSharding::PartialTile(TileAssignment({4, 1, 2}, {8}, {0})), 2, 8}, + {HloSharding::PartialTile(TileAssignment({2, 1, 4}, {4, 2}, {1, 0})), 2, + 8}, + {HloSharding::PartialTile(TileAssignment({1, 4, 2}, {8}, {0})), 2, 8}, + {HloSharding::PartialTile(TileAssignment({1, 2, 4}, {4, 2}, {1, 0})), 2, + 8}, + {HloSharding::PartialTile(TileAssignment({4, 3, 2}, {2, 3, 4}, + {2, 1, 0})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({4, 2, 3}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({6, 1, 4}, {24}, {0})), 2, 24}, + {HloSharding::PartialTile(TileAssignment({12, 1, 2}, {2, 12}, {1, 0})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({8, 1, 3}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({2, 1, 12}, {24}, {0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({3, 1, 8}, {2, 3, 4}, + {1, 0, 2})), + 2, 24}, + {HloSharding::PartialTile(TileAssignment({1, 4, 6}, {6, 4}, {1, 0})), 2, + 24}, + {HloSharding::PartialTile(TileAssignment({1, 12, 2}, {2, 12}, {1, 0})), + 2, 24}, + + {HloSharding::PartialTile(TileAssignment({3, 2, 1, 4}, {2, 3, 4}, + {1, 0, 2})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({2, 4, 1, 3}, {2, 3, 4}, + {0, 2, 1})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({4, 3, 1, 2}, {2, 3, 4}, + {2, 1, 0})), + 3, 24}, + {HloSharding::PartialTile(TileAssignment({12, 1, 1, 2}, {2, 12}, + {1, 0})), + 3, 24}, + })); + class ShardingConversionsEquivalentTest : public test_util::ShardingTest { public: void AssertSameTiling(const ShardingParam& sharding_param, @@ -187,8 +290,9 @@ class ShardingConversionsEquivalentTest : public test_util::ShardingTest { TEST_P(ShardingConversionsEquivalentTest, ShardingParamFullySharded) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(sharding_param.verify()); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } @@ -196,8 +300,9 @@ TEST_P(ShardingConversionsEquivalentTest, ShardingParamFullySharded) { TEST_P(ShardingConversionsEquivalentTest, ShardingParamWithPermutation) { ShardingParam sharding_param{/*dim_shards=*/{2, 3}, {/*permutation=*/{1, 0}, /*axis_sizes=*/{3, 2}}}; + TF_EXPECT_OK(sharding_param.verify()); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } @@ -205,8 +310,9 @@ TEST_P(ShardingConversionsEquivalentTest, ShardingParamWithPermutation) { TEST_P(ShardingConversionsEquivalentTest, ShardingParamWithReplication) { ShardingParam sharding_param{/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 3}}}; + TF_EXPECT_OK(sharding_param.verify()); TF_ASSERT_OK_AND_ASSIGN( - const xla::HloSharding hlo_sharding, + const HloSharding hlo_sharding, ToHloShardingViaOpSharding(sharding_param, {0, 1, 2, 3, 4, 5})); AssertSameTiling(sharding_param, hlo_sharding, Shape({6, 6})); } @@ -215,10 +321,11 @@ TEST_P(ShardingConversionsEquivalentTest, OpShardingReplicated) { OpSharding op_sharding; op_sharding.set_type(OpSharding::REPLICATED); TF_ASSERT_OK_AND_ASSIGN(auto hlo_sharding, - xla::HloSharding::FromProto(op_sharding)); + HloSharding::FromProto(op_sharding)); TF_ASSERT_OK_AND_ASSIGN(auto actual, ToShardingParam(hlo_sharding, 2, 6)); ShardingParam expected{/*dim_shards=*/{1, 1}, {/*permutation=*/{0}, /*axis_sizes=*/{6}}}; + TF_EXPECT_OK(expected.verify()); EXPECT_EQ(actual, expected); } From 531dc955df01e8d7b259f336f1a276e92a407871 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 18 Mar 2024 13:35:32 -0700 Subject: [PATCH 035/670] Make mock_nccl_utils.cc compile. PiperOrigin-RevId: 616932122 --- third_party/xla/xla/service/gpu/mock_nccl_utils.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/mock_nccl_utils.cc b/third_party/xla/xla/service/gpu/mock_nccl_utils.cc index bca0bc01bcef3f..56782e9e6d777a 100644 --- a/third_party/xla/xla/service/gpu/mock_nccl_utils.cc +++ b/third_party/xla/xla/service/gpu/mock_nccl_utils.cc @@ -53,7 +53,6 @@ limitations under the License. #include "third_party/gpus/nccl/include/info.h" #include "third_party/gpus/nccl/include/nccl_common.h" #include "third_party/nccl/nccl.h" -#include "third_party/gpus/nccl/src/include/device.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/primitive_util.h" From ad6df1bba30e0f580006ddef18c2f8e1d81e9412 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 18 Mar 2024 13:37:12 -0700 Subject: [PATCH 036/670] Deduplicate BuildAttributesMap code. PiperOrigin-RevId: 616932606 --- .../xla/xla/service/gpu/fusions/custom.cc | 52 ------------------- .../xla/service/gpu/ir_emitter_unnested.cc | 51 ------------------ third_party/xla/xla/service/gpu/runtime/BUILD | 3 +- .../service/gpu/runtime/custom_call_thunk.cc | 51 ++++++++++++++++++ .../service/gpu/runtime/custom_call_thunk.h | 5 ++ 5 files changed, 58 insertions(+), 104 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index fa910bc2589cf6..8027bd69756a3d 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -79,58 +79,6 @@ absl::StatusOr> BuildCustomKernelThunkForFusion( &fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); } -// TODO(vuson): this is duplicated from ir_emitter_unnested.cc -// Converts MLIR dictionary attribute attached to a custom call operation to a -// custom call thunk attributes that are forwarded to the FFI handler. -static absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - CustomCallThunk::AttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); - - auto integer = [&](mlir::IntegerAttr integer) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", name)); - } - }; - - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; - - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); - }; - - TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(integer) - .Case(fp) - .Case(str) - .Default([&](mlir::Attribute) { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute type for attribute: ", name)); - })); - } - return attributes; -} - absl::StatusOr GetSliceWithUpdatedOffsetAndSize( const BufferAssignment& buffer_assignment, const HloFusionAdaptor& fusion, const HloInstruction& fusion_instr, const HloInstruction& start, diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 1826b5723b27f0..79eca88e8f96ea 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1303,57 +1303,6 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) { } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// Converts MLIR dictionary attribute attached to a custom call operation to a -// custom call thunk attributes that are forwarded to the FFI handler. -static absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - CustomCallThunk::AttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); - - auto integer = [&](mlir::IntegerAttr integer) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", name)); - } - }; - - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; - - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); - }; - - TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(integer) - .Case(fp) - .Case(str) - .Default([&](mlir::Attribute) { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute type for attribute: ", name)); - })); - } - return attributes; -} - absl::Status IrEmitterUnnested::EmitCustomCallThunk( const HloCustomCallInstruction* instr) { const std::string call_target_name = instr->custom_call_target(); diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index e1ce6bd1e4c04a..df151764e4da79 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -470,10 +470,11 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_types_header", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:errors", ], ) diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 28a7dcebfc1dfa..0edf3b7c9dced4 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/TypeSwitch.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/call_frame.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/status.h" #include "xla/stream_executor/device_memory.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/stream_executor/gpu/gpu_stream.h" @@ -149,5 +151,54 @@ absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { return handler_ ? ExecuteFfiHandler(params) : ExecuteCustomCall(params); } +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + CustomCallThunk::AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(integer) + .Case(fp) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index dd445a248935e0..12d62c67c9af09 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -120,6 +120,11 @@ class CustomCallThunk : public Thunk { const HloComputation* called_computation_ = nullptr; }; +// Converts MLIR dictionary attribute attached to a custom call operation to a +// custom call thunk attributes that are forwarded to the FFI handler. +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict); + } // namespace gpu } // namespace xla From b960be62b93ba5f14a7073c0d84d79a11d615fde Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 18 Mar 2024 13:37:15 -0700 Subject: [PATCH 037/670] Call DropAllControlDeps before removing an instruction from the computation during RematerializeInstructions, or RemoveInstruction will fail. Before this change, hlo-opt was failing with this stack: INTERNAL: RET_CHECK failure (third_party/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc:426) ignore_safety_check || IsSafelyRemovable(instruction) cannot remove instruction: %all-reduce-start.285 = (s32[8,64]{1,0}, s32[], s32[], s32[], s32[], /*index=5*/s32[], s32[], s32[], s32[], s32[], /*index=10*/s32[], s32[], s32[], s32[], s32[], /*index=15*/s32[], s32[], s32[], s32[], s32[]) all-reduce-start(s32[8,64]{1,0} %input_scatter_fusion.16, s32[] %convert.81230.0.remat, s32[] %copy.13646, s32[] %copy.13647, s32[] %copy.13648, /*index=5*/s32[] %copy.13649, s32[] %copy.13650, s32[] %copy.13651, s32[] %copy.13652, s32[] %copy.13653, /*index=10*/s32[] %copy.13654, s32[] %copy.13655, s32[] %copy.13656, s32[] %copy.13657, s32[] %copy.13658, /*index=15*/s32[] %copy.13659, s32[] %copy.13660, s32[] %copy.13661, s32[] %copy.13662, s32[] %copy.13663), channel_id=711, replica_groups={{0,4,8,12,16,20,24,28,32,36,40,44,48,52,56,60,64,68,72,76,80,84,88,92,96,100,104,108,112,116,120,124},{1,5,9,13,17,21,25,29,33,37,41,45,49,53,57,61,65,69,73,77,81,85,89,93,97,101,105,109,113,117,121,125},{2,6,10,14,18,22,26,30,34,38,42,46,50,54,58,62,66,70,74,78,82,86,90,94,98,102,106,110,114,118,122,126},{3,7,11,15,19,23,27,31,35,39,43,47,51,55,59,63,67,71,75,79,83,87,91,95,99,103,107,111,115,119,123,127}}, use_global_device_ids=true, to_apply=%region_597.10276.clone.1, control-predecessors={%copy.14013, %copy.14014, %copy.14015, %copy.14016, %copy.14017, %copy.14018, %copy.14019, %copy.14020, %copy.14021, %copy.14022, %copy.14023, %copy.14024, %copy.14025, %copy.14026, %copy.14027, %copy.14028, %copy.14029, %copy.14030, %copy.14031, %copy.14032, %copy.14033, %copy.14034, %copy.14035, %copy.14036, %copy.14037, %copy.14038, %copy.14039, %copy.14040, %copy.14041, %copy.14042}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":false,"no_parallel_custom_call":false}} === Source Location Trace: === third_party/tensorflow/compiler/xla/status_macros.cc:80 third_party/tensorflow/compiler/xla/service/hlo_rematerialization.cc:2165 third_party/tensorflow/compiler/xla/service/hlo_rematerialization.cc:2564 third_party/tensorflow/compiler/xla/service/hlo_rematerialization.cc:2708 third_party/tensorflow/compiler/xla/service/hlo_rematerialization.cc:2767 third_party/tensorflow/compiler/xla/service/hlo_rematerialization.cc:2905 third_party/tensorflow/compiler/xla/service/hlo_pass_pipeline.h:140 third_party/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc:185 third_party/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:2183 third_party/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1923 third_party/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:2012 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_lib.cc:155 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_lib.cc:108 third_party/tensorflow/compiler/xla/tools/hlo_opt/gpu_opt.cc:80 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_main.cc:166 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_main.cc:177 After this change, hlo-opt fails with an expected failure due to no JAX: UNIMPLEMENTED: No registered implementation for custom call to cu_threefry2x32 for platform CUDA === Source Location Trace: === third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:1391 third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:3023 third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:2574 third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:2015 third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:3023 third_party/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc:205 third_party/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1949 third_party/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:2028 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_lib.cc:155 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_lib.cc:108 third_party/tensorflow/compiler/xla/tools/hlo_opt/gpu_opt.cc:80 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_main.cc:166 third_party/tensorflow/compiler/xla/tools/hlo_opt/opt_main.cc:177 PiperOrigin-RevId: 616932625 --- third_party/xla/xla/service/hlo_rematerialization.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/service/hlo_rematerialization.cc b/third_party/xla/xla/service/hlo_rematerialization.cc index 3197bb094b5f76..a524e3cca559dc 100644 --- a/third_party/xla/xla/service/hlo_rematerialization.cc +++ b/third_party/xla/xla/service/hlo_rematerialization.cc @@ -2162,6 +2162,10 @@ absl::StatusOr RematerializeInstructions( VLOG(2) << "The old instruction " << best->name() << " is an async op. Removing to maintain one start to one done " "invariant to keep the HLO valid."; + // We need to remove all control dependencies from best before removing it + // from the computation. Its control dependencies were previously copied + // to the remat instruction. + TF_RETURN_IF_ERROR(best->DropAllControlDeps()); TF_RETURN_IF_ERROR(computation->RemoveInstruction(best)); } } From 30488e7e5785733ec638fb1839ca75173ac97e5f Mon Sep 17 00:00:00 2001 From: "Jae H. Yoo" Date: Mon, 18 Mar 2024 13:48:08 -0700 Subject: [PATCH 038/670] Add flatbuffer export/import for bfloat16. PiperOrigin-RevId: 616935733 --- .../compiler/mlir/lite/flatbuffer_export.cc | 2 + tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 +- .../lite/tests/flatbuffer2mlir/cast_bf16.mlir | 12 +++ .../compiler/mlir/lite/tests/legalize-tf.mlir | 12 +++ .../lite/tests/mlir2flatbuffer/cast_bf16.mlir | 74 +++++++++++++++++++ .../mlir/lite/utils/const_tensor_utils.cc | 43 ++++++++--- .../lite/tools/versioning/op_version.cc | 7 +- 7 files changed, 138 insertions(+), 16 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir create mode 100644 tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 1a9ff8016649ef..dd28efd44eab14 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -181,6 +181,8 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_FLOAT32; } else if (type.isF16()) { return tflite::TensorType_FLOAT16; + } else if (type.isBF16()) { + return tflite::TensorType_BFLOAT16; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; } else if (type.isa()) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 55388c86dfc7bf..481f5573058b8c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -3926,10 +3926,10 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TFL_TensorOf<[F16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input + TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input ); - let results = (outs TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir new file mode 100644 index 00000000000000..56068d605016e7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir @@ -0,0 +1,12 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s +// Ensure cast with bfloat16 roundtrip exactly + +func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { +^bb0(%arg0: tensor<4x5xbf16>): + // CHECK-LABEL: @main + // CHECK: (tensor<4x5xbf16>) -> tensor<4x5xf32> + // CHECK-NEXT: (tensor<4x5xf32>) -> tensor<4x5xbf16> + %0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") + %1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") + func.return %1 : tensor<4x5xbf16> +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 685efd5be0ca2d..a0b9f90a879507 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1875,6 +1875,18 @@ func.func @matmul_batchv3_unknown_dim(%arg0: tensor, %arg1: tensor< // CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor } +func.func @matmul_batchv3_unknown_dim_bf16(%arg0: tensor, %arg1: tensor<5x6xf32>) -> tensor { + %0 = "tf.Cast"(%arg0) : (tensor) -> tensor + %1 = "tf.BatchMatMulV3"(%0, %arg1) {Ta = "tfdtype$DT_FLOAT", Tb = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor, tensor<5x6xf32>) -> tensor + %2 = "tf.Cast"(%1) : (tensor) -> tensor + func.return %2 : tensor +// CHECK-LABEL: matmul_batchv3_unknown_dim_bf16 +// CHECK: [[CST:%.*]] = "tfl.cast"(%arg0) : (tensor) -> tensor +// CHECK: [[BMM:%.*]] = "tfl.batch_matmul"([[CST]], %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<5x6xf32>) -> tensor +// CHECK: "tfl.cast"([[BMM]]) : (tensor) -> tensor +} + // ----- func.func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tensor<1x1x1x1x1x4xf32>, %arg2 : tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir new file mode 100644 index 00000000000000..83255ca39a4472 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir @@ -0,0 +1,74 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s + +func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { +^bb0(%arg0: tensor<4x5xbf16>): + +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 53, +// CHECK-NEXT: version: 7, +// CHECK-NEXT: builtin_code: CAST +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: type: BFLOAT16, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "cast1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4, 5 ], +// CHECK-NEXT: type: BFLOAT16, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "cast2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: has_rank: true +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0 ], +// CHECK-NEXT: outputs: [ 1 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: inputs: [ 1 ], +// CHECK-NEXT: outputs: [ 2 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: } ], +// CHECK-NEXT: signature_defs: [ ] +// CHECK-NEXT: } + + %0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") + %1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") + func.return %1 : tensor<4x5xbf16> +} diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 5ce7638f4e4da1..96d75cca30a48d 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -345,22 +345,41 @@ StatusOr ConvertFloatBuffer( switch (elem_type.getIntOrFloatBitWidth()) { case 16: { assert(bytes_len % 2 == 0); - assert(elem_type.isF16()); + // Supports both BF16 and F16. + assert(elem_type.isF16() || elem_type.isBF16()); int elem_count = bytes_len / 2; - std::vector values; - values.reserve(elem_count); - const char* data = reinterpret_cast(buffer.data()); + if (elem_type.isF16()) { + std::vector values; + values.reserve(elem_count); - for (int i = 0; i < elem_count; i++) { - uint16_t bit_repr = - llvm::support::endian::readNext(data); - values.push_back(Eigen::numext::bit_cast(bit_repr)); - } + const char* data = reinterpret_cast(buffer.data()); - return mlir::ElementsAttr( - DenseElementsAttr::get(shaped_type, ArrayRef(values))); + for (int i = 0; i < elem_count; i++) { + uint16_t bit_repr = llvm::support::endian::readNext< + uint16_t, llvm::endianness::native, llvm::support::unaligned>( + data); + values.push_back(Eigen::numext::bit_cast(bit_repr)); + } + + return mlir::ElementsAttr( + DenseElementsAttr::get(shaped_type, ArrayRef(values))); + } else { + std::vector values; + values.reserve(elem_count); + + const char* data = reinterpret_cast(buffer.data()); + + for (int i = 0; i < elem_count; i++) { + uint16_t bit_repr = llvm::support::endian::readNext< + uint16_t, llvm::endianness::native, llvm::support::unaligned>( + data); + values.push_back(Eigen::numext::bit_cast(bit_repr)); + } + + return mlir::ElementsAttr(DenseElementsAttr::get( + shaped_type, ArrayRef(values))); + } } case 32: { assert(bytes_len % 4 == 0); diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index b5d8bb151e7145..e6044fc6881990 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -1045,8 +1045,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 2; case BuiltinOperator_CAST: - if (op_sig.inputs.at(0).type == kTfLiteInt4 && - op_sig.outputs.at(0).type == kTfLiteFloat32) { + if (op_sig.inputs.at(0).type == kTfLiteBFloat16 || + op_sig.outputs.at(0).type == kTfLiteBFloat16) { + return 7; + } else if (op_sig.inputs.at(0).type == kTfLiteInt4 && + op_sig.outputs.at(0).type == kTfLiteFloat32) { return 6; } else if (op_sig.inputs.at(0).type == kTfLiteFloat64 || op_sig.outputs.at(0).type == kTfLiteFloat64 || From 5cac1546945d2971ab5c0f1ea9ead1c4c32f0bf0 Mon Sep 17 00:00:00 2001 From: Harshit Monish <143435143+hmonishN@users.noreply.github.com> Date: Mon, 18 Mar 2024 14:20:06 -0700 Subject: [PATCH 039/670] PR #10635: Fix build error from PR 10497 Imported from GitHub PR https://github.com/openxla/xla/pull/10635 Added changes to fix build error that were encountered after merging changes from PR: https://github.com/openxla/xla/pull/10497 Used ::tsl::testing::StatusIs instead of ::testing::status::StatusIs Copybara import of the project: -- d344105a62e5f1749d994c2f1f2a8a76a5880c3d by hmonishN : Adding changes to use tsl::testing:StatusIs Merging this change closes #10635 PiperOrigin-RevId: 616945601 --- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/autotuner_util_test.cc | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 508fd8f638e8a5..b94c1f10f9ce84 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -5968,6 +5968,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", ]) + [ diff --git a/third_party/xla/xla/service/gpu/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuner_util_test.cc index b755334876cc26..28ec27c64e8da0 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuner_util_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/status_matchers.h" namespace xla { namespace gpu { @@ -45,7 +46,7 @@ using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Not; using ::testing::TempDir; -using ::testing::status::StatusIs; +using ::tsl::testing::StatusIs; class AutotunerUtilTest : public HloTestBase { protected: From 98830e91ac878d1ee15b7539dc72bb06adb2fc2f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 18 Mar 2024 14:59:44 -0700 Subject: [PATCH 040/670] [xla:hlo] Use llvm::BitVector instead of a set when checking reachability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_HloDfsReachabilityBuild/1 109ns ± 4% 111ns ± 4% ~ BM_HloDfsReachabilityBuild/64 1.71µs ± 6% 1.71µs ± 4% ~ BM_HloDfsReachabilityBuild/128 3.38µs ± 3% 3.43µs ± 3% +1.54% BM_HloDfsReachabilityBuild/256 6.80µs ± 4% 6.95µs ± 5% +2.25% BM_HloDfsReachabilityBuild/512 13.8µs ± 4% 14.2µs ± 6% +2.63% BM_HloDfsReachabilityBuild/4096 155µs ± 4% 157µs ± 4% ~ BM_HloDfsReachabilityBuild/32768 1.42ms ± 5% 1.45ms ± 3% +1.94% BM_HloDfsReachabilityBuild/262144 32.2ms ± 4% 32.1ms ± 4% ~ BM_HloDfsReachabilityCheck/1 7.37ns ± 3% 7.41ns ± 4% ~ BM_HloDfsReachabilityCheck/64 295ns ± 5% 139ns ± 8% -52.78% BM_HloDfsReachabilityCheck/128 679ns ± 3% 278ns ± 7% -59.05% BM_HloDfsReachabilityCheck/256 1.53µs ± 5% 0.61µs ± 6% -60.06% BM_HloDfsReachabilityCheck/512 3.06µs ± 5% 1.31µs ± 6% -57.27% BM_HloDfsReachabilityCheck/4096 30.2µs ± 7% 17.9µs ± 4% -40.53% BM_HloDfsReachabilityCheck/32768 532µs ± 4% 327µs ± 5% -38.52% BM_HloDfsReachabilityCheck/262144 8.72ms ± 3% 6.66ms ± 4% -23.59% PiperOrigin-RevId: 616956892 --- .../xla/xla/hlo/ir/hlo_dfs_reachability.cc | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc index 2e6bd0e8495369..ae9b25f7453e98 100644 --- a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc +++ b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallVector.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -45,16 +45,20 @@ bool HloDfsReachability::IsReachable(const HloInstruction* from, // Note that the DFS goes from the "uses" root towards the "defs", i.e. from // `to` node to `from` node, so the node indices are decreasing. - if (target_node_idx > dfs_root_idx) { + if (dfs_root_idx < target_node_idx) { return false; } - // We use LLVM support library here because it has stack-allocated maps (in - // contrast to absl) which significantly improves performance by avoiding heap - // allocations when instructions are reachable via a short chain. - llvm::SmallDenseSet visited_idxs{dfs_root_idx}; + // We use LLVM support library here because it has stack-allocated bit vector + // which significantly improves performance by avoiding heap allocations when + // instructions are reachable via a short chain. llvm::SmallVector stack{to}; + // We will visit instructions in the [target_node_idx, dfs_root_idx] range, so + // we can construct a smaller bit vector. + llvm::BitVector visited_idxs(1 + (dfs_root_idx - target_node_idx)); + visited_idxs.set(dfs_root_idx - target_node_idx); + auto check_and_enqueue = [&](const HloInstruction* instr) { if (instr == from) { return true; @@ -63,9 +67,11 @@ bool HloDfsReachability::IsReachable(const HloInstruction* from, if (instr_idx < target_node_idx) { return false; } - if (auto [_, inserted] = visited_idxs.insert(instr_idx); !inserted) { + size_t visited_idx = instr_idx - target_node_idx; + if (visited_idxs.test(visited_idx)) { return false; } + visited_idxs.set(visited_idx); stack.push_back(instr); return false; }; From 1e3478bf321872470b63032ac98401905e1f81ad Mon Sep 17 00:00:00 2001 From: prrathi <53785742+prrathi@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:00:04 -0700 Subject: [PATCH 041/670] PR #10642: Make configure.py command visible Imported from GitHub PR https://github.com/openxla/xla/pull/10642 Copybara import of the project: -- d5df8a14b8837c0980a8641a83b9d6d9e33577cc by prrathi <53785742+prrathi@users.noreply.github.com>: Make configure.py command visible Merging this change closes #10642 PiperOrigin-RevId: 616956998 --- third_party/xla/docs/build_from_source.md | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index 91ef1e49608818..c273f7f3cdf8c0 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -10,13 +10,12 @@ If you did not clone the XLA repository or install Bazel, please check out the ### Configure XLA builds are configured by the `.bazelrc` file in the repository's root -directory. The `./configure.py` script can be used to adjust -common settings. +directory. The `./configure.py` script can be used to adjust common settings. -If you need to change the configuration, run the `./configure.py` script from the -repository's root directory. This script has flags for the location of XLA -dependencies and additional build configuration options (compiler -flags, for example). Refer to the *Sample session* section for details. +If you need to change the configuration, run the `./configure.py` script from +the repository's root directory. This script has flags for the location of XLA +dependencies and additional build configuration options (compiler flags, for +example). Refer to the *Sample session* section for details. ### CPU support @@ -27,26 +26,29 @@ We recommend using a suitable docker container to build/test XLA, such as docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash ``` -Using a docker container you can build XLA with CPU support using the following commands: +Using a docker container you can build XLA with CPU support using the following +commands: ``` docker exec xla ./configure.py --backend=CPU docker exec xla bazel build //xla/... --spawn_strategy=sandboxed --test_output=all ``` -If you want to build XLA targets with CPU support without Docker you need to install clang. XLA currently builds on CI with clang-17, but earlier versions should also work: +If you want to build XLA targets with CPU support without Docker you need to +install clang. XLA currently builds on CI with clang-17, but earlier versions +should also work: ``` apt install clang ``` Then configure and build targets using the following commands: -``` ./configure.py --backend=CPU +```sh +./configure.py --backend=CPU bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` - ### GPU support We recommend using the same docker container as above to build XLA with GPU @@ -76,6 +78,5 @@ Then configure and build targets using the following commands: bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` - For more details regarding [TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) From 8c29d81c5b08f8ae7e86f68cc7d8e49a7832b183 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 18 Mar 2024 15:02:24 -0700 Subject: [PATCH 042/670] PR #10562: Add missing warmup run when autotuning. Imported from GitHub PR https://github.com/openxla/xla/pull/10562 The `xla/service/gpu:gemm_algorithm_picker_test` test, run on V100, was hitting the delay kernel timeout because of this. See #9757 for explanation of why the best practice is to execute a warmup run **without the GpuTimer active**. Copybara import of the project: -- b4ccf2928ee45ec9139db003378095b948bb73d5 by Olli Lupton : Add missing warmup run when autotuning. The xla/service/gpu:gemm_algorithm_picker_test test, run on V100, was hitting the delay kernel timeout because of this. Merging this change closes #10562 PiperOrigin-RevId: 616957791 --- third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 9ed90d2ba6eadd..446cde8de272de 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -186,6 +186,12 @@ class GemmAutotuner { -> absl::StatusOr { se::OwningScratchAllocator<> scratch_allocator( stream_->parent()->device_ordinal(), autotune_config_.GetAllocator()); + // Run a warmup iteration without the profiler active. + TF_RETURN_IF_ERROR(plan->ExecuteOnStream( + stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_, + bias_buffer, aux_buffer, a_scale_buffer, b_scale_buffer, + c_scale_buffer, d_scale_buffer, d_amax_buffer, algorithm, + scratch_allocator)); se::blas::ProfileResult profile_result; TF_RETURN_IF_ERROR(plan->ExecuteOnStream( stream_, lhs_buffer_, rhs_buffer_, output_buffer_, output_buffer_, From 0ca187b3ced9b97ade8322a26d3a40ea2c9c38bd Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Mon, 18 Mar 2024 15:10:28 -0700 Subject: [PATCH 043/670] Force creating XlaSharding ops for optimizer slot variables. We do a read_value on the slot variables when creating the update op (eg. ResourceApplyAdagrad) for the optimizer to make sure the XlaSharding op is also generated. We make this a control dependency for the update op so that this appears before it. PiperOrigin-RevId: 616960034 --- .../python/compiler/xla/experimental/BUILD | 25 ++++ .../resource_variable_xla_sharding_test.py | 136 ++++++++++++++++++ tensorflow/python/training/optimizer.py | 25 +++- 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py diff --git a/tensorflow/python/compiler/xla/experimental/BUILD b/tensorflow/python/compiler/xla/experimental/BUILD index 8cc63502e0869a..c2e2dd9d45af60 100644 --- a/tensorflow/python/compiler/xla/experimental/BUILD +++ b/tensorflow/python/compiler/xla/experimental/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") +load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -40,3 +41,27 @@ py_strict_test( "@absl_py//absl/testing:absltest", ], ) + +tpu_py_strict_test( + name = "resource_variable_xla_sharding_test", + srcs = ["resource_variable_xla_sharding_test.py"], + disable_v3_4chips = False, + python_version = "PY3", + srcs_version = "PY3", + tags = ["requires-net:external"], + deps = [ + ":xla_sharding", + "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/tpu:device_assignment", + "//tensorflow/python/tpu:tpu_py", + "//tensorflow/python/training:adagrad", + ], +) diff --git a/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py b/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py new file mode 100644 index 00000000000000..ef7192a4f45807 --- /dev/null +++ b/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py @@ -0,0 +1,136 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from tensorflow.python.compiler.xla.experimental import xla_sharding +from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import config +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.tpu import device_assignment +from tensorflow.python.tpu import tpu +from tensorflow.python.training import adagrad + + +# Gets all the nodes of `op` in graph that have `input_node_name` as one of the +# inputs +def _get_op_nodes_with_input(input_node_name, op, graph): + nodes_with_input = [] + for node in graph.node: + nodes_with_input += [ + node + for input in node.input + if input == input_node_name and node.op == op + ] + return nodes_with_input + + +# Gets XlaSharding ops connected to ReadVariableOp for the given variable_name +def _get_xla_sharding_nodes_for_variable(variable_name, graph): + read_variable_op_nodes = _get_op_nodes_with_input( + variable_name, 'ReadVariableOp', graph + ) + xla_sharding_op_nodes = [] + for read_variable_op_node in read_variable_op_nodes: + xla_sharding_op_nodes += _get_op_nodes_with_input( + read_variable_op_node.name, 'XlaSharding', graph + ) + return xla_sharding_op_nodes + + +def _get_xla_sharding_proto_from_node(node): + sharding_proto = xla_sharding.xla_data_pb2.OpSharding() + sharding_proto.ParseFromString(node.attr['sharding'].s) + return sharding_proto + + +class ResourceVariableXlaShardingTest(test.TestCase): + + def setUp(self) -> None: + super().setUp() + + context.enable_xla_sharding_for_resource_variables() + self.topology = tpu_cluster_resolver.initialize_tpu_system() + if len(config.list_logical_devices('TPU')) != 8: + self.skipTest('All tests require 8 TPUs.') + + self.da = device_assignment.DeviceAssignment.build( + self.topology, computation_shape=[2, 2, 1, 2], num_replicas=1 + ) + + def test_xla_sharding_ops_created_for_optimizer_slot_variables(self): + w = variables.Variable( + initial_value=math_ops.range(8, dtype=dtypes.float32), + name='w', + ) + self.assertIsInstance(w, resource_variable_ops.BaseResourceVariable) + w = xla_sharding.split( + w, + split_dimension=0, + num_devices=8, + ) + sharding_proto = xla_sharding.xla_data_pb2.OpSharding() + sharding_proto.ParseFromString(xla_sharding.get_tensor_sharding(w)) + opt = adagrad.AdagradOptimizer(1.0) + + @def_function.function + def computation(x): + def tpu_fn(x): + y = math_ops.add(w, x) + loss = math_ops.reduce_sum(y) + opt.minimize(loss, None, [w]) + return loss + + output = tpu.replicate(tpu_fn, [[x]], device_assignment=self.da) + return output + + inputs = array_ops.reshape(math_ops.range(16, dtype=dtypes.float32), (2, 8)) + result = computation(inputs) + self.assertSequenceEqual([[176.0]], self.evaluate(result)) + graph = computation.get_concrete_function(inputs).graph.as_graph_def() + + update_op_nodes = [ + node for node in graph.node if node.op == 'ResourceApplyAdagrad' + ] + self.assertLen(update_op_nodes, 1) + update_op_node = update_op_nodes[0] + + var_input_name = update_op_node.input[0] + var_sharding_nodes = _get_xla_sharding_nodes_for_variable( + var_input_name, graph + ) + self.assertLen(var_sharding_nodes, 1) + self.assertProtoEquals( + _get_xla_sharding_proto_from_node(var_sharding_nodes[0]), sharding_proto + ) + + slot_var_input_name = update_op_node.input[1] + slot_var_sharding_nodes = _get_xla_sharding_nodes_for_variable( + slot_var_input_name, graph + ) + self.assertLen(slot_var_sharding_nodes, 1) + self.assertProtoEquals( + _get_xla_sharding_proto_from_node(slot_var_sharding_nodes[0]), + sharding_proto, + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 7536c6ce90692f..5a438ce2d52d05 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -169,7 +169,30 @@ def update_op(self, optimizer, g): "Cannot use a constraint function on a sparse variable.") return optimizer._resource_apply_sparse_duplicate_indices( g.values, self._v, g.indices) - update_op = optimizer._resource_apply_dense(g, self._v) + + if context.xla_sharding_for_resource_variables_enabled(): + # For each slot variable that is annotated with an XLA sharding, we read + # the variable and assign the value to itself. This is done to trigger the + # creation of an XlaShardingOp when a ReadVariableOp is created upon the + # call to `slot_var.read_value()`. This is needed to ensure that slot + # variables with XLA sharding are sharded correctly. Please see + # b/307541427 for more details. + assign_ops = [] + for variable_dict in optimizer._slots.values(): + for slot_var in variable_dict.values(): + if ( + isinstance(slot_var, resource_variable_ops.BaseResourceVariable) + and slot_var._get_xla_sharding() is not None + ): + assign_ops.append(slot_var.assign(slot_var.read_value())) + + # The assign_ops created above are added as a control dependency for the + # update op to make sure these appear before the update_op. + with ops.control_dependencies(assign_ops): + update_op = optimizer._resource_apply_dense(g, self._v) + else: + update_op = optimizer._resource_apply_dense(g, self._v) + if self._v.constraint is not None: with ops.control_dependencies([update_op]): return self._v.assign(self._v.constraint(self._v)) From cd60a4dd509b300469607267b00a3513796748e6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 18 Mar 2024 15:11:02 -0700 Subject: [PATCH 044/670] [xla:hlo] Do not compute channel dependencies when building DFS reachability PiperOrigin-RevId: 616960175 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 4 +++- third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 16c9e0ce37b1c8..7d8a080bd3840a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -506,6 +507,7 @@ void HloComputation::ForEachInstructionPostOrderImpl( absl::FunctionRef func, HloInstruction* root, const ChannelDependencies& channel_dependencies, VisitMap& visited, std::vector* dfs_stack_scratch) const { + bool has_channel_dependencies = !channel_dependencies.empty(); auto* dfs_stack = dfs_stack_scratch; dfs_stack->clear(); dfs_stack->push_back(root); @@ -532,7 +534,7 @@ void HloComputation::ForEachInstructionPostOrderImpl( // Collectives with the same channel ID must be performed together, as these // represent MPMD-partitioned that will later be split into separate modules // and the order must be preserved. - if (¤t != root) { + if (has_channel_dependencies && ¤t != root) { auto it = channel_dependencies.find(¤t); if (it != channel_dependencies.end()) { dfs_stack->insert(dfs_stack->end(), it->second.begin(), diff --git a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc index ae9b25f7453e98..c831f31cec03f1 100644 --- a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc +++ b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc @@ -96,10 +96,11 @@ std::unique_ptr HloDfsReachability::Build( const HloComputation* computation) { auto res = std::make_unique(); - HloComputation::ChannelDependencies channel_dependencies = - computation->ComputeChannelDependencies(); + // For instruction reachability we do not care about correct order of + // collective operations as we only care about use-def chains. + HloComputation::ChannelDependencies empty_channel_dependencies; std::vector instructions = - computation->MakeInstructionPostOrder(channel_dependencies); + computation->MakeInstructionPostOrder(empty_channel_dependencies); res->instruction_to_idx_.reserve(instructions.size()); for (size_t i = 0; i < instructions.size(); ++i) { From 392a5f0d120a82e1bd4c9af486fe786ee04931bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 15:11:57 -0700 Subject: [PATCH 045/670] Temporarily disables failing shared_batch_scheduler_test on Windows. PiperOrigin-RevId: 616960378 --- tensorflow/core/kernels/batching_util/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index d34bd7331a35d5..828b1c0f60d4fb 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -190,6 +190,7 @@ tf_cc_test( name = "shared_batch_scheduler_test", size = "small", srcs = ["shared_batch_scheduler_test.cc"], + tags = ["no_windows"], deps = [ ":batch_scheduler", ":fake_clock_env", From ff0308108e154928b033b1ab01fb8512b9786c18 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 15:55:04 -0700 Subject: [PATCH 046/670] Fix readability issues in `quantization_driver.h/cc`. PiperOrigin-RevId: 616971652 --- .../mlir/lite/quantization/lite/BUILD | 1 + .../common/quantization_lib/BUILD | 1 + .../quantization_lib/quantization_driver.cc | 408 +++++++++--------- .../quantization_lib/quantization_driver.h | 154 +++---- .../quantization_lib/quantization_utils.h | 8 +- 5 files changed, 299 insertions(+), 273 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index ad7c1905440297..a0f55e0408932f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -91,6 +91,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index d41a189519fd6d..a0d64569562d38 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/tools/optimize:quantization_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 962c6656f55b65..327d109946e031 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -26,7 +26,6 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -47,39 +46,44 @@ limitations under the License. namespace mlir { namespace quant { - namespace { -// This is used to identify an operand or result of an op. The second element -// of this pair is the index of the operand or result. -using OpValue = std::pair; + +constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; // Uses the type of `value` to set the initial state of the index-th result if // `as_result` is true or index-th operand if `as_result` is false. The state // is immutable if the type is a quantized type. Returns the index of this // new state in the state vector. -void InitializeStateForValue(Operation* op, const int index, const Value value, - const bool as_result, - std::vector* states, - llvm::DenseMap* value_to_state, - llvm::DenseMap* operand_states, - llvm::DenseMap* result_states) { - const auto [cached, inserted] = value_to_state->insert({value, 0}); +void InitializeStateForValue( + Operation* op, const int index, const Value value, const bool as_result, + std::vector& states, + DenseMap& value_to_state, + DenseMap& operand_states, + DenseMap& result_states) { + const auto [cached, inserted] = value_to_state.try_emplace(value, 0); if (!inserted) { - if (as_result) - (*result_states)[{op, index}] = cached->second; - else - (*operand_states)[{op, index}] = cached->second; + if (as_result) { + result_states[{op, index}] = cached->second; + } else { + operand_states[{op, index}] = cached->second; + } return; } - const QuantParams params = - quant::QuantizedType::getQuantizedElementType(value.getType()); - const bool immutable = !HasQuantParams(params); - const int next_state_index = states->size(); - states->push_back({params, immutable}); - if (as_result) - (*result_states)[{op, index}] = next_state_index; - else - (*operand_states)[{op, index}] = next_state_index; + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(value.getType()); + + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states.size(); + states.push_back({quantized_type, immutable}); + if (as_result) { + result_states[{op, index}] = next_state_index; + } else { + operand_states[{op, index}] = next_state_index; + } + cached->second = next_state_index; } @@ -87,32 +91,31 @@ void InitializeStateForValue(Operation* op, const int index, const Value value, void QuantizationDriver::InitializeArgState(const BlockArgument arg, const Value arg_value) { - const auto [cached, inserted] = value_to_state_.insert({arg_value, 0}); + const auto [cached, inserted] = value_to_state_.try_emplace(arg_value, 0); if (!inserted) { arg_states_[arg] = cached->second; return; } - const QuantParams params = - quant::QuantizedType::getQuantizedElementType(arg_value.getType()); - const bool immutable = !HasQuantParams(params); - const int next_state_index = states_.size(); - states_.push_back({params, immutable}); + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(arg_value.getType()); + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states_.size(); + states_.push_back({quantized_type, immutable}); arg_states_[arg] = next_state_index; cached->second = next_state_index; } void QuantizationDriver::InitializeOperandState(Operation* op, const int index, const Value value) { - ::mlir::quant::InitializeStateForValue(op, index, value, /*as_result=*/false, - &states_, &value_to_state_, - &operand_states_, &result_states_); + InitializeStateForValue(op, index, value, /*as_result=*/false, states_, + value_to_state_, operand_states_, result_states_); } void QuantizationDriver::InitializeResultState(Operation* op, const int index, const Value value) { - ::mlir::quant::InitializeStateForValue(op, index, value, /*as_result=*/true, - &states_, &value_to_state_, - &operand_states_, &result_states_); + InitializeStateForValue(op, index, value, /*as_result=*/true, states_, + value_to_state_, operand_states_, result_states_); } std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { @@ -133,11 +136,11 @@ bool QuantizationDriver::IsQuantized(Operation* op) { bool QuantizationDriver::SetConstantResultParams(Operation* op) { DenseFPElementsAttr attr; - const Value res = op->getResult(0); - if (!matchPattern(res, m_Constant(&attr))) { + const Value result = op->getResult(0); + if (!matchPattern(result, m_Constant(&attr))) { return false; } - // TODO(fengliuai): make storage_type_width and narrow_range configurable. + // TODO: b/323478683 - Make storage_type_width and narrow_range configurable. Type final_type; const auto it = optimized_weights_.find(op); const bool is_weight = it != optimized_weights_.end(); @@ -159,42 +162,44 @@ bool QuantizationDriver::SetConstantResultParams(Operation* op) { final_type = GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/is_weight && is_signed_, /*num_bits=*/8, is_signed_, - /*narrow_range_=*/is_weight, legacy_float_scale_); + /*narrow_range=*/is_weight, legacy_float_scale_); } - if (const auto quant_type = - final_type.dyn_cast_or_null()) { - return SetResultParams(op, 0, quant_type); + if (const auto quant_type = final_type.dyn_cast_or_null(); + quant_type != nullptr) { + return SetResultParams(op, /*result_index=*/0, quant_type); } return false; } -bool QuantizationDriver::SetResultParams(Operation* op, const int res_index, - const QuantParams params) { - auto& state = GetResultQuantState(op, res_index); - if (state.params == params) { +bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, + const QuantizedType quantized_type) { + QuantState& state = GetResultQuantState(op, result_index); + if (state.params == quantized_type) { return false; } if (!state.IsEmpty()) { - auto& rescales = GetResultRequantizeStates(op, res_index); + RequantizeStates& rescales = GetResultRequantizeStates(op, result_index); RequantizeState& rescale = rescales.emplace_back(); rescale.pos = RequantizeState::ON_INPUT; - rescale.params = params; + rescale.params = quantized_type; return true; } - state.params = params; - AddUserToList(op, res_index); + state.params = quantized_type; + AddUserToList(op, result_index); return true; } -QuantParams QuantizationDriver::GetBiasParams( - Operation* op, const int bias_index, const std::vector& non_biases, +QuantizedType QuantizationDriver::GetBiasParams( + Operation* op, const int bias_index, + const ArrayRef non_bias_operand_indices, const AccumulatorScaleFunc func) { QuantState& bias_state = GetOperandQuantState(op, bias_index); if (!bias_state.IsEmpty()) { return bias_state.params; } - std::vector op_types; - op_types.reserve(non_biases.size()); + std::vector op_types{}; + op_types.reserve(non_bias_operand_indices.size()); + int adjusted_quant_dim = -1; if (op->getNumOperands() > bias_index) { // Some kernels allow 1D bias, broadcasting it inside the kernel. In this @@ -211,68 +216,75 @@ QuantParams QuantizationDriver::GetBiasParams( } } - for (int non_bias : non_biases) { - const QuantState& non_bias_type = GetOperandQuantState(op, non_bias); - op_types.push_back(non_bias_type.params); + for (const int non_bias_operand_index : non_bias_operand_indices) { + const QuantState& non_bias_state = + GetOperandQuantState(op, non_bias_operand_index); + op_types.push_back(non_bias_state.params); } return func(op_types, adjusted_quant_dim, legacy_float_scale_); } -bool QuantizationDriver::SetOperandParams(Operation* op, const int index, - const QuantParams params, +bool QuantizationDriver::SetOperandParams(Operation* op, + const int operand_index, + const QuantizedType quantized_type, const bool override) { - auto& state = GetOperandQuantState(op, index); - if (state.params == params) { + QuantState& state = GetOperandQuantState(op, operand_index); + if (state.params == quantized_type) { return false; } if (!state.IsEmpty() && !override) { - auto& rescales = GetOperandRequantizeStates(op, index); + RequantizeStates& rescales = GetOperandRequantizeStates(op, operand_index); for (RequantizeState& rescale : rescales) { - if (rescale.params == params) { - rescale.users.emplace_back(op, index); + if (rescale.params == quantized_type) { + rescale.users.emplace_back(op, operand_index); return true; } } RequantizeState& rescale = rescales.emplace_back(); rescale.pos = RequantizeState::ON_OUTPUT; - rescale.params = params; - rescale.users.emplace_back(op, index); + rescale.params = quantized_type; + rescale.users.emplace_back(op, operand_index); return true; } - state.params = params; - AddOperandToList(op, index); + state.params = quantized_type; + AddOperandToList(op, operand_index); return true; } -void QuantizationDriver::QuantizeOpResult(Operation* op, const int index, - const QuantParams params) { +void QuantizationDriver::QuantizeOpResult(Operation* op, const int result_index, + const QuantizedType quantized_type) { builder_.setInsertionPointAfter(op); - const Value original_result = op->getResult(index); - QuantizeValue(original_result, params, op->getLoc()); + const Value original_result = op->getResult(result_index); + QuantizeValue(original_result, quantized_type, op->getLoc()); } -void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) { +void QuantizationDriver::QuantizeArg(BlockArgument arg, + const QuantizedType quantized_type) { builder_.setInsertionPointToStart(arg.getOwner()); - QuantizeValue(arg, params, builder_.getUnknownLoc()); + QuantizeValue(arg, quantized_type, builder_.getUnknownLoc()); } -void QuantizationDriver::QuantizeValue(Value value, QuantParams params, - Location loc) { +void QuantizationDriver::QuantizeValue(Value value, + QuantizedType quantized_type, + const Location loc) { const Type expressed_type = value.getType(); - const Type new_type = params.castFromExpressedType(expressed_type); - // This value isn't an expressed type (float), skip. - if (!new_type) return; + const Type new_value_type = + quantized_type.castFromExpressedType(expressed_type); + // Skip if `value` or `value`'s element type doesn't match the expressed type + // of `quantized_type`. + if (new_value_type == nullptr) return; + auto quantize = - builder_.create(loc, new_type, value); + builder_.create(loc, new_value_type, value); auto dequantize = builder_.create( loc, expressed_type, quantize.getResult()); // This attribute is set to distinguish the quantize ops being added by the // quantization pass. These ops can be removed without losing original // program accuracy. - // TODO(fengliuai): make the attribute being part of op definition. + // TODO: b/323478683 - Make the attribute being part of op definition. quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); // `original_result` has a use to `quantize`, so this will replace that use @@ -281,17 +293,18 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params, quantize.getOperation()->replaceUsesOfWith(dequantize, value); } -void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, - RequantizeStates* states) { - if (states->empty()) return; +void QuantizationDriver::RequantizeOpResult(Operation* op, + const int result_index, + RequantizeStates& states) { + if (states.empty()) return; builder_.setInsertionPointAfter(op); - Value value = op->getResult(index); - RequantizeState::RequantizePosition pos = states->front().pos; + Value value = op->getResult(result_index); + RequantizeState::RequantizePosition pos = states.front().pos; if (pos == RequantizeState::NO_REQUANTIZE) { return; } - for (auto& state : *states) { + for (const RequantizeState& state : states) { // Check that all requantization positions are the same for each state. // Unsure if this check is required. if (state.pos != pos) { @@ -300,7 +313,7 @@ void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, } if (pos == RequantizeState::ON_OUTPUT) { Operation* user = value.getUses().begin().getUser(); - if (llvm::isa(user)) { + if (isa(user)) { // The requantize op is inserted between `quantize` and `dequantize` ops. value = user->getResult(0); builder_.setInsertionPointAfter(user); @@ -310,12 +323,12 @@ void QuantizationDriver::RequantizeOpResult(Operation* op, const int index, } void QuantizationDriver::RequantizeArg(const BlockArgument arg, - RequantizeStates* states) { + RequantizeStates& states) { Value value = arg; builder_.setInsertionPointToStart(arg.getOwner()); if (value.hasOneUse()) { Operation* user = value.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { value = q.getResult(); builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); } @@ -323,14 +336,13 @@ void QuantizationDriver::RequantizeArg(const BlockArgument arg, RequantizeValue(value, states, builder_.getUnknownLoc()); } -void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, +void QuantizationDriver::RequantizeValue(Value value, RequantizeStates& states, const Location loc) { - if (states->empty() || - states->front().pos == RequantizeState::NO_REQUANTIZE) { + if (states.empty() || states.front().pos == RequantizeState::NO_REQUANTIZE) { return; } - if (states->front().pos == RequantizeState::ON_INPUT) { - auto& state = states->front(); + if (states.front().pos == RequantizeState::ON_INPUT) { + RequantizeState& state = states.front(); const Type expressed_type = value.getType(); // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. @@ -350,7 +362,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, if (!value.hasOneUse()) { return; } - auto dequant_op = llvm::dyn_cast_or_null( + auto dequant_op = dyn_cast_or_null( value.use_begin().getUser()); if (!dequant_op) { return; @@ -363,10 +375,9 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, // Whether to replace quantization params of the first dequantize op // after the quantized value is produced. // If there is a use other than the requantize states, then we can't clobber. - bool clobber_first = num_uses <= states->size(); - for (auto& state : *states) { - Type expressed_type = - quant::QuantizedType::castToExpressedType(value.getType()); + bool clobber_first = num_uses <= states.size(); + for (RequantizeState& state : states) { + Type expressed_type = QuantizedType::castToExpressedType(value.getType()); if (!expressed_type) continue; // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. @@ -384,8 +395,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, } else { auto new_dequant_op = builder_.create( loc, dequant_op.getResult().getType(), requantize_op.getResult()); - for (auto& op_index : state.users) { - op_index.first->setOperand(op_index.second, new_dequant_op.getResult()); + for (auto [op, operand_idx] : state.users) { + op->setOperand(operand_idx, new_dequant_op.getResult()); } } } @@ -400,12 +411,12 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeStates* states, // - use the single input if it is ready, or, // - use the single output if it is ready, or, // - use the first ready one in the collection. -QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( +QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( Operation* op) { // Two vector to collect Non-empty operands and results states. std::vector mutable_states, immutable_states; for (int i = 0; i < op->getNumOperands(); ++i) { - auto& state = GetOperandQuantState(op, i); + QuantState& state = GetOperandQuantState(op, i); if (state.immutable) { immutable_states.push_back(&state); } else if (!state.IsEmpty()) { @@ -422,7 +433,7 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( } for (int i = 0; i < op->getNumResults(); ++i) { - auto& state = GetResultQuantState(op, i); + QuantState& state = GetResultQuantState(op, i); if (state.immutable) { immutable_states.push_back(&state); } else if (!state.IsEmpty()) { @@ -476,14 +487,11 @@ void QuantizationDriver::PreprocessConstantOps() { // The following loop will change the value uses, thus we cache all the uses // needs to be changed. - llvm::SmallVector> uses; - for (auto& use : value.getUses()) { + SmallVector> uses; + for (OpOperand& use : value.getUses()) { uses.push_back({use.getOwner(), use.getOperandNumber()}); } - for (const auto& indexed_use : llvm::enumerate(uses)) { - Operation* user = indexed_use.value().first; - const int operand_num = indexed_use.value().second; - + for (const auto [user, operand_num] : uses) { const std::unique_ptr spec = GetQuantSpec(user); const std::unique_ptr scale_spec = GetQuantScaleSpec(user); @@ -493,9 +501,9 @@ void QuantizationDriver::PreprocessConstantOps() { // other values. So any constants which are not bias, an operand of an // op with same scale requirements, and haven't been quantized are // weights. - if (biases.find(operand_num) == biases.end() && + if (!biases.contains(operand_num) && !scale_spec->has_same_scale_requirement && - !llvm::dyn_cast(user)) { + !dyn_cast(user)) { // Needs to scan the content of weights to get the quantization // parameters if there are no quantization parameters (FakeQuant ops). // For this case, the weight will not be duplicated. @@ -511,9 +519,9 @@ void QuantizationDriver::PreprocessConstantOps() { // other values. Duplicate this constant in case it is shared by // different users. if (uses.size() > 1) { - auto new_cst = + auto new_constant_op = builder_.create(cst.getLoc(), cst.getValue()); - user->setOperand(operand_num, new_cst); + user->setOperand(operand_num, new_constant_op); } } } @@ -521,13 +529,13 @@ void QuantizationDriver::PreprocessConstantOps() { } void QuantizationDriver::SetupAllStates() { - for (auto arg : fn_.getArguments()) { + for (BlockArgument arg : fn_.getArguments()) { args_.push_back(arg); Value value = arg; // If the argument is quantized, it should only has one user. if (arg.hasOneUse()) { Operation* user = value.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { value = q.getResult(); } } @@ -543,29 +551,29 @@ void QuantizationDriver::SetupAllStates() { for (int i = 0; i < op->getNumOperands(); ++i) { Value operand = op->getOperand(i); - if (auto* inst = operand.getDefiningOp()) { + if (Operation* inst = operand.getDefiningOp()) { // If the operand comes from a `quantfork::DequantizeCastOp`, we use // the quantized input of this `quantfork::DequantizeCastOp` to set the // state. - if (auto dq = llvm::dyn_cast(inst)) { + if (auto dq = dyn_cast(inst)) { operand = dq.getArg(); } } InitializeOperandState(op, i, operand); } - for (int res = 0; res < op->getNumResults(); ++res) { - Value result = op->getResult(res); + for (int i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); // If the result has been quantized, it should only be used by a // `quantfork::QuantizeCastOp`. For this case, we uses the quantized // result to create the state and mark it immutable. if (result.hasOneUse()) { Operation* user = result.use_begin().getUser(); - if (auto q = llvm::dyn_cast(user)) { + if (auto q = dyn_cast(user)) { result = q.getResult(); } } - InitializeResultState(op, res, result); + InitializeResultState(op, i, result); } }); } @@ -577,7 +585,7 @@ arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( } OpBuilder builder(op->getContext()); builder.setInsertionPointAfter(op); - arith::ConstantOp new_op = llvm::cast(builder.clone(*op)); + arith::ConstantOp new_op = cast(builder.clone(*op)); target_op->getOpOperand(operand_index).set(new_op.getResult()); InitializeOperandState(target_op, operand_index, new_op.getResult()); InitializeResultState(new_op, 0, new_op.getResult()); @@ -585,13 +593,13 @@ arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( } bool QuantizationDriver::ShouldCheckBiasScale( - Operation* op, const int bias_index, const std::vector& input_indices, - const QuantParams params, int& input_index, int& filter_index) { + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType quantized_type, int& input_index, int& filter_index) { // For now, restrict scale adjustment to ops with affine quantized weights, // and having weights and biases as constants. This currently only applies to // FC and Conv* ops. Restriction for the weight can be relaxed if there are // needs for adjusting scale of variable weights. - auto affine_op = llvm::dyn_cast(op); + auto affine_op = dyn_cast(op); auto bias_op = op->getOperand(bias_index).getDefiningOp(); if (!affine_op || !bias_op || input_indices.size() != 2) return false; if (!bias_op.getValue().isa()) return false; @@ -607,22 +615,20 @@ bool QuantizationDriver::ShouldCheckBiasScale( return false; } - const auto input_state = GetOperandQuantState(op, input_index); - const auto filter_state = GetOperandQuantState(op, filter_index); + const QuantState& input_state = GetOperandQuantState(op, input_index); + const QuantState& filter_state = GetOperandQuantState(op, filter_index); // If quantization parameter for the filter is fixed, should return it as-is. // Only checks ops with 8-bit input and weights, and 32-bit biases. - if (!(input_state.params.getStorageTypeIntegralWidth() == 8 && - filter_state.params.getStorageTypeIntegralWidth() == 8 && - params.getStorageTypeIntegralWidth() == 32)) { - return false; - } - return true; + return input_state.params.getStorageTypeIntegralWidth() == 8 && + filter_state.params.getStorageTypeIntegralWidth() == 8 && + quantized_type.getStorageTypeIntegralWidth() == 32; } bool QuantizationDriver::SetBiasParamsWithAdjustments( - Operation* op, const int bias_index, const std::vector& input_indices, - const QuantParams params) { + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType params) { bool changed = false; + int input_index; int filter_index; if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index, @@ -630,8 +636,8 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( return SetOperandParams(op, bias_index, params); } - quant::QuantState input_state = GetOperandQuantState(op, input_index); - quant::QuantState filter_state = GetOperandQuantState(op, filter_index); + QuantState input_state = GetOperandQuantState(op, input_index); + QuantState filter_state = GetOperandQuantState(op, filter_index); auto bias_op = op->getOperand(bias_index).getDefiningOp(); const double input_scale = input_state.params.cast().getScale(); @@ -639,15 +645,15 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( auto bias_values = bias_op.getValue().cast(); // Restrict maximum absolute value of bias within INT_MAX / 2, to make some // room for accumulator. - const int32_t kBiasMax = std::numeric_limits::max() / 2; - if (auto bias_params = params.dyn_cast()) { + if (auto bias_quantized_type = params.dyn_cast(); + bias_quantized_type != nullptr) { double bias_half_range = 0.0f; for (auto bias : bias_values.getValues()) { if (bias_half_range < std::abs(bias.convertToFloat())) { bias_half_range = std::abs(bias.convertToFloat()); } } - if (bias_half_range / bias_params.getScale() < kBiasMax) { + if (bias_half_range / bias_quantized_type.getScale() < kBiasMax) { return SetOperandParams(op, bias_index, params); } const double new_bias_scale = @@ -659,30 +665,36 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( bias_op->getLoc(), params.getFlags(), params.getStorageType(), params.getExpressedType(), new_bias_scale, 0, params.getStorageTypeMin(), params.getStorageTypeMax())); - auto filter_op = DuplicateConstantOpIfNeeded( + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( op->getOperand(filter_index).getDefiningOp(), op, filter_index); if (!filter_op) { return SetOperandParams(op, bias_index, params); } - const auto filter_param = filter_state.params.cast(); + const auto filter_quantized_type = + filter_state.params.cast(); changed |= SetOperandParams( op, filter_index, UniformQuantizedType::getChecked( - filter_op->getLoc(), filter_param.getFlags(), - filter_param.getStorageType(), filter_param.getExpressedType(), - new_bias_scale / input_scale, 0, filter_param.getStorageTypeMin(), - filter_param.getStorageTypeMax()), + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), + new_bias_scale / input_scale, 0, + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), /*override=*/true); - } else if (auto bias_params = - params.dyn_cast()) { - const auto filter_params = + } else if (auto bias_quantized_type = + params.dyn_cast(); + bias_quantized_type != nullptr) { + const auto filter_quantized_type = filter_state.params.cast(); - std::vector new_bias_scales = bias_params.getScales().vec(); - std::vector new_filter_scales = filter_params.getScales().vec(); + std::vector new_bias_scales = bias_quantized_type.getScales().vec(); + std::vector new_filter_scales = + filter_quantized_type.getScales().vec(); + bool needs_adjustment = false; - for (int i = 0; i < bias_params.getScales().size(); ++i) { + for (int i = 0; i < bias_quantized_type.getScales().size(); ++i) { const float abs_bias = std::abs(bias_values.getValues()[i]); if (abs_bias / new_bias_scales[i] > kBiasMax) { new_bias_scales[i] = static_cast(abs_bias) / kBiasMax; @@ -698,21 +710,23 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( quant::UniformQuantizedPerAxisType::getChecked( bias_op->getLoc(), params.getFlags(), params.getStorageType(), params.getExpressedType(), new_bias_scales, - bias_params.getZeroPoints(), bias_params.getQuantizedDimension(), + bias_quantized_type.getZeroPoints(), + bias_quantized_type.getQuantizedDimension(), params.getStorageTypeMin(), params.getStorageTypeMax())); - auto filter_op = DuplicateConstantOpIfNeeded( + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( op->getOperand(filter_index).getDefiningOp(), op, filter_index); changed |= SetOperandParams( op, filter_index, quant::UniformQuantizedPerAxisType::getChecked( - filter_op->getLoc(), filter_params.getFlags(), - filter_params.getStorageType(), filter_params.getExpressedType(), - new_filter_scales, filter_params.getZeroPoints(), - filter_params.getQuantizedDimension(), - filter_params.getStorageTypeMin(), - filter_params.getStorageTypeMax()), + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), new_filter_scales, + filter_quantized_type.getZeroPoints(), + filter_quantized_type.getQuantizedDimension(), + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), /*override=*/true); } return changed; @@ -720,12 +734,12 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( // This method scans the operations in the function to setup the initial // states for quantization parameter propagation. -// TODO(fengliuai): This algorithm assumes there are only one pair of +// TODO: b/323478683 - This algorithm assumes there are only one pair of // `quantfork::QuantizeCastOp` and `quantfork::DequantizeCastOp` ops between two // quantizable ops. A sanity check should be applied. void QuantizationDriver::Initialize() { // Duplicate the bias constant, so the states can be setup correctly. - // TODO(fengliuai): Function definition should also be duplicated if there + // TODO: b/323478683 - Function definition should also be duplicated if there // are multiple call sites. PreprocessConstantOps(); @@ -736,21 +750,21 @@ void QuantizationDriver::Initialize() { // Propagates the quantization parameters to the operands, results, and biases. // TODO: b/323478683 - Do not use while loop to handle this logic. bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { - // TODO(fengliuai): uses a typed indicator instead of a bool value. + // TODO: b/323478683 - Use a typed indicator instead of a bool value. bool changed = false; while (!work_list_.empty()) { Operation* op = work_list_.back(); work_list_.pop_back(); // This op has been quantized, so we should not consider it again. - if (llvm::is_contained(quantized_, op)) continue; + if (quantized_.contains(op)) continue; quantized_.insert(op); - if (auto cst = llvm::dyn_cast(op)) { + if (auto constant_op = dyn_cast(op); constant_op) { // If the workflow requires inferring ranges from the content // (post-training quantization) and it is weight (filter) and hasn't // been quantized, we infer the quantization parameters from the content. - if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) { + if (infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { // The quantization parameters are determined by the content of the // constant. changed |= SetConstantResultParams(op); @@ -761,7 +775,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { std::unique_ptr scale_spec = GetQuantScaleSpec(op); if (scale_spec->has_same_scale_requirement) { - const auto params = GetQuantParamsForSameScaleConstraint(op); + const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); // The quantization parameters haven't been propagated to any operands // or results. Skip this node for now. if (!params) { @@ -792,12 +806,13 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { } // Use the final state to set all the results' parameters. - for (int res = 0; res < op->getNumResults(); ++res) - if (auto type = op->getResult(res).getType().dyn_cast()) { + for (int i = 0; i < op->getNumResults(); ++i) + if (auto type = op->getResult(i).getType().dyn_cast(); + type != nullptr) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float-tensors. if (type.getElementType().isa()) - changed |= SetResultParams(op, res, params); + changed |= SetResultParams(op, i, params); } } @@ -807,8 +822,8 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { !is_qdq_conversion_) { // Infer ranges from the activation ops. This is usually required for // the post-training quantization workflow. - // TODO(fengliuai): different result can have different fixed range. - const auto params = + // TODO: b/323478683 - Different result can have different fixed range. + const QuantizedType params = scale_spec->fixed_output_range_func(is_signed_, bit_width_); for (auto i = 0; i < op->getNumResults(); ++i) { // The range is null if the result has been quantized. @@ -818,16 +833,20 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { } } - const auto spec = GetQuantSpec(op); - for (auto& it : spec->biases_params) { - const auto params = - GetBiasParams(op, it.first, it.second.first, it.second.second); + const std::unique_ptr spec = GetQuantSpec(op); + for (const auto& [bias_operand_idx, non_bias_params] : + spec->biases_params) { + const auto& [non_bias_operand_indices, accumulator_scale_func] = + non_bias_params; + const QuantizedType params = + GetBiasParams(op, bias_operand_idx, non_bias_operand_indices, + accumulator_scale_func); if (!params) { quantized_.erase(op); continue; } - changed |= - SetBiasParamsWithAdjustments(op, it.first, it.second.first, params); + changed |= SetBiasParamsWithAdjustments(op, bias_operand_idx, + non_bias_operand_indices, params); } } @@ -836,9 +855,9 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // Finalizes the arguments and result states in the function. void QuantizationDriver::Finalize() { - for (auto arg : args_) { - auto& state = GetArgQuantState(arg); - auto& requantizes = GetArgRequantizeStates(arg); + for (BlockArgument arg : args_) { + const QuantState& state = GetArgQuantState(arg); + RequantizeStates& requantizes = GetArgRequantizeStates(arg); if (state.IsEmpty() || (state.immutable && requantizes.empty())) { continue; } @@ -848,25 +867,24 @@ void QuantizationDriver::Finalize() { } if (!requantizes.empty()) { - RequantizeArg(arg, &requantizes); + RequantizeArg(arg, requantizes); } } - for (auto it : result_states_) { - Operation* op = it.first.first; - const int res_index = it.first.second; - auto& state = GetResultQuantState(op, res_index); - auto& requantizes = GetResultRequantizeStates(op, res_index); + for (const auto& [op_with_result_idx, quant_state_idx] : result_states_) { + const auto [op, result_idx] = op_with_result_idx; + const QuantState& state = GetResultQuantState(op, result_idx); + RequantizeStates& requantizes = GetResultRequantizeStates(op, result_idx); if (state.IsEmpty() || (state.immutable && requantizes.empty())) { continue; } if (!state.immutable) { - QuantizeOpResult(op, res_index, state.params); + QuantizeOpResult(op, result_idx, state.params); } if (!requantizes.empty()) { - RequantizeOpResult(op, res_index, &requantizes); + RequantizeOpResult(op, result_idx, requantizes); } } } @@ -885,7 +903,7 @@ void QuantizationDriver::Run() { } void ApplyQuantizationParamsPropagation( - const mlir::func::FuncOp func, const bool is_signed, const int bit_width, + const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, const OpQuantSpecGetter op_quant_spec_getter, const bool infer_tensor_ranges, const bool legacy_float_scale, @@ -897,7 +915,7 @@ void ApplyQuantizationParamsPropagation( } void ApplyQuantizationParamsPropagation( - const mlir::func::FuncOp func, const bool is_signed, const int bit_width, + const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, const OpQuantSpecGetter op_quant_spec_getter, const OpQuantScaleSpecGetter op_quant_scale_spec_getter, diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h index 59741f48307a16..d054e9ed738ce0 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h @@ -17,14 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ #include -#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -40,20 +39,16 @@ limitations under the License. namespace mlir { namespace quant { -static bool HasQuantParams(QuantParams p) { - return p == quant::QuantizedType(); -} - // The state for each op result during the quantization parameters propagation. struct QuantState { // Quantization parameters propagated to an op result. - QuantParams params; + QuantizedType params; // A flag indicates this state (the params) shouldn't be changed after it is // initialized. This flag will be set to true if the quantization parameters // are from the quantization-aware training. const bool immutable; - bool IsEmpty() { return HasQuantParams(params); } + bool IsEmpty() const { return params == nullptr; } }; // The state for rescaling the propagated quantization parameters. This can be @@ -70,7 +65,7 @@ struct RequantizeState { } pos = NO_REQUANTIZE; // Quantization parameters will be used to add the requantize ops. - QuantParams params; + QuantizedType params; // Avoid clobbering all uses of the value, limit to just these ops. SmallVector> users; @@ -99,15 +94,25 @@ using RequantizeStates = SmallVector; // class QuantizationDriver { public: - explicit QuantizationDriver(func::FuncOp fn, bool is_signed, int bit_width, - bool disable_per_channel, + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, - bool infer_tensor_range, - bool legacy_float_scale = false, - bool is_qdq_conversion = false) - : fn_(fn), - builder_(fn.getBody()), + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), is_signed_(is_signed), bit_width_(bit_width), disable_per_channel_(disable_per_channel), @@ -130,7 +135,7 @@ class QuantizationDriver { // result. void Finalize(); - llvm::SmallVector GetArgs() { return args_; } + SmallVector GetArgs() { return args_; } // Returns the state of the block argument. QuantState& GetArgQuantState(BlockArgument arg) { @@ -138,10 +143,6 @@ class QuantizationDriver { } private: - // This is used to identify an operand or result of an op. The second element - // of this pair is the index of the operand or result. - using OpValue = std::pair; - // Duplicates the constant op if it has multiple uses, and replaces // target_op->operand[operand_index] with the newly created op. This also // replaces corresponsing quantization states. @@ -153,13 +154,13 @@ class QuantizationDriver { // prevent overflow of quantized bias values. This also changes quantization // state of other inputs when needed. bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, - const std::vector& input_indices, - QuantParams params); + ArrayRef input_indices, + QuantizedType params); // Checks preconditions to adjust bias scale. bool ShouldCheckBiasScale(Operation* op, int bias_index, - const std::vector& input_indices, - QuantParams params, int& input_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, int& filter_index); // Preprocesses the constants by doing the following: @@ -187,84 +188,92 @@ class QuantizationDriver { bool IsQuantized(Operation* op); // Adds all the users of index-th result of op to the work list. - void AddUserToList(Operation* op, int index) { + void AddUserToList(Operation* op, const int index) { for (Operation* user : op->getResult(index).getUsers()) { work_list_.push_back(user); } } // Adds the defining op of index-th operand of op to the work list. - void AddOperandToList(Operation* op, int index) { - if (Operation* inst = op->getOperand(index).getDefiningOp()) { - work_list_.push_back(inst); + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); } } // Returns the quantization params for the bias input from the non-bias // operands which have their indexes in the `non_biases` vector. The returned // parameters are calculated by `func`. - QuantParams GetBiasParams(Operation* op, int bias_index, - const std::vector& non_biases, - AccumulatorScaleFunc func); - - // Sets the quantization parameters of the result to a fixed value. If any - // quantization parameters have been propagated, a `requantize` will happen on - // the input of propagated quantization. - bool SetResultParams(Operation* op, int index, QuantParams params); - - // Sets the quantization parameters of the operand to a fixed value. If any + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any // quantization parameters have been propagated, a `requantize` will happen on // the output of propagated quantization. When `override` is set, quantization - // state of the value is replaced instead of adding requantization. - bool SetOperandParams(Operation* op, int index, QuantParams params, - bool override = false); + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); // Sets the quantization parameters of the constant result according to its // content. bool SetConstantResultParams(Operation* op); - // Inserts the Quantize and Dequantize ops for quantizing the index-th result - // of the op. - void QuantizeOpResult(Operation* op, int index, QuantParams params); + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); - void QuantizeArg(BlockArgument arg, QuantParams params); + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); - // Inserts the Quantize and Dequantize ops to quantize the value and returns - // the Quantize op. - void QuantizeValue(Value value, QuantParams params, Location loc); + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); // Inserts the Quantize ops for requantizing the index-th result of the op. - void RequantizeOpResult(Operation* op, int index, RequantizeStates* states); + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); // Inserts the Quantize ops for requantizing a block argument. - void RequantizeArg(BlockArgument arg, RequantizeStates* states); + void RequantizeArg(BlockArgument arg, RequantizeStates& states); // Inserts the Quantize and Dequantize ops to quantize the value and returns // the Quantize op. - void RequantizeValue(Value value, RequantizeStates* states, Location loc); + void RequantizeValue(Value value, RequantizeStates& states, Location loc); // Returns the quantization parameter satisfies the same scale // constraints for the op. Returns an empty option if this quantization // parameter doesn't exist. - QuantParams GetQuantParamsForSameScaleConstraint(Operation* op); + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); // Returns the state of the index-th operand of the op. - QuantState& GetOperandQuantState(Operation* op, int index) { + QuantState& GetOperandQuantState(Operation* op, const int index) { return states_[operand_states_[{op, index}]]; } // Returns the state of the index-th result of the op. - QuantState& GetResultQuantState(Operation* op, int index) { + QuantState& GetResultQuantState(Operation* op, const int index) { return states_[result_states_[{op, index}]]; } // Returns the states of the index-th operand of the op. - RequantizeStates& GetOperandRequantizeStates(Operation* op, int index) { + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { return rescale_states_[operand_states_[{op, index}]]; } // Returns the states of the index-th result of the op. - RequantizeStates& GetResultRequantizeStates(Operation* op, int index) { + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { return rescale_states_[result_states_[{op, index}]]; } @@ -278,10 +287,6 @@ class QuantizationDriver { // a new entry in the state vector. void InitializeArgState(BlockArgument arg, Value arg_value); - // Sets the state of index-th operand / result of op. - void InitializeStateForValue(Operation* op, int index, Value value, - bool as_result); - // Sets the state of the index-th operand of the op. If this operand is // cached, uses the cached result without creating new entry in the state // vector. Otherwise, allocate a new entry in the state vector. @@ -301,12 +306,13 @@ class QuantizationDriver { // We should distinguish weights and bias constants. Biases are specified by // the quantization spec or are the operands of ops with same scale spec. The // rest are weights. - llvm::DenseSet weights_; + DenseSet weights_; // The weights require narrow_range quantization. This map collects all the - // weight operands defined by the op quant spec. If the value of the entry is - // positive, per-channel quantization is required. - llvm::DenseMap optimized_weights_; + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; // All the ops needs to propagate the quantization parameters to. std::vector work_list_; @@ -319,18 +325,18 @@ class QuantizationDriver { // The map contains all the quantization parameters which are required to // satisfy the same operands and results constraint. The keys of this map are // the values from `operand_states_` and `result_state_`. - std::unordered_map rescale_states_; + absl::flat_hash_map rescale_states_; // Maps of indexes to the propagation state vector from the ops operands, // results and arguments. - llvm::DenseMap operand_states_; - llvm::DenseMap result_states_; - llvm::DenseMap arg_states_; - llvm::DenseMap value_to_state_; + DenseMap operand_states_; + DenseMap result_states_; + DenseMap arg_states_; + DenseMap value_to_state_; // This vector is to preserve the arguments order, so the newly inserted // quantized ops for the arguments are deterministically ordered. - llvm::SmallVector args_; + SmallVector args_; OpQuantSpecGetter op_quant_spec_getter_; OpQuantScaleSpecGetter op_quant_scale_spec_getter_; @@ -357,7 +363,7 @@ class QuantizationDriver { // Setting `infer_tensor_range` to true, to infer quantization parameters from // the activation ops and weight constants. This is only used for post-training // quantization. -void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, @@ -365,8 +371,8 @@ void ApplyQuantizationParamsPropagation(mlir::func::FuncOp func, bool is_signed, bool is_qdq_conversion); void ApplyQuantizationParamsPropagation( - mlir::func::FuncOp func, bool is_signed, int bit_width, - bool disable_per_channel, OpQuantSpecGetter op_quant_spec_getter, + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, bool legacy_float_scale, bool is_qdq_conversion); diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index d95ba49cf8e800..88017117098aca 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -26,10 +26,10 @@ limitations under the License. #include #include #include -#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseMap.h" @@ -86,11 +86,11 @@ inline constexpr double kNearZeroTolerance = 1.0e-6; using QuantParams = QuantizedType; using QuantSpec = QuantizationSpecs; using SignedInteger = std::pair; // bitwidth and sign -using QuantParamsForResults = llvm::SmallVector; +using QuantParamsForResults = llvm::SmallVector; using AccumulatorScaleFunc = - std::function&, int, bool)>; + std::function&, int, bool)>; using BiasParamsMap = - std::unordered_map, AccumulatorScaleFunc>>; + absl::flat_hash_map, AccumulatorScaleFunc>>; // UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) using GetFixedOutputRangeFunc = std::function; // bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) From 6c62a390ba07f071e52f4c727a62446388e26044 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 15:55:19 -0700 Subject: [PATCH 047/670] 1. Fix flaky test after recently enabling the modelling of resharding memory costs by default in auto-sharding. 2. Also check shapes of parameters in the test instead of sharding annotations as sharding annotations may not be preserved through the compilation. PiperOrigin-RevId: 616971712 --- third_party/xla/xla/service/gpu/BUILD | 1 + .../gpu/auto_sharding_gpu_compiler_test.cc | 37 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b94c1f10f9ce84..dd3fc565d28481 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3929,6 +3929,7 @@ xla_cc_test( srcs = ["auto_sharding_gpu_compiler_test.cc"], tags = tf_cuda_tests_tags() + ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS deps = [ + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:gpu_plugin", "//xla/service:hlo_module_config", diff --git a/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc index 06928aa44a08b1..eab4b0d48e5dbb 100644 --- a/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/logging.h" @@ -30,6 +32,8 @@ namespace { namespace m = ::xla::match; +using ::testing::Conditional; + class AutoShardingTest : public HloTestBase { protected: const char* const dot_hlo_string_ = R"( @@ -60,14 +64,35 @@ ENTRY matmul { }; TEST_F(AutoShardingTest, MatMulWithAutosharding) { - auto compiled_module = CompileMatMul(true, 4); - auto* instruction = + std::unique_ptr compiled_module = CompileMatMul(true, 4); + const HloInstruction* parameter1 = compiled_module->entry_computation()->parameter_instruction(0); - VLOG(2) << instruction->ToString(); + const HloInstruction* parameter2 = + compiled_module->entry_computation()->parameter_instruction(1); + bool is_parameter1_replicated = ShapeUtil::Equal( + parameter1->shape(), ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64})); + bool is_parameter2_replicated = ShapeUtil::Equal( + parameter2->shape(), ShapeUtil::MakeShape(PrimitiveType::F32, {64, 128})); + + // Check that at least one of the parameters is sharded, thereby telling us + // that the dot is as well. + VLOG(2) << parameter1->ToString(); + EXPECT_THAT( + parameter1, + Conditional( + is_parameter2_replicated, + AnyOf(GmockMatch(m::Op().WithShape(PrimitiveType::F32, {8, 64})), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {32, 16}))), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {32, 64})))); + + VLOG(2) << parameter2->ToString(); EXPECT_THAT( - instruction, - AnyOf(GmockMatch(m::Op().WithSharding("{devices=[1,4]0,1,2,3}")), - GmockMatch(m::Op().WithSharding("{devices=[4,1]0,1,2,3}")))); + parameter2, + Conditional( + is_parameter1_replicated, + AnyOf(GmockMatch(m::Op().WithShape(PrimitiveType::F32, {16, 128})), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {64, 32}))), + GmockMatch(m::Op().WithShape(PrimitiveType::F32, {64, 128})))); } TEST_F(AutoShardingTest, MatMulWithoutAutosharding) { From 400c2b04ab3360d9bfdc5f9eada6a45cd827e8f0 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 15:56:28 -0700 Subject: [PATCH 048/670] Rename `fold_constant_transpose_pass.cc`->`fold_constant_transpose.cc`. Conventionally the `_pass` prefix isn't used. PiperOrigin-RevId: 616971947 --- tensorflow/compiler/mlir/quantization/stablehlo/BUILD | 2 +- ...ld_constant_transpose_pass.cc => fold_constant_transpose.cc} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tensorflow/compiler/mlir/quantization/stablehlo/passes/{fold_constant_transpose_pass.cc => fold_constant_transpose.cc} (100%) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 99c93739949a9c..11b100be5601c1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -48,7 +48,7 @@ cc_library( srcs = [ "passes/convert_func_to_bfloat16.cc", "passes/convert_xla_call_module_op_to_bfloat16.cc", - "passes/fold_constant_transpose_pass.cc", + "passes/fold_constant_transpose.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions_fusion.inc", "passes/lift_quantizable_spots_as_functions_simple.inc", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc similarity index 100% rename from tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose_pass.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc From 28c0ea47129c47ee10dcf7ee3a47e8a25a2eca42 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Mon, 18 Mar 2024 16:01:41 -0700 Subject: [PATCH 049/670] Add tensor shape check for ADD & MUL. PiperOrigin-RevId: 616973218 --- .../delegates/gpu/common/model_builder.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 548cbcba1afc80..10e3efd3bbbb5a 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1143,19 +1143,32 @@ class ElementwiseOperationParser : public TFLiteOperationParser { int input_tensor1 = 1; if (operation_type_ == OperationType::MUL || operation_type_ == OperationType::ADD) { - // The "larger" input tensor must be bound to 1st input and the - // "smaller" input tensor must be bound to 2nd input. + // The "larger" input tensor MUST be the 1st argument, and the + // "smaller" input tensor must be the 2nd. BHWC shape0; RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); BHWC shape1; RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1)); + if (shape0.b != shape1.b) { + return absl::InvalidArgumentError(absl::StrCat( + "Tensor shape (b) mismatch: ", shape0.b, " vs ", shape1.b)); + } else if (shape0.c != shape1.c) { + return absl::InvalidArgumentError(absl::StrCat( + "Tensor shape (c) mismatch: ", shape0.c, " vs ", shape1.c)); + } else if (!(shape0.h <= shape1.h && shape0.w <= shape1.w) && + !(shape0.h >= shape1.h && shape0.w >= shape1.w)) { + // One input tensor must be consistently larger (or smaller) than or + // as same shaped as the other input tensor in both dimensions. + return absl::InvalidArgumentError(absl::StrCat( + "Tensor shape (h, w) mismatch: (", shape0.h, ", ", shape0.w, + ") vs (", shape1.h, ", ", shape1.w, ")")); + } if (shape0.h <= shape1.h && shape0.w <= shape1.w && shape0.c == shape1.c) { input_tensor0 = 1; input_tensor1 = 0; } } - RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); } From 1b85215e88f7254b50ed0552d41195d6ebd6e122 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 16:33:18 -0700 Subject: [PATCH 050/670] Cleanup: `lift_as_function_call.h/cc`. * Remove unnecessary `llvm::` and `mlir::`. * Use `ArrayRef`s where applicable. * Use pass-by-reference where applicable (e.g. OpBuilder). * Use pass-by-value where applicable (e.g. Value). PiperOrigin-RevId: 616981588 --- .../common/lift_as_function_call.cc | 118 +++++++++--------- .../common/lift_as_function_call.h | 27 ++-- .../common/lift_as_function_call_test.cc | 29 ++--- 3 files changed, 89 insertions(+), 85 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index 9c700ed50bc4d0..86ba98a7ee1139 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -69,12 +69,10 @@ constexpr int64_t kDefaultVersion = 9; constexpr StringRef kPlatformCpu = "CPU"; // Name of `tf.XlaCallModule`'s dictionary attribute for keeping the // deserialized stablehlo module's attributes. -constexpr llvm::StringRef kStablehloModuleAttrsAttrName = - "_stablehlo_module_attrs"; +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; // Attribute required for running shape refinement pass enabled in XlaCallModule // version 8 and above. -constexpr llvm::StringRef kUsesShapePolymorphismAttr = - "jax.uses_shape_polymorphism"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; // Checks if the op is inside a lifted function. bool IsInLiftedFunc(Operation& op) { @@ -83,16 +81,16 @@ bool IsInLiftedFunc(Operation& op) { // Inserts the function to the symbol table of the module thread-safely. StringAttr InsertToSymbolTable(Operation& module, Operation& function, - const std::string& func_name) { + const StringRef func_name) { static tensorflow::mutex* mtx = new tensorflow::mutex(); tensorflow::mutex_lock lock(*mtx); SymbolTable symbol_table(&module); - std::string unique_name = func_name; + std::string unique_name = func_name.str(); int32_t uniquing_counter = 0; while (symbol_table.lookup(unique_name) != nullptr) { ++uniquing_counter; - unique_name = func_name + "_" + std::to_string(uniquing_counter); + unique_name = absl::StrCat(func_name.str(), "_", uniquing_counter); } function.setAttr("sym_name", StringAttr::get(module.getContext(), unique_name)); @@ -101,9 +99,11 @@ StringAttr InsertToSymbolTable(Operation& module, Operation& function, // Creates the TF::PartitionedCallOp with the given arguments and output types. // This function call op is for invoking the TF subgraphs. -ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, - StringRef func_name, - TypeRange output_types, ValueRange args) { +ValueRange CreateTFPartitionedCallOp(OpBuilder& builder, + const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { TF::PartitionedCallOp call_op = builder.create( location, output_types, args, FlatSymbolRefAttr::get(builder.getStringAttr(func_name)), @@ -112,7 +112,7 @@ ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, // Set the attribute to annotate this function call op as a quantizable spot. call_op->setAttr( kQuantTraitAttrName, - builder.getStringAttr(llvm::StringRef( + builder.getStringAttr(StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); return call_op.getOutput(); @@ -120,10 +120,11 @@ ValueRange createTFPartitionedCallOp(OpBuilder builder, Location location, // Creates the TF::XlaCallModuleOp with the given arguments and output types. // This function call op is for invoking the StableHLO subgraphs. -ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, - StringRef func_name, TypeRange output_types, - ValueRange args) { - auto ctx = builder.getContext(); +ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + MLIRContext* ctx = builder.getContext(); // Collect the shapes of the output to fill up the Sout attribute. SmallVector shape_attrs; for (const Type result_type : output_types) { @@ -133,7 +134,7 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, auto empty_array_attr = ArrayAttr::get(ctx, {}); auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); - TF::XlaCallModuleOp call_op = builder.create( + auto call_op = builder.create( location, /*output=*/output_types, /*args=*/args, @@ -159,7 +160,7 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, // Set the attribute to annotate this function call op as a quantizable spot. call_op->setAttr( kQuantTraitAttrName, - builder.getStringAttr(llvm::StringRef( + builder.getStringAttr(StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. @@ -172,27 +173,25 @@ ValueRange createTFXlaCallModuleOp(OpBuilder builder, Location location, } // Creates the function call op based on the given call_op_type argument. -ValueRange createFunctionCallOp(OpBuilder builder, Location location, - FunctionCallOpType call_op_type, - StringRef func_name, TypeRange output_types, - ValueRange args) { +ValueRange CreateFunctionCallOp(OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { switch (call_op_type) { case FunctionCallOpType::TFXlaCallModuleOp: - return createTFXlaCallModuleOp(builder, location, func_name, output_types, + return CreateTFXlaCallModuleOp(builder, location, func_name, output_types, args); case FunctionCallOpType::TFPartitionedCallOp: - return createTFPartitionedCallOp(builder, location, func_name, + return CreateTFPartitionedCallOp(builder, location, func_name, output_types, args); - default: - llvm_unreachable("unhandled call op type"); } } // Finds ops in the paths from arguments to results. The ops is listed in an // order that the former ops shouldn't have any dependencies on the later ones. -llvm::SmallVector FindOpsFromArgumentsToResults( - const llvm::SmallVector& arguments, - const llvm::SmallVector& results) { +SmallVector FindOpsFromArgumentsToResults( + const ArrayRef arguments, const ArrayRef results) { std::queue value_queue; for (Value result : results) { value_queue.push(result); @@ -213,7 +212,7 @@ llvm::SmallVector FindOpsFromArgumentsToResults( Operation* defining_node = current_value.getDefiningOp(); if (defining_node == nullptr) continue; op_stack.push(defining_node); - for (const auto& arg : defining_node->getOperands()) { + for (Value arg : defining_node->getOperands()) { if (!argument_set.contains(arg.getImpl())) { value_queue.push(arg); } @@ -221,7 +220,7 @@ llvm::SmallVector FindOpsFromArgumentsToResults( } // Remove duplicate ops from the op stack. - llvm::SmallVector sorted_ops; + SmallVector sorted_ops; absl::flat_hash_set unique_ops; while (!op_stack.empty()) { Operation* current_op = op_stack.top(); @@ -243,9 +242,9 @@ llvm::SmallVector FindOpsFromArgumentsToResults( // "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute // identifiers. // This function returns success if all attributes could be found. -LogicalResult SetAttributeMap( - MLIRContext& context, const llvm::SmallVector& attributes, - const llvm::SmallVector& ops) { +LogicalResult SetAttributeMap(MLIRContext& context, + const ArrayRef attributes, + const ArrayRef ops) { // A map to find which operation an attribute belongs to. // The key for this map uses the entire NamedAttribute object, i.e. the // {attribute_name, attribute_value} pair. @@ -270,8 +269,8 @@ LogicalResult SetAttributeMap( attr_to_op_map.begin(), attr_to_op_map.end(), [&](auto attr_op) { return std::get<0>(attr_op).getName() == attribute.getName(); }) == attr_to_op_map.end()) { - mlir::emitError(UnknownLoc::get(&context), - "Could not find attribute: " + attribute.getName().str()); + emitError(UnknownLoc::get(&context), + "Could not find attribute: " + attribute.getName().str()); return failure(); } @@ -293,7 +292,7 @@ LogicalResult SetAttributeMap( // Append ":". Ex) "0:transpose_a". const std::string identifier = std::to_string(idx); - const mlir::StringAttr attribute_name = attribute.getName(); + const StringAttr attribute_name = attribute.getName(); absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); owner_op->setAttr(kAttrMapAttribute, StringAttr::get(&context, new_attr_map_str)); @@ -303,14 +302,14 @@ LogicalResult SetAttributeMap( } // Creates a function to wrap the section between arguments and results. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector& arguments, - const llvm::SmallVector& results, - const llvm::SmallVector& attributes) { +SmallVector LiftAsFunctionCall( + OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, const StringRef func_name, + const ArrayRef arguments, const ArrayRef results, + const ArrayRef attributes) { MLIRContext* context = builder.getContext(); if (results.empty()) { - mlir::emitError(UnknownLoc::get(context), "No result values specified"); + emitError(UnknownLoc::get(context), "No result values specified"); return {}; } Operation* result_op = results[0].getDefiningOp(); @@ -324,10 +323,11 @@ llvm::SmallVector LiftAsFunctionCall( TypeRange result_types{ValueRange{results}}; auto func_type = FunctionType::get(context, arg_types, result_types); - llvm::SmallVector arg_locs; - for (const auto& arg : arguments) { + SmallVector arg_locs; + for (Value arg : arguments) { arg_locs.push_back(arg.getLoc()); } + auto wrap_func = builder.create(location, func_name, func_type); wrap_func.setVisibility(SymbolTable::Visibility::Private); // The callee function for TF::XlaCallModuleOp must have this attribute. @@ -361,34 +361,36 @@ llvm::SmallVector LiftAsFunctionCall( builder.clone(*op, mapping); } - llvm::SmallVector return_values; + SmallVector return_values; for (Value result : results) { return_values.push_back(mapping.lookupOrNull(result)); } - builder.create(location, return_values); + builder.create(location, return_values); // Create a function call to the newly created function. StringAttr new_func_name = - InsertToSymbolTable(*module, *wrap_func, func_name.str()); + InsertToSymbolTable(*module, *wrap_func, func_name); builder.setInsertionPointAfter(result_op); ValueRange new_results = - createFunctionCallOp(builder, call_op_loc, call_op_type, + CreateFunctionCallOp(builder, call_op_loc, call_op_type, new_func_name.getValue(), result_types, arguments); - return llvm::SmallVector(new_results.begin(), new_results.end()); + return SmallVector(new_results.begin(), new_results.end()); } -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector& arguments, - const llvm::SmallVector& results) { - llvm::SmallVector attributes; +SmallVector LiftAsFunctionCall(OpBuilder& builder, + const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const ArrayRef arguments, + const ArrayRef results) { + SmallVector attributes; return LiftAsFunctionCall(builder, location, call_op_type, func_name, arguments, results, attributes); } -llvm::SmallVector AppendToVector( - const llvm::SmallVector& arguments, Value append) { - llvm::SmallVector ret(arguments); +SmallVector AppendToVector(const ArrayRef arguments, + Value append) { + SmallVector ret(arguments); ret.push_back(append); return ret; } @@ -402,7 +404,7 @@ llvm::SmallVector AppendToVector( // could process the following equation by setting the attributes properly: // abc,cd->abd. // 4. The output should be in the form: [batch dims][lhs dims][rhs dims] -bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr) { +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { StringRef equation = equation_attr.getValue(); if (!absl::StrContains(equation, "->") || !absl::StrContains(equation, ",") || diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index f2edd732f50cc5..db86b56734ab99 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -48,10 +48,10 @@ inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; // Checks if the op is inside a lifted function. -bool IsInLiftedFunc(Operation &op); +bool IsInLiftedFunc(Operation& op); // Checks if the given einsum op is supported for XlaDotV2 quantization. -bool IsEinsumSupportedByXlaDotV2(mlir::StringAttr equation_attr); +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); // Gets the quantization method from the given `XlaCallModuleOp`. It is // retrieved from the `kQuantizationMethodAttr` string attribute. Returns @@ -64,23 +64,24 @@ absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( // The generated function call op type will be decided by the given call_op_type // argument. Currently, it supports TF::XlaCallModuleOp and // TF::PartitionedCallOp function call op generations. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results, - const llvm::SmallVector &attributes); +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results, + ArrayRef attributes); // Same as above but with empty attributes. -llvm::SmallVector LiftAsFunctionCall( - OpBuilder builder, Location location, FunctionCallOpType call_op_type, - StringRef func_name, const llvm::SmallVector &arguments, - const llvm::SmallVector &results); +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results); // Add the second argument to the first argument, which is expected to be an // argument list. // Used to attach bias to einsum argument list. -llvm::SmallVector AppendToVector( - const llvm::SmallVector &arguments, Value append); +SmallVector AppendToVector(ArrayRef arguments, Value append); } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index 3d1285928f5f18..30c1a342f8d4d5 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -46,7 +46,7 @@ using ::testing::NotNull; using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; -using LiftAsFunctionCallTest = ::mlir::quant::QuantizationTestBase; +using LiftAsFunctionCallTest = QuantizationTestBase; constexpr absl::string_view kModuleLifted = R"mlir( module { @@ -65,9 +65,8 @@ TEST_F(LiftAsFunctionCallTest, LiftedFunctionSucceeds) { module_op->lookupSymbol("composite_dot_general_fn_1"); ASSERT_THAT(composite_dot_general_fn, NotNull()); - Operation* dot_general_op = - FindOperationOfType( - composite_dot_general_fn); + auto dot_general_op = FindOperationOfType( + composite_dot_general_fn); EXPECT_TRUE(IsInLiftedFunc(*dot_general_op)); } @@ -87,7 +86,7 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); - Operation* dot_general_op = + auto dot_general_op = FindOperationOfType(main_fn); const SmallVector& attributes = { @@ -97,19 +96,20 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { 1, mlir::stablehlo::PrecisionAttr::get( ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))), }; + const SmallVector operands(dot_general_op->getOperands()); + const SmallVector results(dot_general_op->getResults()); Operation* lifted_op = LiftAsFunctionCall(builder_, dot_general_op->getLoc(), FunctionCallOpType::TFXlaCallModuleOp, - "composite_dot_general_fn", - dot_general_op->getOperands(), - dot_general_op->getResults(), attributes)[0] + "composite_dot_general_fn", operands, results, + attributes)[0] .getDefiningOp(); const auto entry_function_symbol_ref = lifted_op->getAttrOfType("_entry_function"); SymbolTable symbol_table(*module_op); auto entry_func = dyn_cast_or_null( symbol_table.lookup(entry_function_symbol_ref.getValue())); - Operation* lifted_dot_general_op = + auto lifted_dot_general_op = FindOperationOfType(entry_func); EXPECT_TRUE(isa(lifted_op)); @@ -129,13 +129,14 @@ TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { func::FuncOp main_fn = FindMainFuncOp(*module_op); ASSERT_THAT(main_fn, NotNull()); - Operation* dot_general_op = + auto dot_general_op = FindOperationOfType(main_fn); + const SmallVector operands(dot_general_op->getOperands()); + const SmallVector results(dot_general_op->getResults()); Operation* lifted_op = - LiftAsFunctionCall( - builder_, dot_general_op->getLoc(), - FunctionCallOpType::TFXlaCallModuleOp, "composite_dot_general_fn", - dot_general_op->getOperands(), dot_general_op->getResults())[0] + LiftAsFunctionCall(builder_, dot_general_op->getLoc(), + FunctionCallOpType::TFXlaCallModuleOp, + "composite_dot_general_fn", operands, results)[0] .getDefiningOp(); EXPECT_TRUE(isa(lifted_op)); EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), From 4452743c353854631826fd1793753e0a41181926 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Mon, 18 Mar 2024 16:33:26 -0700 Subject: [PATCH 051/670] [xla:gpu] Dynamic offsets must be read one by one When creating address computation thunk, we need to load dynamic offsets from device to host one by one, as dynamic-slice and DUS ops have each offset defined by a separate runtime value. PiperOrigin-RevId: 616981622 --- .../gpu/runtime/address_computation_thunk.cc | 86 +++--- .../gpu/runtime/address_computation_thunk.h | 8 +- .../runtime/address_computation_thunk_test.cc | 246 +++++++++++++----- 3 files changed, 235 insertions(+), 105 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 8affba065d2d78..28cf9163774ca5 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -46,11 +46,11 @@ AddressComputationThunk::AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> operands, std::vector> results, - std::vector> + std::vector>> operand_offset_buffer_indices, std::vector> operand_orig_shapes, std::vector> operand_sliced_shapes, - std::vector> + std::vector>> result_offset_buffer_indices, std::vector> result_orig_shapes, std::vector> result_sliced_shapes) @@ -79,6 +79,10 @@ absl::Status AddressComputationThunk::Prepare( TF_RET_CHECK(operand_sliced_shapes_[i]->IsArray()); TF_RET_CHECK(operand_orig_shapes_[i].has_value() && operand_orig_shapes_[i]->IsArray()); + TF_RET_CHECK(operand_sliced_shapes_[i]->rank() == + operand_orig_shapes_[i]->rank()); + TF_RET_CHECK(operand_offset_buffer_indices_[i]->size() == + operand_orig_shapes_[i]->rank()); } } @@ -93,6 +97,10 @@ absl::Status AddressComputationThunk::Prepare( TF_RET_CHECK(result_sliced_shapes_[i]->IsArray()); TF_RET_CHECK(result_orig_shapes_[i].has_value() && result_orig_shapes_[i]->IsArray()); + TF_RET_CHECK(result_sliced_shapes_[i]->rank() == + result_orig_shapes_[i]->rank()); + TF_RET_CHECK(result_offset_buffer_indices_[i]->size() == + result_orig_shapes_[i]->rank()); } } @@ -167,32 +175,37 @@ absl::Status AddressComputationThunk::ExecuteOnStream( continue; } - se::DeviceMemoryBase offset_src = - orig_allocations.GetDeviceAddress(*operand_offset_buffer_indices_[i]); - - // Copy the ith offset from device to host. const Shape& src_shape = *operand_orig_shapes_[i]; const Shape& dst_shape = *operand_sliced_shapes_[i]; - int64_t* offset_dst = &operand_offsets_base[i]; - TF_RETURN_IF_ERROR(stream.Memcpy(offset_dst, offset_src, - dst_shape.rank() * sizeof(int64_t))); - - if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { - return absl::InternalError(absl::StrFormat( - "Failed to retrieve all slice offset values on stream %p: %s", - &stream, blocked.message())); + TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); + + std::vector slice_starts; + slice_starts.reserve(dst_shape.rank()); + + // Get offset for ith operand, which has `dst_shape.rank()` components. + for (auto [idx, offset_slice] : + llvm::enumerate(*operand_offset_buffer_indices_[i])) { + se::DeviceMemoryBase offset_src = + orig_allocations.GetDeviceAddress(offset_slice); + int64_t* offset_dst = &operand_offsets_base[i + idx]; + // Copy the idx-th component of the ith offset from device to host. + TF_RETURN_IF_ERROR( + stream.Memcpy(offset_dst, offset_src, sizeof(int64_t))); + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to retrieve all slice offset values on stream %p: %s", + &stream, blocked.message())); + } + slice_starts.push_back(*offset_dst); } // Compute new slice. No need to copy the content to new buffers as we can // reuse the original buffers since slices are contiguous. - TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); - int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); BufferAllocation::Slice orig_slice = *embedded_thunk_operands_[i]; int64_t new_offset = orig_slice.offset(); - std::vector slice_starts(offset_dst, - offset_dst + dst_shape.rank()); for (auto [start, stride] : llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { new_offset += start * stride; @@ -221,32 +234,37 @@ absl::Status AddressComputationThunk::ExecuteOnStream( continue; } - se::DeviceMemoryBase offset_src = - orig_allocations.GetDeviceAddress(*result_offset_buffer_indices_[i]); - - // Copy the ith offset from device to host. const Shape& src_shape = *result_orig_shapes_[i]; const Shape& dst_shape = *result_sliced_shapes_[i]; - int64_t* offset_dst = &result_offsets_base[i]; - TF_RETURN_IF_ERROR(stream.Memcpy(offset_dst, offset_src, - dst_shape.rank() * sizeof(int64_t))); - - if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { - return absl::InternalError(absl::StrFormat( - "Failed to retrieve all slice offset values on stream %p: %s", - &stream, blocked.message())); + TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); + + std::vector slice_starts; + slice_starts.reserve(dst_shape.rank()); + + // Get offset for ith result, which has `dst_shape.rank()` components. + for (auto [idx, offset_slice] : + llvm::enumerate(*result_offset_buffer_indices_[i])) { + se::DeviceMemoryBase offset_src = + orig_allocations.GetDeviceAddress(offset_slice); + int64_t* offset_dst = &result_offsets_base[i + idx]; + // Copy the idx-th component of the ith offset from device to host. + TF_RETURN_IF_ERROR( + stream.Memcpy(offset_dst, offset_src, sizeof(int64_t))); + + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to retrieve all slice offset values on stream %p: %s", + &stream, blocked.message())); + } + slice_starts.push_back(*offset_dst); } // Compute new slice. No need to copy the content to new buffers as we can // reuse the original buffers since slices are contiguous. - TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); - int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); BufferAllocation::Slice orig_slice = *embedded_thunk_results_[i]; int64_t new_offset = orig_slice.offset(); - std::vector slice_starts(offset_dst, - offset_dst + dst_shape.rank()); for (auto [start, stride] : llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { new_offset += start * stride; diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index d4bdbfe287d9b1..b52b5fdfde861e 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -46,11 +46,11 @@ class AddressComputationThunk : public Thunk { ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> operands, std::vector> results, - std::vector> + std::vector>> operand_offset_buffer_indices, std::vector> operand_orig_shapes, std::vector> operand_sliced_shapes, - std::vector> + std::vector>> result_offset_buffer_indices, std::vector> result_orig_shapes, std::vector> result_sliced_shapes); @@ -69,11 +69,11 @@ class AddressComputationThunk : public Thunk { embedded_thunk_operands_; std::vector> embedded_thunk_results_; - std::vector> + std::vector>> operand_offset_buffer_indices_; std::vector> operand_orig_shapes_; std::vector> operand_sliced_shapes_; - std::vector> + std::vector>> result_offset_buffer_indices_; std::vector> result_orig_shapes_; std::vector> result_sliced_shapes_; diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index e783cdea0ba6a3..1167cf18a93c57 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -77,7 +77,7 @@ TEST(AddressComputationThunkTest, SlicedGemm) { int64_t lhs_length = sizeof(float) * 2 * 4; int64_t rhs_length = sizeof(float) * 3 * 1; int64_t out_length = sizeof(float) * 1 * 1; - int64_t lhs_offset_length = sizeof(int64_t) * 2; + int64_t offset_length = sizeof(int64_t); // Step 1: // Prepare embedded and address computation thunks. @@ -95,10 +95,15 @@ TEST(AddressComputationThunkTest, SlicedGemm) { BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); - BufferAllocation alloc_lhs_offset(/*index=*/4, lhs_offset_length, - /*color=*/0); - BufferAllocation::Slice slice_lhs_offset(&alloc_lhs_offset, 0, - lhs_offset_length); + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); BufferAllocation alloc_lhs_fake(/*index=*/0, rhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, rhs_length); @@ -119,10 +124,12 @@ TEST(AddressComputationThunkTest, SlicedGemm) { slice_out, slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {slice_lhs_offset, std::nullopt}, + {slice_out, slice_workspace}, {lhs_offsets, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt}, {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, @@ -157,15 +164,17 @@ TEST(AddressComputationThunkTest, SlicedGemm) { executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); - se::DeviceMemory lhs_offset = executor->AllocateArray(2); + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); std::vector lhs_offset_arr{0, 1}; - TF_ASSERT_OK( - stream.Memcpy(&lhs_offset, lhs_offset_arr.data(), lhs_offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; - BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset}, 0, - executor->GetAllocator()); + BufferAllocations allocations( + {lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, + executor->GetAllocator()); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, &stream, &stream, {}, nullptr, nullptr); @@ -194,7 +203,7 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { int64_t lhs_length = sizeof(float) * 2 * 4; int64_t rhs_length = sizeof(float) * 4 * 3; int64_t out_length = sizeof(float) * 2 * 2; - int64_t offset_length = sizeof(int64_t) * 2; + int64_t offset_length = sizeof(int64_t); int64_t slice_length = sizeof(float) * 2 * 2; // Step 1: @@ -213,11 +222,25 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); - BufferAllocation alloc_lhs_offset(/*index=*/4, offset_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_offset(&alloc_lhs_offset, 0, offset_length); + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); - BufferAllocation alloc_rhs_offset(/*index=*/5, offset_length, /*color=*/0); - BufferAllocation::Slice slice_rhs_offset(&alloc_rhs_offset, 0, offset_length); + BufferAllocation alloc_rhs_offset_0(/*index=*/6, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_0(&alloc_rhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_1(/*index=*/7, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, + offset_length); BufferAllocation alloc_lhs_fake(/*index=*/0, slice_length, /*color=*/0); BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, slice_length); @@ -241,10 +264,14 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { slice_out, slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + std::vector rhs_offsets{slice_rhs_offset_0, + slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {slice_lhs_offset, slice_rhs_offset}, + {slice_out, slice_workspace}, {lhs_offsets, rhs_offsets}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3})}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), @@ -286,21 +313,23 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); - se::DeviceMemory lhs_offset = executor->AllocateArray(2); + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); std::vector lhs_offset_arr{0, 1}; - TF_ASSERT_OK( - stream.Memcpy(&lhs_offset, lhs_offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); - se::DeviceMemory rhs_offset = executor->AllocateArray(2); + se::DeviceMemory rhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory rhs_offset_1 = executor->AllocateArray(1); std::vector rhs_offset_arr{2, 1}; - TF_ASSERT_OK( - stream.Memcpy(&rhs_offset, rhs_offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&rhs_offset_0, &rhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&rhs_offset_1, &rhs_offset_arr[1], offset_length)); // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; - BufferAllocations allocations( - {lhs, rhs, out, workspace, lhs_offset, rhs_offset}, 0, - executor->GetAllocator()); + BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, + lhs_offset_1, rhs_offset_0, rhs_offset_1}, + 0, executor->GetAllocator()); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, &stream, &stream, {}, nullptr, nullptr); @@ -322,7 +351,7 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { int64_t length = sizeof(float) * 2 * 4; int64_t out_length = sizeof(float) * 1; - int64_t offset_length = sizeof(int64_t) * 2; + int64_t offset_length = sizeof(int64_t); int64_t slice_length = sizeof(float) * 3; // Step 1: @@ -341,17 +370,31 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); - BufferAllocation alloc_lhs_offset(/*index=*/4, offset_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_offset(&alloc_lhs_offset, 0, offset_length); + BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/5, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); - BufferAllocation alloc_rhs_offset(/*index=*/5, offset_length, /*color=*/0); - BufferAllocation::Slice slice_rhs_offset(&alloc_rhs_offset, 0, offset_length); + BufferAllocation alloc_rhs_offset_0(/*index=*/6, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_0(&alloc_rhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_rhs_offset_1(/*index=*/7, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, + offset_length); BufferAllocation alloc_lhs_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, slice_length); + BufferAllocation::Slice slice_lhs_fake(&alloc_lhs, 0, slice_length); BufferAllocation alloc_rhs_fake(/*index=*/1, slice_length, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake(&alloc_rhs_fake, 0, slice_length); + BufferAllocation::Slice slice_rhs_fake(&alloc_rhs, 0, slice_length); // Preparing config for GEMM thunk. auto config = @@ -369,10 +412,14 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { slice_out, slice_workspace, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + std::vector rhs_offsets{slice_rhs_offset_0, + slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {slice_lhs_offset, slice_rhs_offset}, + {slice_out, slice_workspace}, {lhs_offsets, rhs_offsets}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1})}, {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), @@ -418,21 +465,23 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { executor->AllocateArray(1024 * 1024); TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); - se::DeviceMemory lhs_offset = executor->AllocateArray(2); + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); std::vector lhs_offset_arr{0, 1}; - TF_ASSERT_OK( - stream.Memcpy(&lhs_offset, lhs_offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); - se::DeviceMemory rhs_offset = executor->AllocateArray(2); + se::DeviceMemory rhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory rhs_offset_1 = executor->AllocateArray(1); std::vector rhs_offset_arr{2, 0}; - TF_ASSERT_OK( - stream.Memcpy(&rhs_offset, rhs_offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&rhs_offset_0, &rhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&rhs_offset_1, &rhs_offset_arr[1], offset_length)); // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; - BufferAllocations allocations( - {lhs, rhs, out, workspace, lhs_offset, rhs_offset}, 0, - executor->GetAllocator()); + BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, + lhs_offset_1, rhs_offset_0, rhs_offset_1}, + 0, executor->GetAllocator()); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, &stream, &stream, {}, nullptr, nullptr); @@ -480,7 +529,7 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { int64_t dst_count = 8 * 8; int64_t src_length = sizeof(int32_t) * src_count; int64_t dst_length = sizeof(int32_t) * dst_count; - int64_t offset_length = sizeof(int64_t) * 4; + int64_t offset_length = sizeof(int64_t); int64_t slice_length = sizeof(int32_t) * dst_count; // Step 1: @@ -493,8 +542,17 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); - BufferAllocation alloc_offset(/*index=*/2, offset_length, /*color=*/0); - BufferAllocation::Slice slice_offset(&alloc_offset, 0, offset_length); + BufferAllocation alloc_offset_0(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_0(&alloc_offset_0, 0, offset_length); + + BufferAllocation alloc_offset_1(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_1(&alloc_offset_1, 0, offset_length); + + BufferAllocation alloc_offset_2(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_2(&alloc_offset_2, 0, offset_length); + + BufferAllocation alloc_offset_3(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_offset_3(&alloc_offset_3, 0, offset_length); // Fake slices for embedded thunk creation. BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); @@ -520,10 +578,13 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { /*called_computation=*/nullptr)); // Wrapping address computation thunk around the custom call thunk. + std::vector slice_offsets{ + slice_offset_0, slice_offset_1, slice_offset_2, slice_offset_3}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_src}, {slice_dst}, - {slice_offset}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8})}, + {slice_offsets}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8})}, {std::nullopt}, @@ -546,14 +607,21 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { se::DeviceMemory dst = executor->AllocateArray(dst_count); TF_ASSERT_OK(stream.MemZero(&dst, dst_length)); - se::DeviceMemory offset = executor->AllocateArray(4); + se::DeviceMemory offset_0 = executor->AllocateArray(1); + se::DeviceMemory offset_1 = executor->AllocateArray(1); + se::DeviceMemory offset_2 = executor->AllocateArray(1); + se::DeviceMemory offset_3 = executor->AllocateArray(1); std::vector offset_arr{3, 5, 2, 0}; - TF_ASSERT_OK(stream.Memcpy(&offset, offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&offset_0, &offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&offset_1, &offset_arr[1], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&offset_2, &offset_arr[2], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&offset_3, &offset_arr[3], offset_length)); // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; - BufferAllocations allocations({src, dst, offset}, 0, - executor->GetAllocator()); + BufferAllocations allocations( + {src, dst, offset_0, offset_1, offset_2, offset_3}, 0, + executor->GetAllocator()); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, &stream, &stream, {}, nullptr, nullptr); @@ -591,7 +659,7 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { int64_t slice_count = 2 * 2; int64_t src_length = sizeof(int32_t) * src_count; int64_t dst_length = sizeof(int32_t) * dst_count; - int64_t offset_length = sizeof(int64_t) * 4; + int64_t offset_length = sizeof(int64_t); int64_t slice_length = sizeof(int32_t) * slice_count; // Step 1: @@ -604,11 +672,37 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); - BufferAllocation alloc_src_offset(/*index=*/2, offset_length, /*color=*/0); - BufferAllocation::Slice slice_src_offset(&alloc_src_offset, 0, offset_length); + BufferAllocation alloc_src_offset_0(/*index=*/2, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_0(&alloc_src_offset_0, 0, + offset_length); + + BufferAllocation alloc_src_offset_1(/*index=*/3, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_1(&alloc_src_offset_1, 0, + offset_length); + + BufferAllocation alloc_src_offset_2(/*index=*/4, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_2(&alloc_src_offset_2, 0, + offset_length); - BufferAllocation alloc_dst_offset(/*index=*/3, offset_length, /*color=*/0); - BufferAllocation::Slice slice_dst_offset(&alloc_dst_offset, 0, offset_length); + BufferAllocation alloc_src_offset_3(/*index=*/5, offset_length, /*color=*/0); + BufferAllocation::Slice slice_src_offset_3(&alloc_src_offset_3, 0, + offset_length); + + BufferAllocation alloc_dst_offset_0(/*index=*/6, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_0(&alloc_dst_offset_0, 0, + offset_length); + + BufferAllocation alloc_dst_offset_1(/*index=*/7, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_1(&alloc_dst_offset_1, 0, + offset_length); + + BufferAllocation alloc_dst_offset_2(/*index=*/8, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_2(&alloc_dst_offset_2, 0, + offset_length); + + BufferAllocation alloc_dst_offset_3(/*index=*/9, offset_length, /*color=*/0); + BufferAllocation::Slice slice_dst_offset_3(&alloc_dst_offset_3, 0, + offset_length); // Fake slices for embedded thunk creation. BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); @@ -637,15 +731,21 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { /*called_computation=*/nullptr)); // Wrapping address computation thunk around the custom call thunk. + std::vector slice_src_offsets{ + slice_src_offset_0, slice_src_offset_1, slice_src_offset_2, + slice_src_offset_3}; + std::vector slice_dst_offsets{ + slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2, + slice_dst_offset_3}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_src}, {slice_dst}, - {slice_src_offset}, + {slice_src_offsets}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}, - {slice_dst_offset}, + {slice_dst_offsets}, {{ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}}, {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}); @@ -671,20 +771,32 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { se::DeviceMemory dst = executor->AllocateArray(dst_count); TF_ASSERT_OK(stream.MemZero(&dst, dst_length)); - se::DeviceMemory src_offset = executor->AllocateArray(4); + se::DeviceMemory src_offset_0 = executor->AllocateArray(1); + se::DeviceMemory src_offset_1 = executor->AllocateArray(1); + se::DeviceMemory src_offset_2 = executor->AllocateArray(1); + se::DeviceMemory src_offset_3 = executor->AllocateArray(1); std::vector src_offset_arr{3, 5, 2, 0}; - TF_ASSERT_OK( - stream.Memcpy(&src_offset, src_offset_arr.data(), offset_length)); - - se::DeviceMemory dst_offset = executor->AllocateArray(4); + TF_ASSERT_OK(stream.Memcpy(&src_offset_0, &src_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&src_offset_1, &src_offset_arr[1], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&src_offset_2, &src_offset_arr[2], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&src_offset_3, &src_offset_arr[3], offset_length)); + + se::DeviceMemory dst_offset_0 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_1 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_2 = executor->AllocateArray(1); + se::DeviceMemory dst_offset_3 = executor->AllocateArray(1); std::vector dst_offset_arr{1, 1, 0, 0}; - TF_ASSERT_OK( - stream.Memcpy(&dst_offset, dst_offset_arr.data(), offset_length)); + TF_ASSERT_OK(stream.Memcpy(&dst_offset_0, &dst_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&dst_offset_1, &dst_offset_arr[1], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&dst_offset_2, &dst_offset_arr[2], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&dst_offset_3, &dst_offset_arr[3], offset_length)); // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; - BufferAllocations allocations({src, dst, src_offset, dst_offset}, 0, - executor->GetAllocator()); + BufferAllocations allocations( + {src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3, + dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3}, + 0, executor->GetAllocator()); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, &stream, &stream, {}, nullptr, nullptr); From befa96da07437b7735c385a6fae7c7ccb8ef1c21 Mon Sep 17 00:00:00 2001 From: Eunjae Kim Date: Mon, 18 Mar 2024 16:41:09 -0700 Subject: [PATCH 052/670] Fix the shared_batch_scheduler_test to avoid using the designated initializer to fix the windows build failure PiperOrigin-RevId: 616983664 --- tensorflow/core/kernels/batching_util/BUILD | 1 - .../shared_batch_scheduler_test.cc | 66 +++++++++---------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 828b1c0f60d4fb..d34bd7331a35d5 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -190,7 +190,6 @@ tf_cc_test( name = "shared_batch_scheduler_test", size = "small", srcs = ["shared_batch_scheduler_test.cc"], - tags = ["no_windows"], deps = [ ":batch_scheduler", ":fake_clock_env", diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index 29b79b3bb4b712..680bbb5dd56206 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -434,39 +434,39 @@ TEST_P( // Create two queues. - const SharedBatchScheduler::QueueOptions - queue_options = { - .input_batch_size_limit = 10, - .batch_timeout_micros = 1000 * 1000, - .max_enqueued_batches = 2, - .enable_large_batch_splitting = enable_input_batch_split(), - .split_input_task_func = - [](std::unique_ptr* input_task, - int open_batch_remaining_slot, int max_batch_size, - std::vector>* - output_tasks) -> Status { - std::unique_ptr owned_input_task = - std::move(*input_task); - const int input_task_size = owned_input_task->size(); - - const internal::InputSplitMetadata input_split_metadata( - input_task_size, open_batch_remaining_slot, max_batch_size); - - const absl::FixedArray task_sizes = - input_split_metadata.task_sizes(); - const int num_batches = task_sizes.size(); - - output_tasks->resize(num_batches); - for (int i = 0; i < num_batches; i++) { - (*output_tasks)[i] = - std::make_unique(task_sizes[i]); - } - - return absl::OkStatus(); - }, - .enable_lazy_split = enable_lazy_split(), - .max_execution_batch_size = 10, - .enable_priority_queue = true}; + SharedBatchScheduler::QueueOptions + queue_options; + queue_options.input_batch_size_limit = 10; + queue_options.batch_timeout_micros = 1000 * 1000; + queue_options.max_enqueued_batches = 2; + queue_options.enable_large_batch_splitting = enable_input_batch_split(); + queue_options.split_input_task_func = + [](std::unique_ptr* input_task, + int open_batch_remaining_slot, int max_batch_size, + std::vector>* + output_tasks) -> Status { + std::unique_ptr owned_input_task = + std::move(*input_task); + const int input_task_size = owned_input_task->size(); + + const internal::InputSplitMetadata input_split_metadata( + input_task_size, open_batch_remaining_slot, max_batch_size); + + const absl::FixedArray task_sizes = + input_split_metadata.task_sizes(); + const int num_batches = task_sizes.size(); + + output_tasks->resize(num_batches); + for (int i = 0; i < num_batches; i++) { + (*output_tasks)[i] = + std::make_unique(task_sizes[i]); + } + + return absl::OkStatus(); + }; + queue_options.enable_lazy_split = enable_lazy_split(); + queue_options.max_execution_batch_size = 10; + queue_options.enable_priority_queue = true; std::unique_ptr> queue_0; TF_CHECK_OK(shared_batch_scheduler->AddQueue(queue_options, From dfe2e26f8673f529f43e0f8e20ad8a7afad15aeb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 16:54:49 -0700 Subject: [PATCH 053/670] Support host offloaded values as entry computation output PiperOrigin-RevId: 616987214 --- third_party/xla/xla/service/BUILD | 4 +- third_party/xla/xla/service/host_offloader.cc | 58 ++++++++++++- third_party/xla/xla/service/host_offloader.h | 2 + .../xla/xla/service/host_offloader_test.cc | 83 ++++++++++++++++++- .../xla/xla/service/layout_assignment.cc | 36 +++++--- 5 files changed, 168 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 1ce30da506c7f3..f41fe0f3a7c93c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4571,7 +4571,6 @@ cc_library( deps = [ ":call_graph", ":computation_layout", - ":hlo_alias_analysis", ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", @@ -4581,6 +4580,7 @@ cc_library( "//xla:permutation_util", "//xla:shape_layout", "//xla:shape_util", + "//xla:status", "//xla:status_macros", "//xla:statusor", "//xla:types", @@ -4598,8 +4598,8 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index b484d7fec3418b..9058a9aa48c515 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -166,6 +166,56 @@ HloInstruction* FindDSAnnotation(HloInstruction* hlo) { } // namespace +absl::StatusOr HostOffloader::TryOutputStreaming( + HloInstruction* custom_call) { + const HloBuffer& unique_buffer = + alias_analysis_->GetUniqueBufferAt(custom_call); + bool is_used_as_output_with_host_memory_space = false; + const HloComputation* const entry_computation = + custom_call->GetModule()->entry_computation(); + for (const HloValue* value : unique_buffer.values()) { + // Check if this is memory-only. + if (!AllPositionsAreAllowed(value)) { + // Found a position which is not allowed. + return false; + } + + // Look for a value used as a output. + for (const auto& position : value->positions()) { + const HloInstruction* instruction = position.instruction; + const ShapeIndex& index = position.index; + if (instruction->parent() == entry_computation && instruction->IsRoot()) { + const Shape& output_shape = + ShapeUtil::GetSubshape(entry_computation->parent() + ->entry_computation_layout() + .result_shape(), + index); + CHECK(output_shape.has_layout()); + + if (output_shape.layout().memory_space() != kHostMemorySpaceColor) { + return FailedPrecondition( + "Output buffer is annotated with %s but is not marked with host " + "memory space in the entry computation.", + custom_call->name()); + } + is_used_as_output_with_host_memory_space = true; + } + } + } + if (!is_used_as_output_with_host_memory_space) { + VLOG(1) << "Buffer annotated by " << custom_call->name() + << " is not used as an output with host memory space."; + return false; + } + + VLOG(3) << "Found an output buffer annotated with " << custom_call->name() + << ". Expecting that we'll need to insert copies."; + + annotations_for_copy_to_host_to_insert_.emplace(custom_call); + AddAllPositionsToBeMovedToHostMemory(unique_buffer); + return true; +} + Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) { VLOG(2) << "Found a custom call annotating start-of-host-offload: " << custom_call->ToString(); @@ -195,7 +245,11 @@ Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) { } else if (op_being_annotated->opcode() == HloOpcode::kCopy) { TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(op_being_annotated)); } else { - TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call)); + TF_ASSIGN_OR_RETURN(bool did_output_streaming, + TryOutputStreaming(custom_call)); + if (!did_output_streaming) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call)); + } } return OkStatus(); } @@ -576,7 +630,7 @@ absl::StatusOr HostOffloader::Run( // Run HloAliasAnalysis on module. TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module)); - // Iterate over all instructions and look for XLA host offload annoations. + // Iterate over all instructions and look for XLA host offload annotations. for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index 85966a312dc790..8bd2c0fb26598a 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_pass_interface.h" @@ -67,6 +68,7 @@ class HostOffloader : public HloModulePass { void AddAllPositionsToBeMovedToHostMemory(const HloBuffer& unique_buffer); absl::StatusOr TryParameterStreaming(HloInstruction* custom_call); + absl::StatusOr TryOutputStreaming(HloInstruction* custom_call); Status HandleMoveToHostCustomCall(HloInstruction* custom_call); Status HandleMoveToDeviceCustomCall(HloInstruction* custom_call); diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc index 4eb459c2e60222..6b367fe53a2f54 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/service/host_offloader_test.cc @@ -1779,7 +1779,7 @@ ENTRY main { TEST_F(HostOffloaderTest, ParameterStreaming) { const std::string& hlo_string = R"( -HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)S(5)})} +HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})} ENTRY main { param_0 = s32[2,1]{1,0} parameter(0) @@ -1854,6 +1854,87 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } +TEST_F(HostOffloaderTest, OutputStreaming) { + const std::string& hlo_string = R"( +HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})} + +ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" + ROOT tuple = (s32[2,1]{1,0}, s32[2,1]{1,0}) tuple(custom_call, multiply_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // param1 broadcast param0 + // \ / / + // multiply / + // \ / + // \ / + // multiply constant + // | | | + // | ---+---broadcast + // | / | + // multiply | + // | | + // copy | + // \ | + // tuple + HloInstruction* param_1; + HloInstruction* broadcast_0; + HloInstruction* multiply_0; + HloInstruction* param_0; + HloInstruction* multiply_1; + HloInstruction* broadcast_1; + HloInstruction* multiply_2; + HloInstruction* copy; + HloInstruction* tuple; + auto multiplyPattern = + m::Multiply(&multiply_1, + m::Multiply(&multiply_0, m::Parameter(¶m_1), + m::Broadcast(&broadcast_0, m::ConstantScalar(2))), + m::Parameter(¶m_0)); + ASSERT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + &tuple, + m::Copy(©, m::Multiply( + &multiply_2, multiplyPattern, + m::Broadcast(&broadcast_1, m::ConstantScalar(4)))), + multiplyPattern))); + TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {0}), + Layout::kHostMemorySpace); + TestShapeHasMemorySpace(ShapeUtil::GetSubshape(tuple->shape(), {1}), + Layout::kDefaultMemorySpace); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 67874e38d67a17..c79261e6bb8d6c 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -15,24 +15,23 @@ limitations under the License. #include "xla/service/layout_assignment.h" -#include +#include #include -#include #include #include -#include #include #include #include -#include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -40,12 +39,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/map_util.h" #include "xla/permutation_util.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/logical_buffer.h" #include "xla/service/tuple_points_to_analysis.h" @@ -53,15 +53,15 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -2019,12 +2019,28 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( Status LayoutAssignment::PropagateResultConstraint( const ComputationLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { + ShapeLayout result_layout = + layout_constraint.computation_layout().result_layout(); + // Clear out memory space in layout for entry computation root. Host offloader + // will do the analysis later and add back the memory space for host outputs. + if (constraints->computation()->IsEntryComputation()) { + Shape result_shape = result_layout.shape(); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + &result_shape, [](Shape* subshape, const ShapeIndex& shape_index) { + if (subshape->has_layout() && subshape->IsArray()) { + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + } + return OkStatus(); + })); + TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape)); + } + // Propagate the use constraint of the root instruction up to the logical // buffers which make up the result. return PropagateUseConstraintToDefs( - layout_constraint.computation_layout().result_layout(), - constraints->computation()->root_instruction(), constraints, - current_priority_); + result_layout, constraints->computation()->root_instruction(), + constraints, current_priority_); } // Infers the layout of the array at the given index in the given instruction's From a141be8dae4222f4fbc23b0ba9919cb8fbca2ac6 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 17:09:51 -0700 Subject: [PATCH 054/670] Implement basic `QuantizationReport`. This is a minimally working version of `QuantizationReport` where a user may add a single `QuantizationResult` manually. In future revisions, it will be able to parse `QuantizationResult`s from `ModuleOp` and populate internal data automatically. PiperOrigin-RevId: 616991325 --- .../mlir/quantization/stablehlo/cc/BUILD | 20 ++++++ .../mlir/quantization/stablehlo/cc/report.cc | 29 +++++++++ .../mlir/quantization/stablehlo/cc/report.h | 48 ++++++++++++++ .../quantization/stablehlo/cc/report_test.cc | 64 +++++++++++++++++++ .../stablehlo/quantization_config.proto | 26 ++++++++ 5 files changed, 187 insertions(+) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 7a36ad58dc34a4..2ba0127d2b9c97 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -276,6 +276,26 @@ tf_cc_test( ], ) +cc_library( + name = "report", + srcs = ["report.cc"], + hdrs = ["report.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + ], +) + +tf_cc_test( + name = "report_test", + srcs = ["report_test.cc"], + deps = [ + ":report", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "context", srcs = [], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc new file mode 100644 index 00000000000000..ef24c16dbf4acc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" + +#include + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +using ::stablehlo::quantization::QuantizationResult; + +void QuantizationReport::AddQuantizationResult(QuantizationResult&& result) { + *quantization_results_.add_results() = std::move(result); +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h new file mode 100644 index 00000000000000..94eb47463f16c1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// A class that manages information about `QuantizableUnit`s post-quantization, +// internally in the form of `QuantizationUnits`. It is used to collect +// quantization summary from a quantized `ModuleOp` and emit it in a human- and +// machine-readable format. +class QuantizationReport { + public: + QuantizationReport() = default; + + // Adds a `QuantizationResult` to the report. + void AddQuantizationResult( + ::stablehlo::quantization::QuantizationResult&& result); + + // Returns `QuantizationResults` that are registered in this report. + const ::stablehlo::quantization::QuantizationResults& GetQuantizationResults() + const { + return quantization_results_; + } + + private: + // Quantization results that are registered in this report. A quantization + // result may be added manually by calling `AddQuantizationResult`. + ::stablehlo::quantization::QuantizationResults quantization_results_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc new file mode 100644 index 00000000000000..f6897f7fde401d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h" + +#include + +#include +#include +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { +namespace { + +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizableUnit; +using ::stablehlo::quantization::QuantizationResult; +using ::stablehlo::quantization::QuantizationResults; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::StrEq; + +TEST(QuantizationReportTest, GetQuantizationResultsReturnsEmptyResults) { + QuantizationReport report{}; + + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), IsEmpty()); +} + +TEST(QuantizationReportTest, AddQuantizationResult) { + // Construct a `QuantizationResult` to add, representing a unit named + // `quantized_my_function` that is not quantized. + QuantizationResult result{}; + QuantizableUnit& quantizable_unit = *result.mutable_quantizable_unit(); + quantizable_unit.set_name("quantized_my_function"); + + Method& method = *result.mutable_method(); + method.mutable_no_quantization(); + + QuantizationReport report{}; + report.AddQuantizationResult(std::move(result)); + + const QuantizationResults& results = report.GetQuantizationResults(); + ASSERT_THAT(results.results(), SizeIs(1)); + + const QuantizationResult& first_result = results.results(0); + EXPECT_THAT(first_result.quantizable_unit().name(), + StrEq("quantized_my_function")); + EXPECT_TRUE(first_result.method().has_no_quantization()); +} + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 81aff6e46d5850..56645d7f3d73ad 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -63,6 +63,32 @@ message PipelineConfig { optional bool unpack_quantized_types = 1; } +// Represents a single quantizable unit, a (nearly) minimum unit of work when +// applying quantization. It may correspond to a single or multiple ops. +// Next ID: 2 +message QuantizableUnit { + // Name of the `FuncOp` symbol corresponding to the "lifted function", + // representing a single quantizable unit. This value is guaranteed to be + // unique across a single `ModuleOp`. + string name = 1; +} + +// Represents a quantization result of a single `QuantizableUnit`. It is +// essentially a `(QuantizableUnit, Method)` pair, where the `Method` +// corresponds to the quantization method eventually applied to the +// `QuantizableUnit`. +// Next ID: 3 +message QuantizationResult { + QuantizableUnit quantizable_unit = 1; + Method method = 2; +} + +// A series of `QuantizationResult`s. See `QuantizationResult` for details. +// Next ID: 2 +message QuantizationResults { + repeated QuantizationResult results = 1; +} + // A quantization method representing "do not quantize". Mostly used for // denylisting quantizable units from quantization. message NoQuantization {} From 858019381d8496e45a40184a705482bc78f34d5f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Mar 2024 17:22:55 -0700 Subject: [PATCH 055/670] [PJRT] Drop mentions of CPU support from the stream_executor client. We never use the stream_executor client on CPU any more, since the TFRT CPU client is better in every way. PiperOrigin-RevId: 616994524 --- .../xla/pjrt/pjrt_stream_executor_client.cc | 65 +------------------ 1 file changed, 2 insertions(+), 63 deletions(-) diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 9275f94492133d..cfe5962915dce0 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -616,9 +616,7 @@ void PjRtStreamExecutorBuffer::ScopedHold::AddToInput( } } -bool PjRtStreamExecutorBuffer::IsOnCpu() const { - return client()->platform_id() == CpuId(); -} +bool PjRtStreamExecutorBuffer::IsOnCpu() const { return false; } StatusOr PjRtStreamExecutorBuffer::logical_on_device_shape() { if (on_device_shape_.is_static()) { @@ -827,59 +825,6 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( ShapeUtil::ByteStrides(device_shape, absl::MakeSpan(shape_strides))); bool host_and_device_strides_equal = (size == 0 || *byte_strides == shape_strides); - // The CPU platform is special because the "host" and the "device" are in the - // same memory space. If the input shape is in the correct layout and we don't - // want to defer the copy onto a thread, we can use the following fast - // path. - bool is_cpu_platform = - local_device->executor()->platform()->id() == se::host::kHostPlatformId; - if (is_cpu_platform) { - // If we are on the host platform and the input buffer is sufficiently - // aligned, we can simply point to the input array's data without any - // further copies. At the time of writing we require a 16-byte alignment - // because XLA may generate code which requires it. - bool can_use_zero_copy = - host_buffer_semantics == HostBufferSemantics::kZeroCopy && - ((absl::bit_cast(data) & - (cpu_function_runtime::MinAlign() - 1)) == 0); - if (host_and_device_strides_equal && - (host_buffer_semantics == - HostBufferSemantics::kImmutableOnlyDuringCall || - can_use_zero_copy)) { - absl::AnyInvocable on_delete_callback; - se::DeviceMemoryBase buffer; - // If we are on the host platform and the input buffer is sufficiently - // aligned, we can simply point to the input array's data without any - // further copies. At the time of writing we require a 16-byte alignment - // because XLA may generate code which requires it. - if (can_use_zero_copy) { - on_delete_callback = std::move(on_done_with_host_buffer); - buffer = se::DeviceMemoryBase( - const_cast(static_cast(data)), size); - } else { - void* staging_buffer = host_memory_allocator()->AllocateRaw( - cpu_function_runtime::MinAlign(), size); - buffer = se::DeviceMemoryBase(staging_buffer, size); - std::memcpy(staging_buffer, data, size); - if (on_done_with_host_buffer) { - std::move(on_done_with_host_buffer)(); - } - on_delete_callback = [staging_buffer, host_memory_allocator = - host_memory_allocator()]() { - host_memory_allocator->DeallocateRaw(staging_buffer); - }; - } - absl::Span> - definition_events; - auto device_buffer = std::make_shared( - /*allocator=*/nullptr, local_device->local_device_id().value(), - std::initializer_list{buffer}, - definition_events, std::move(on_delete_callback)); - return std::unique_ptr( - std::make_unique( - device_shape, std::move(device_buffer), this, device)); - } - } TF_ASSIGN_OR_RETURN( std::unique_ptr py_buffer, @@ -1038,13 +983,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( } })); }; - if (is_cpu_platform) { - // Using the thread_pool would be a double thread hop; the code - // already defers its work onto a stream (= thread on CPU). - transfer_h2d(); - } else { - thread_pool()->Schedule(transfer_h2d); - } + thread_pool()->Schedule(transfer_h2d); return std::unique_ptr(std::move(py_buffer)); } From e7849b639bb26a4b4f570c8868412989fc228949 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 18 Mar 2024 17:35:25 -0700 Subject: [PATCH 056/670] PR #10612: [GPU] cuDNN GEMM fusions: enable noncontracting dimension transformations. Imported from GitHub PR https://github.com/openxla/xla/pull/10612 This kind of transformations is already in use by the Triton GEMM backend for a while. Copybara import of the project: -- 51eaaf6c8b722ef7c3273825d0585371cd55da26 by Ilia Sergachev : [GPU] Support broadcasts in cuDNN GEMM fusions. -- 40aed83572cd3f09fa8530b6c130bf84be226593 by Ilia Sergachev : [XLA:GPU] Enable noncontracting to batch dimension transformation in cuDNN GEMM fusions. This transformation is already in use by the Triton GEMM backend for a while. Merging this change closes #10612 PiperOrigin-RevId: 616998256 --- .../xla/service/gpu/cudnn_fusion_compiler.cc | 86 +++++++-- .../xla/xla/service/gpu/fusions/cudnn_test.cc | 180 ++++++++++++++++++ third_party/xla/xla/xla.proto | 6 +- 3 files changed, 259 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc index 3b4a5c4cc5b825..1bd93e3ee2a243 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc @@ -103,6 +103,13 @@ inline std::optional ToCudnnDataType(const PrimitiveType type) { } } +int FusionLevel(const HloInstruction& hlo) { + return hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_cudnn_gemm_fusion_level(); +}; + // Extracts dimensions and strides from HLO tensors in the format expected by // cuDNN. class GemmDimensionAdapter { @@ -139,17 +146,21 @@ class GemmDimensionAdapter { std::vector& strides) { const DotDimensionNumbers& dims = dot_.dot_dimension_numbers(); // GEMM fusions require a specific canonical order of dimensions. + constexpr int kBatchDimensionIndex = 0; + constexpr int kOutputLHSNonContractingDimensionIndex = 1; std::vector dim_indices; + int lhs_noncontracting_index = -1; switch (scope) { case TritonFusionAnalysis::Scope::LHS: - dim_indices = {dims.lhs_batch_dimensions().empty() - ? -1 - : dims.lhs_batch_dimensions(0), - GetNonContractingDims(dot_.operand(0)->shape(), - dims.lhs_batch_dimensions(), - dims.lhs_contracting_dimensions()) - .value()[0], - dims.lhs_contracting_dimensions(0)}; + lhs_noncontracting_index = + GetNonContractingDims(dot_.operand(0)->shape(), + dims.lhs_batch_dimensions(), + dims.lhs_contracting_dimensions()) + .value()[0]; + dim_indices = { + dims.lhs_batch_dimensions().empty() ? -1 + : dims.lhs_batch_dimensions(0), + lhs_noncontracting_index, dims.lhs_contracting_dimensions(0)}; break; case TritonFusionAnalysis::Scope::RHS: dim_indices = {dims.rhs_batch_dimensions().empty() @@ -162,8 +173,9 @@ class GemmDimensionAdapter { .value()[0]}; break; case TritonFusionAnalysis::Scope::OUTPUT: + lhs_noncontracting_index = dot_.shape().rank() - 2; dim_indices = {dims.lhs_batch_dimensions().empty() ? -1 : 0, - dot_.shape().rank() - 2, dot_.shape().rank() - 1}; + lhs_noncontracting_index, dot_.shape().rank() - 1}; break; case TritonFusionAnalysis::Scope::META: LOG(FATAL) << "Unsupported scope."; @@ -177,17 +189,67 @@ class GemmDimensionAdapter { strides.push_back(strides.empty() ? 1 : strides.back()); continue; } else { - if (spec->size() != 1) { + if (spec->size() == 1) { + // The dimension is not split, nothing to do. + } else if (spec->size() == 2) { + if (FusionLevel(hlo) < 3) { + return false; + } + if (!dims.lhs_batch_dimensions().empty()) { + VLOG(8) << "Noncontracting dimension split is not compatible with " + "batch dimensions."; + return false; + } + if (index != lhs_noncontracting_index) { + VLOG(8) << "Only LHS noncontracting dimension can be split."; + return false; + } + switch (scope) { + case TritonFusionAnalysis::Scope::LHS: + lhs_noncontracting_split = spec->back().count; + break; + case TritonFusionAnalysis::Scope::OUTPUT: + if (lhs_noncontracting_split != spec->back().count) { + VLOG(8) << "Output non-contracting dimension has to be split " + "the same way as the LHS input one if it is split."; + return false; + } + break; + default: + VLOG(8) << "Only LHS noncontracting dimension can be split."; + return false; + } + // Assign the major part of the noncontracting dimension to the + // unused batch one. + CHECK_EQ(dimensions[kBatchDimensionIndex], 1); + dimensions[kBatchDimensionIndex] = spec->back().count; + strides[kBatchDimensionIndex] = spec->back().stride; + } else { + VLOG(8) << "The dimension is split multiple times."; return false; } dimensions.push_back(spec->front().count); strides.push_back(spec->front().stride); } } + if (lhs_noncontracting_split > 1 && + scope == TritonFusionAnalysis::Scope::OUTPUT && + dimensions[kBatchDimensionIndex] == 1) { + // LHS input noncontracting dimension is split but the corresponding + // output one is not. Assign part of the output one to the unused batch + // dimension. + dimensions[kBatchDimensionIndex] = lhs_noncontracting_split; + dimensions[kOutputLHSNonContractingDimensionIndex] /= + lhs_noncontracting_split; + strides[kBatchDimensionIndex] = + strides[kOutputLHSNonContractingDimensionIndex] * + dimensions[kOutputLHSNonContractingDimensionIndex]; + } return true; } private: + int64_t lhs_noncontracting_split = 1; const HloDotInstruction& dot_; }; @@ -254,7 +316,9 @@ absl::StatusOr> HloFusionToCuDnnGraph( } else if (hlo->opcode() == HloOpcode::kReshape || hlo->opcode() == HloOpcode::kBitcast || hlo->opcode() == HloOpcode::kTranspose || - hlo->opcode() == HloOpcode::kCopy) { + hlo->opcode() == HloOpcode::kCopy || + (FusionLevel(fusion) >= 2 && + hlo->opcode() == HloOpcode::kBroadcast)) { // All these are accounted for separately as transformations of strides. hlo_to_cudnn[hlo] = operand(0); } else if (hlo->IsElementwise()) { diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 2d69800a63b69a..40da7bbbc039ff 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -28,6 +28,7 @@ class CuDnnFusionTest : public GpuCodegenTest { // Let this group of tests just use first available plan skipping // autotuning. debug_options.set_xla_gpu_autotune_level(0); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(1); return debug_options; } bool IsAtLeastHopperWithCuDnn9() { @@ -291,6 +292,185 @@ ENTRY %e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +class CuDnnFusionLevel2Test : public CuDnnFusionExecutionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + CuDnnFusionExecutionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(2); + return debug_options; + } +}; + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim2ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,32] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={0,1} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,32] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim1ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,128] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={0,2} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[16,128] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastToDim0ExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = bf16[32,128] parameter(0) + p0b = bf16[5,32,128] broadcast(p0), dimensions={1,2} + p1 = bf16[5,128,64] parameter(1) + ROOT r = f32[5,32,64] dot(p0b, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = bf16[32,128] parameter(0) + p1 = bf16[5,128,64] parameter(1) + ROOT _ = f32[5,32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastTo2DimsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[128] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={2} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[128] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel2Test, BroadcastTo3DimsExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[] parameter(2) + p2b = f16[16,32,128] broadcast(p2), dimensions={} + a = f16[16,32,128] add(p0, p2b) + ROOT r = f16[16,32,64] dot(a, p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} +} + +ENTRY e { + p0 = f16[16,32,128] parameter(0) + p1 = f16[16,128,64] parameter(1) + p2 = f16[] parameter(2) + ROOT _ = f16[16,32,64] fusion(p0, p1, p2), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +class CuDnnFusionLevel3Test : public CuDnnFusionExecutionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + CuDnnFusionExecutionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(3); + return debug_options; + } +}; + +TEST_F(CuDnnFusionLevel3Test, + DotWithSplitNonContractingInputExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + cp0 = s8[4,3,16,400]{3,2,1,0} copy(p0) + bc0 = s8[192,400]{1,0} bitcast(cp0) + cvt0 = bf16[192,400]{1,0} convert(bc0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + bc1 = bf16[128,400]{1,0} reshape(p1) + ROOT d = bf16[192,128]{1,0} dot(cvt0, bc1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY r { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + ROOT r = bf16[192,128]{1,0} fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(CuDnnFusionLevel3Test, + DotWithSplitNonContractingInOutExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + cp0 = s8[4,3,16,400]{3,2,1,0} copy(p0) + bc0 = s8[192,400]{1,0} bitcast(cp0) + cvt0 = bf16[192,400]{1,0} convert(bc0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + bc1 = bf16[128,400]{1,0} reshape(p1) + d = bf16[192,128]{1,0} dot(cvt0, bc1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + bc = bf16[4,3,16,128]{3,2,1,0} bitcast(d) + ROOT cp = bf16[4,3,16,128]{2,1,3,0} copy(bc) +} + +ENTRY r { + p0 = s8[4,3,16,400]{2,1,3,0} parameter(0) + p1 = bf16[1,128,400]{2,1,0} parameter(1) + ROOT r = bf16[4,3,16,128]{2,1,3,0} fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + class CuDnnFusionRewriteTest : public CuDnnFusionTest { public: DebugOptions GetDebugOptionsForTest() override { diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 69a40c9dc09e41..1c84566bebb4aa 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -723,8 +723,10 @@ message DebugOptions { // Let GEMM fusion autotuning probe cuDNN as a backend. // Current levels: - // 0: disabled. - // 1: fusions of GEMM, elementwise, transpose/reshape operations. + // 0: Disabled. + // 1: Fusions of GEMM, elementwise, transpose/reshape operations. + // 2: + Broadcasts. + // 3: + Nontrivial noncontracting dimension reshapes/transposes. int32 xla_gpu_cudnn_gemm_fusion_level = 285; // Next id: 286 From 7c6c9233e5ac4a668fc411bc3153fc05736f6a3c Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 18 Mar 2024 17:43:26 -0700 Subject: [PATCH 057/670] Only mark ops converted by pattern at illegal PiperOrigin-RevId: 617000313 --- .../stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir | 2 ++ .../lite/stablehlo/transforms/smuggle_disallowed_ops.cc | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir index ec8ab139054e63..4a0f6a5d5e673b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-smuggle-resize.mlir @@ -1,10 +1,12 @@ // RUN: odml_to_stablehlo %s -skip-resize -smuggle-disallowed-ops -o - | FileCheck %s +// RUN: odml-to-stablehlo-opt %s --smuggle-disallowed-ops-pass | FileCheck %s --check-prefix=CHECK-OPT // CHECK-LABEL: @main module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 975 : i32}, tf_saved_model.semantics} { func.func @serving_default(%arg0: tensor<1x32x32x128xf32> {tf_saved_model.index_path = ["a"]}) -> (tensor<1x64x64x128xf32> {tf_saved_model.index_path = ["b"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "c:0", outputs = "d:0"}, tf_saved_model.exported_names = ["serving_default"]} { %0 = "tf.Const"() {value = dense<[56, 904]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK: %1 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %0) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> + // CHECK-OPT: %0 = stablehlo.custom_call @tf.ResizeBilinear(%arg0, %cst) {align_corners = false, device = "", half_pixel_centers = true} : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> %1 = "tf.ResizeBilinear"(%arg0, %0) { align_corners = false, device = "", half_pixel_centers = true } : (tensor<1x32x32x128xf32>, tensor<2xi32>) -> tensor<1x64x64x128xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc index 033ec78751e6b6..06754ea72b580c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -70,6 +71,9 @@ class SmuggleDisallowedOpsPass StringRef getDescription() const final { return "Smuggle disallowed ops via stablehlo.custom_calls"; } + void getDependentDialects(DialectRegistry& registry) const final { + registry.insert(); + } void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -77,7 +81,7 @@ class SmuggleDisallowedOpsPass patterns.add>(&getContext()); ConversionTarget target(getContext()); - target.addIllegalDialect(); + target.addIllegalOp(); target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { From cd0c17316ee2f49238b10d3c63ec9bc6fce72c98 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 18 Mar 2024 17:51:46 -0700 Subject: [PATCH 058/670] [xla:hlo] Do not add processed instructions to DFS stack PiperOrigin-RevId: 617002781 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 51 +++++++++---------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 7d8a080bd3840a..d5418d1a9ad47a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -45,9 +45,11 @@ limitations under the License. #include "xla/printer.h" #include "xla/service/mapped_ptr_container_sorter.h" #include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -510,22 +512,29 @@ void HloComputation::ForEachInstructionPostOrderImpl( bool has_channel_dependencies = !channel_dependencies.empty(); auto* dfs_stack = dfs_stack_scratch; dfs_stack->clear(); - dfs_stack->push_back(root); + + // Pushes instruction to dfs stack only if it was not already processed. + auto dfs_stack_push = [&](HloInstruction* instr) { + VisitState state = visited.GetState(instr->index_in_parent_); + if (state != kVisited) dfs_stack->push_back(instr); + }; + + dfs_stack_push(root); while (!dfs_stack->empty()) { - HloInstruction& current = *dfs_stack->back(); + HloInstruction* current = dfs_stack->back(); + DCHECK_EQ(current->parent(), this) + << "Instruction " << current->name() + << " is not in the current computation (" << name() << ")."; - VisitMap::Handle h = current.index_in_parent_; + VisitMap::Handle h = current->index_in_parent_; VisitState state = visited.GetState(h); if (state == kNew) { visited.SetState(h, kVisiting); } else { dfs_stack->pop_back(); if (state != kVisited) { - DCHECK_EQ(current.parent(), this) - << "Instruction " << current.name() - << " is not in the current computation (" << name() << ")."; - func(¤t); visited.SetState(h, kVisited); + func(current); } continue; } @@ -534,34 +543,22 @@ void HloComputation::ForEachInstructionPostOrderImpl( // Collectives with the same channel ID must be performed together, as these // represent MPMD-partitioned that will later be split into separate modules // and the order must be preserved. - if (has_channel_dependencies && ¤t != root) { - auto it = channel_dependencies.find(¤t); + if (has_channel_dependencies && current != root) { + auto it = channel_dependencies.find(current); if (it != channel_dependencies.end()) { - dfs_stack->insert(dfs_stack->end(), it->second.begin(), - it->second.end()); + absl::c_for_each(it->second, dfs_stack_push); } } // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer // result for things like HLO stringification. - const HloInstruction::InstructionVector& operands = current.operands(); - - for (auto it = operands.rbegin(); it != operands.rend(); ++it) { - HloInstruction* operand = *it; - if (visited.GetState(operand->index_in_parent_) != kVisited) { - dfs_stack->push_back(operand); - } else { - // Already fully visited, so we avoid pushing onto the stack - } - } + const HloInstruction::InstructionVector& operands = current->operands(); + absl::c_for_each(tsl::gtl::make_range(operands.rbegin(), operands.rend()), + dfs_stack_push); - const PtrVec& predecessors = - current.control_predecessors(); - if (!predecessors.empty()) { - dfs_stack->insert(dfs_stack->end(), predecessors.begin(), - predecessors.end()); - } + // Add control predecessors to the stack. + absl::c_for_each(current->control_predecessors(), dfs_stack_push); } } From 9c7585f17c4e8dc497c14c35b02ad3285d7a65be Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 18 Mar 2024 18:11:01 -0700 Subject: [PATCH 059/670] [xla:hlo] NFC: Convert VisitState to enum class PiperOrigin-RevId: 617007972 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index d5418d1a9ad47a..449a86db314cc3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -58,7 +59,22 @@ namespace xla { using absl::StrCat; -enum VisitState { kNew = 0, kVisiting = 1, kVisited = 2 }; +enum class VisitState { kNew = 0, kVisiting = 1, kVisited = 2 }; + +static std::ostream& operator<<(std::ostream& os, const VisitState& state) { + switch (state) { + case VisitState::kNew: + os << "new"; + break; + case VisitState::kVisiting: + os << "visiting"; + break; + case VisitState::kVisited: + os << "visited"; + break; + } + return os; +} class HloComputation::VisitMap { public: @@ -516,7 +532,7 @@ void HloComputation::ForEachInstructionPostOrderImpl( // Pushes instruction to dfs stack only if it was not already processed. auto dfs_stack_push = [&](HloInstruction* instr) { VisitState state = visited.GetState(instr->index_in_parent_); - if (state != kVisited) dfs_stack->push_back(instr); + if (state != VisitState::kVisited) dfs_stack->push_back(instr); }; dfs_stack_push(root); @@ -528,12 +544,12 @@ void HloComputation::ForEachInstructionPostOrderImpl( VisitMap::Handle h = current->index_in_parent_; VisitState state = visited.GetState(h); - if (state == kNew) { - visited.SetState(h, kVisiting); + if (state == VisitState::kNew) { + visited.SetState(h, VisitState::kVisiting); } else { dfs_stack->pop_back(); - if (state != kVisited) { - visited.SetState(h, kVisited); + if (state != VisitState::kVisited) { + visited.SetState(h, VisitState::kVisited); func(current); } continue; @@ -1568,16 +1584,16 @@ std::unique_ptr HloComputation::CloneInContext( auto it = visited.find(cur); if (it != visited.end()) { dfs_stack.pop_back(); - if (it->second == kVisited) { + if (it->second == VisitState::kVisited) { continue; } - CHECK_EQ(it->second, kVisiting); + CHECK_EQ(it->second, VisitState::kVisiting); postorder.push_back(cur); - it->second = kVisited; + it->second = VisitState::kVisited; continue; } - visited.insert({cur, kVisiting}); + visited.insert({cur, VisitState::kVisiting}); for (HloInstruction* operand : cur->operands()) { const HloInstruction* new_operand = replace(operand); if (new_operand) { From 44b161ac851da17cbd3efd9f5d6b58d8a96cd1bc Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 19:09:20 -0700 Subject: [PATCH 060/670] Move the logic for populating default calibration options to `PopulateDefaults`. This change populates calibration method as part of `PopulateDefaults`. This is reused for ODML use cases. However, for ODML the value for `unpack_quantized_types` is explicitly set to `False` because ODML use cases require uniform quantized types to be left intact. PiperOrigin-RevId: 617019651 --- .../mlir/lite/quantization/stablehlo/BUILD | 1 + .../quantization/stablehlo/quantization.cc | 12 +++---- .../quantization/stablehlo/quantization.h | 2 +- .../mlir/quantization/stablehlo/cc/config.cc | 16 ++++++++++ .../quantization/stablehlo/cc/config_test.cc | 31 +++++++++++++++++++ .../stablehlo/cc/static_range_ptq.cc | 6 +--- .../stablehlo/cc/static_range_ptq.h | 13 +------- .../stablehlo/quantization_config.proto | 2 +- tensorflow/lite/python/lite.py | 6 +++- 9 files changed, 63 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index df286611f3e356..f469cbc8fddacf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -18,6 +18,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc index ccba41d07e103b..0cc946a23d4e25 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" @@ -41,6 +42,7 @@ namespace tensorflow { namespace { using ::mlir::quant::stablehlo::StaticRangePtqComponent; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::SignatureDef; using ::tensorflow::quantization::PyFunctionLibrary; @@ -79,7 +81,7 @@ absl::StatusOr RunQuantization( const SavedModelBundle* saved_model_bundle, const absl::string_view saved_model_dir, const std::unordered_set& saved_model_tags, - QuantizationConfig& quantization_config, + const QuantizationConfig& quantization_config, const PyFunctionLibrary* quantization_py_function_lib, mlir::ModuleOp module_op) { if (saved_model_bundle == nullptr) { @@ -94,10 +96,8 @@ absl::StatusOr RunQuantization( "be nullptr."); } - if (!quantization_config.has_calibration_options()) { - *quantization_config.mutable_calibration_options() = - mlir::quant::stablehlo::GetDefaultCalibrationOptions(); - } + const QuantizationConfig config_with_defaults = + PopulateDefaults(quantization_config); const absl::flat_hash_map signature_def_map = GetSignatureDefMapFromBundle(*saved_model_bundle); @@ -132,7 +132,7 @@ absl::StatusOr RunQuantization( /*signature_keys=*/exported_names, saved_model_tags, signature_def_map, GetFunctionAliases(*saved_model_bundle)); const absl::StatusOr quantized_module_op = - static_range_ptq_component.Run(module_op, quantization_config); + static_range_ptq_component.Run(module_op, config_with_defaults); if (!quantized_module_op.ok()) { return absl::InternalError("Failed to run quantization. Status msg: " + quantized_module_op.status().ToString()); diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h index ef6496315e8e61..c55d59cad0f1a0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h @@ -50,7 +50,7 @@ absl::StatusOr RunQuantization( const SavedModelBundle* saved_model_bundle, absl::string_view saved_model_dir, const std::unordered_set& saved_model_tags, - stablehlo::quantization::QuantizationConfig& quantization_config, + const stablehlo::quantization::QuantizationConfig& quantization_config, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_lib, mlir::ModuleOp module_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index 679e1f8754be9b..e8a4aa87bb0619 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -15,11 +15,27 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" namespace stablehlo::quantization { +namespace { + +// Creates `CalibrationOptions` with default fields. Uses simple min-max +// calibration by default. +CalibrationOptions GetDefaultCalibrationOptions() { + CalibrationOptions options{}; + options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); + return options; +} + +} // namespace QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config) { QuantizationConfig config = user_provided_config; + if (!config.has_calibration_options()) { + *config.mutable_calibration_options() = GetDefaultCalibrationOptions(); + } + PipelineConfig& pipeline_config = *config.mutable_pipeline_config(); if (!pipeline_config.has_unpack_quantized_types()) { pipeline_config.set_unpack_quantized_types(true); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index 5912788bddf96b..164cd6bae237f8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include #include #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" namespace stablehlo::quantization { namespace { +using ::testing::Eq; + TEST(PopulateDefaultsTest, PopulateDefaultsForEmptyConfig) { QuantizationConfig config{}; @@ -37,5 +40,33 @@ TEST(PopulateDefaultsTest, PopulateDefaultsForConfigWithUnpackQuantizedTypes) { EXPECT_FALSE(new_config.pipeline_config().unpack_quantized_types()); } +TEST(PopulateDefaultsTest, DefaultCalibrationOptionsPopulated) { + QuantizationConfig config{}; + + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_MIN_MAX)); +} + +TEST(PopulateDefaultsTest, ExplicitCalibrationOptionsNotOverridden) { + QuantizationConfig config{}; + CalibrationOptions& calibration_options = + *config.mutable_calibration_options(); + calibration_options.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); + calibration_options.mutable_calibration_parameters()->set_initial_num_bins( + 512); + + // Test that if the user explicitly provided `calibration_options`, it is not + // overridden. + const QuantizationConfig new_config = PopulateDefaults(config); + EXPECT_THAT(new_config.calibration_options().calibration_method(), + Eq(CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX)); + EXPECT_THAT(new_config.calibration_options() + .calibration_parameters() + .initial_num_bins(), + Eq(512)); +} + } // namespace } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index eaafdf1770f7f9..e4b3595ae0f2de 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -243,17 +243,13 @@ absl::StatusOr StaticRangePtqComponent::Run( absl::Status QuantizeStaticRangePtq( const absl::string_view src_saved_model_path, const absl::string_view dst_saved_model_path, - QuantizationConfig quantization_config, + const QuantizationConfig& quantization_config, const std::vector& signature_keys, const absl::flat_hash_map& signature_def_map, const PyFunctionLibrary& py_function_library) { std::unordered_set tags; tags.insert(quantization_config.tf_saved_model().tags().begin(), quantization_config.tf_saved_model().tags().end()); - if (!quantization_config.has_calibration_options()) { - *quantization_config.mutable_calibration_options() = - GetDefaultCalibrationOptions(); - } std::unique_ptr ctx = CreateMlirContextForQuantization(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h index e5056418bbae55..69bd9da6733c0c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -37,17 +37,6 @@ limitations under the License. namespace mlir::quant::stablehlo { -using ::stablehlo::quantization::CalibrationOptions; - -// Create default configuration for the calibration step, which is the min/max -// calibration method. -inline CalibrationOptions GetDefaultCalibrationOptions() { - CalibrationOptions options{}; - options.set_calibration_method( - CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); - return options; -} - // Component for static-range post-training quantization (PTQ). // TODO: b/320607042 - Add tests in python level. class StaticRangePtqComponent : public Component { @@ -102,7 +91,7 @@ class StaticRangePtqComponent : public Component { absl::Status QuantizeStaticRangePtq( absl::string_view src_saved_model_path, absl::string_view dst_saved_model_path, - ::stablehlo::quantization::QuantizationConfig quantization_config, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, const std::vector& signature_keys, const absl::flat_hash_map& signature_def_map, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 56645d7f3d73ad..b4c4dbdf1f26c8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -226,7 +226,7 @@ message CalibrationOptions { } // Determines how to calibrate. - // The default calibration method is MIN_MAX. + // Default value: CALIBRATION_METHOD_MIN_MAX CalibrationMethod calibration_method = 1; // Defines the parameters required for calibration. Parameters such as the diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index c6804b0f35ed18..952392dcb8df84 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -864,7 +864,11 @@ def _get_base_converter_args(self): ) ], enable_per_channel_quantized_weight=True, - ) + ), + # For ODML use cases, uniform quantized types should be left intact. + pipeline_config=qc.PipelineConfig( + unpack_quantized_types=False, + ), ) args["quantization_config"] = quantization_config From bfee9197ad47168d103a6963fdaf8c8a8f27cfa1 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 18 Mar 2024 19:33:33 -0700 Subject: [PATCH 061/670] [xla:cpu] Remove lmhlo dependency PiperOrigin-RevId: 617023917 --- .../xla/mlir/backends/cpu/transforms/BUILD | 1 - .../cpu/transforms/xla_cpu_to_cpu_runtime.cc | 139 +----------------- 2 files changed, 4 insertions(+), 136 deletions(-) diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD index 9fb60b8442c698..ff81d082104b3a 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD @@ -44,7 +44,6 @@ cc_library( "//xla/mlir/runtime/utils:custom_calls", "//xla/mlir/xla_cpu/ir:xla_cpu", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:hlo_parser", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc index 16223cead19f7c..fb3bb71548c6f9 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc @@ -35,7 +35,6 @@ limitations under the License. #include "xla/mlir/runtime/transforms/type_converter.h" #include "xla/mlir/runtime/utils/custom_calls.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" @@ -48,8 +47,6 @@ namespace { using namespace mlir; // NOLINT -using mlir::lmhlo::CustomCallOp; - using xla_cpu::PartitionIdOp; using xla_cpu::ReplicaIdOp; @@ -115,133 +112,6 @@ func::CallOp CreateCallForDpsCollectiveOp(Operation* op, //===----------------------------------------------------------------------===// -class CustomCallOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.cpu.custom_call"; - - public: - CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime - // custom calls bypassing custom call adaptor. - LogicalResult rewriteTypedCustomCall(CustomCallOp op, - PatternRewriter& rewriter) const { - // TODO(ezhulenev): Support target arg mapping, or explain why we do not - // need them for typed custom calls. - if (op.getTargetArgMapping()) - return op.emitOpError( - "API_VERSION_TYPED_FFI custom calls do not " - "support target arg mapping"); - - // Create a custom call function declaration. - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = - custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); - callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); - - // Forward backend config to the custom call implementation. - auto config = op.getBackendConfig(); - if (!config) return op.emitOpError("Failed to get backend config"); - auto dict = config->cast(); - llvm::SmallVector backend_config(dict.begin(), dict.end()); - - // Call the custom call function forwarding user-defined attributes. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), op.getOperands()); - AppendCustomCallAttrs(call, backend_config); - - return success(); - } - - LogicalResult matchAndRewrite(CustomCallOp op, - PatternRewriter& rewriter) const override { - // Typed custom calls lowered directly to XLA runtime custom calls. - if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) - return rewriteTypedCustomCall(op, rewriter); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // By default all operands passed to the custom call handler. - llvm::SmallVector operands = op.getOperands(); - - // Get the number of outputs from operand_segment_sizes. - int64_t num_results = op->getAttrOfType( - op.getOperandSegmentSizesAttrName())[1]; - - // If custom call has target arguments mapping, then we need to pass empty - // memrefs in place of holes. - if (op.getTargetArgMapping().has_value()) { - auto mapping = *op.getTargetArgMapping(); - int64_t num_args = mapping.getNumArgs(); - num_results = mapping.getNumResults(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value hole = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().front()); - return b.create(MemRefType::get({0}, b.getI8Type())); - }(); - - // We represent holes as empty i8 memrefs. - operands = llvm::SmallVector(num_args + num_results, hole); - - // Update operands to mapped custom call arguments. - auto args = mapping.getArgsToTargetArgs(); - for (const auto& indexed : llvm::enumerate(args)) - operands[indexed.value()] = op.getArgs()[indexed.index()]; - - // Update operands to mapped custom call results. - auto res = mapping.getResultsToTargetResults(); - for (const auto& indexed : llvm::enumerate(res)) - operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; - } - - // TODO(jreiffers): This will break if an output has a non-default layout. - operands = EnsureFlatMemrefs(operands, b); - // Create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); - - // The ABI is different depending on whether the original op was outputting - // a tuple or not. For multiple outputs this is trivial but for a single - // output we rely on the xla_shape attribute to distinguish the ABIs. - bool output_tuple = num_results > 1; - if (auto xla_shape = op->getAttrOfType("xla_shape")) - output_tuple = ParseShape(xla_shape.strref())->IsTuple(); - - // This is not equivalent to op.getApiVersionAttr() - that call returns null - // if the attribute is absent. getApiVersion returns the default. - Attribute api_version = - mhlo::CustomCallApiVersionAttr::get(getContext(), op.getApiVersion()); - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("num_results"), - b.getI32IntegerAttr(static_cast(num_results))}, - {b.getStringAttr("output_tuple"), b.getBoolAttr(output_tuple)}, - {b.getStringAttr("api_version"), api_version}, - {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; - - if (auto backend_config = op.getBackendConfigAttr()) { - custom_call_attrs.emplace_back(b.getStringAttr("backend_config"), - op.getBackendConfigAttr()); - } - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - template class IdOpLowering : public OpRewritePattern { public: @@ -542,11 +412,10 @@ void ConvertXlaCpuToCpuRuntimePass::runOnOperation() { // Convert xla_cpu operations to XLA cpu runtime custom calls. RewritePatternSet patterns(ctx); - patterns - .insert( - ctx, custom_calls); + patterns.insert( + ctx, custom_calls); patterns.insert>(ctx, "xla.cpu.partition_id", custom_calls); patterns.insert>(ctx, "xla.cpu.replica_id", From d2f259e64f4035c7beee80af12edcb7eb4f282ef Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 18 Mar 2024 19:45:00 -0700 Subject: [PATCH 062/670] [xla:ffi] Add XLA_FFI_Handler_Traits to capture properties of an FFI handler PiperOrigin-RevId: 617026090 --- third_party/xla/xla/ffi/api/api.h | 30 ++++++++++--------- third_party/xla/xla/ffi/api/c_api.h | 12 +++++++- third_party/xla/xla/ffi/ffi_api.cc | 20 +++++++------ third_party/xla/xla/ffi/ffi_api.h | 13 +++++--- third_party/xla/xla/ffi/ffi_test.cc | 6 +++- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 4 +-- third_party/xla/xla/python/xla_compiler.cc | 4 +-- .../address_computation_fusion_rewriter.cc | 4 +-- .../xla/xla/service/gpu/fusions/custom.cc | 11 +++---- .../xla/service/gpu/ir_emitter_unnested.cc | 6 ++-- .../runtime/address_computation_thunk_test.cc | 12 ++++---- 11 files changed, 73 insertions(+), 49 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index bc21d1856a8f85..7faddec1e350d4 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -107,10 +107,9 @@ class Ffi { // Registers handler with an XLA runtime under the given name on a given // platform. - static inline XLA_FFI_Error* RegisterStaticHandler(const XLA_FFI_Api* api, - std::string_view name, - std::string_view platform, - XLA_FFI_Handler* handler); + static inline XLA_FFI_Error* RegisterStaticHandler( + const XLA_FFI_Api* api, std::string_view name, std::string_view platform, + XLA_FFI_Handler* handler, XLA_FFI_Handler_Traits traits = 0); protected: template @@ -131,7 +130,8 @@ class Ffi { XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, std::string_view name, std::string_view platform, - XLA_FFI_Handler* handler) { + XLA_FFI_Handler* handler, + XLA_FFI_Handler_Traits traits) { // Make copies of string views to guarantee they are null terminated. std::string name_str(name); std::string platform_str(platform); @@ -142,6 +142,7 @@ XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, args.name = name_str.c_str(); args.platform = platform_str.c_str(); args.handler = handler; + args.traits = traits; return api->XLA_FFI_Handler_Register(&args); } @@ -1294,15 +1295,16 @@ auto DictionaryDecoder(Members... m) { // TODO(ezhulenev): Add a callback so that end users can log registration error // to appropriate logging destination, e.g. LOG(FATAL) for duplicate internal // FFI handlers. -#define XLA_FFI_REGISTER_HANDLER(API, NAME, PLATFORM, FUNC) \ - XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, __COUNTER__) -#define XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, N) \ - XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N) -#define XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N) \ - XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ - xla_ffi_static_handler_##N##_registered_ = [] { \ - return ::xla::ffi::Ffi::RegisterStaticHandler(API, NAME, PLATFORM, \ - FUNC); \ +#define XLA_FFI_REGISTER_HANDLER(API, NAME, PLATFORM, FUNC, ...) \ + XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, __COUNTER__, \ + ##__VA_ARGS__) +#define XLA_FFI_REGISTER_HANDLER_(API, NAME, PLATFORM, FUNC, N, ...) \ + XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N, ##__VA_ARGS__) +#define XLA_FFI_REGISTER_HANDLER__(API, NAME, PLATFORM, FUNC, N, ...) \ + XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ + xla_ffi_static_handler_##N##_registered_ = [] { \ + return ::xla::ffi::Ffi::RegisterStaticHandler(API, NAME, PLATFORM, \ + FUNC, ##__VA_ARGS__); \ }() } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 114b2b4f6fbf1a..5549c5f3c2a30d 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -267,6 +267,15 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs); // External functions registered with XLA as FFI handlers. typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame); +enum XLA_FFI_Handler_TraitsBits { + // Calls to FFI handler are safe to trace into the command buffer. It means + // that calls to FFI handler always launch exactly the same device operations + // (can depend on attribute values) that can be captured and then replayed. + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0, +}; + +typedef uint32_t XLA_FFI_Handler_Traits; + struct XLA_FFI_Handler_Register_Args { size_t struct_size; void* priv; @@ -274,9 +283,10 @@ struct XLA_FFI_Handler_Register_Args { const char* name; // null terminated const char* platform; // null terminated XLA_FFI_Handler* handler; + XLA_FFI_Handler_Traits traits; }; -XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, handler); +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, traits); typedef XLA_FFI_Error* XLA_FFI_Handler_Register( XLA_FFI_Handler_Register_Args* args); diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 3173157a10ba90..75de43e277cfc1 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -84,7 +84,7 @@ Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, //===----------------------------------------------------------------------===// using HandlerKey = std::pair; -using HandlerRegistry = absl::flat_hash_map; +using HandlerRegistry = absl::flat_hash_map; static HandlerKey MakeHandlerKey(std::string_view name, std::string_view platform) { @@ -97,9 +97,10 @@ static HandlerRegistry& GetHandlerRegistry() { } static Status RegisterHandler(std::string_view name, std::string_view platform, - XLA_FFI_Handler* handler) { - auto emplaced = - GetHandlerRegistry().try_emplace(MakeHandlerKey(name, platform), handler); + XLA_FFI_Handler* handler, + XLA_FFI_Handler_Traits traits) { + auto emplaced = GetHandlerRegistry().try_emplace( + MakeHandlerKey(name, platform), HandlerRegistration{handler, traits}); if (!emplaced.second) return absl::InvalidArgumentError( absl::StrCat("Duplicate FFI handler registration for ", name, @@ -107,8 +108,8 @@ static Status RegisterHandler(std::string_view name, std::string_view platform, return OkStatus(); } -absl::StatusOr FindHandler(std::string_view name, - std::string_view platform) { +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform) { auto it = GetHandlerRegistry().find(MakeHandlerKey(name, platform)); if (it == GetHandlerRegistry().end()) return absl::NotFoundError(absl::StrCat("No FFI handler registered for ", @@ -116,9 +117,9 @@ absl::StatusOr FindHandler(std::string_view name, return it->second; } -absl::flat_hash_map StaticRegisteredHandlers( +absl::flat_hash_map StaticRegisteredHandlers( std::string_view platform) { - absl::flat_hash_map calls; + absl::flat_hash_map calls; for (const auto& [metadata, handler] : GetHandlerRegistry()) { if (absl::AsciiStrToLower(platform) == metadata.second) { calls[metadata.first] = handler; @@ -236,7 +237,8 @@ static XLA_FFI_Error* XLA_FFI_Handler_Register( "XLA_FFI_Handler_Register", XLA_FFI_Handler_Register_Args_STRUCT_SIZE, args->struct_size)); - if (auto status = RegisterHandler(args->name, args->platform, args->handler); + if (auto status = RegisterHandler(args->name, args->platform, args->handler, + args->traits); !status.ok()) { return new XLA_FFI_Error{std::move(status)}; } diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h index eae9eeda0a34c3..d101a8974587b6 100644 --- a/third_party/xla/xla/ffi/ffi_api.h +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -27,7 +27,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/service_executable_run_options.h" #include "xla/status.h" -#include "xla/statusor.h" namespace xla::ffi { @@ -62,12 +61,18 @@ Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, // XLA FFI registry //===----------------------------------------------------------------------===// +struct HandlerRegistration { + XLA_FFI_Handler* handler = nullptr; + XLA_FFI_Handler_Traits traits = 0; +}; + // Returns registered FFI handler for a given name and platform, or an error if // it's not found in the static registry. -absl::StatusOr FindHandler(std::string_view name, - std::string_view platform); +absl::StatusOr FindHandler(std::string_view name, + std::string_view platform); + // Returns all registered calls in the static registry for a given platform. -absl::flat_hash_map StaticRegisteredHandlers( +absl::flat_hash_map StaticRegisteredHandlers( std::string_view platform); //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index c28da195f7ba57..7c4e5fe1e083fb 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -49,7 +49,8 @@ TEST(FfiTest, StaticRegistration) { XLA_FFI_DEFINE_HANDLER(NoOp1, noop); XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op-0", "Host", NoOp0); - XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op-1", "Host", NoOp1); + XLA_FFI_REGISTER_HANDLER(GetXlaFfiApi(), "no-op-1", "Host", NoOp1, + XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); auto handler0 = FindHandler("no-op-0", "Host"); auto handler1 = FindHandler("no-op-1", "Host"); @@ -57,6 +58,9 @@ TEST(FfiTest, StaticRegistration) { TF_ASSERT_OK(handler0.status()); TF_ASSERT_OK(handler1.status()); + ASSERT_EQ(handler0->traits, 0); + ASSERT_EQ(handler1->traits, XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE); + EXPECT_THAT(StaticRegisteredHandlers("Host"), UnorderedElementsAre(Pair("no-op-0", _), Pair("no-op-1", _))); } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 6d753c5dd4c117..f5583b3878dd12 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -430,8 +430,8 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { reinterpret_cast(next)->custom_call(&args); CHECK_EQ(error, nullptr); - auto* custom_call = xla::ffi::FindHandler(function_name, "CUDA").value(); - EXPECT_EQ(reinterpret_cast(custom_call), kNoop); + auto registration = xla::ffi::FindHandler(function_name, "CUDA").value(); + EXPECT_EQ(reinterpret_cast(registration.handler), kNoop); } } // namespace diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 69bb526a05683d..1ef547777794ae 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -943,10 +943,10 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { targets[nb::str(name.data(), name.size())] = nb::capsule(target); } - for (const auto& [name, target] : + for (const auto& [name, registration] : ffi::StaticRegisteredHandlers(platform)) { targets[nb::str(name.data(), name.size())] = - nb::capsule(reinterpret_cast(target)); + nb::capsule(reinterpret_cast(registration.handler)); } return targets; }, diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index ebee4e06f65d6d..ad124ed3eabde2 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -80,12 +80,12 @@ bool IsCustomCall(const HloInstruction* hlo, absl::string_view platform_name) { void* call_target = CustomCallTargetRegistry::Global()->Lookup( call_target_name, std::string(platform_name)); - absl::StatusOr handler = + absl::StatusOr handler_registration = ffi::FindHandler(call_target_name, platform_name); // At least one implementation should be available at run time. bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; - bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + bool found_ffi_handler = is_ffi_custom_call && handler_registration.ok(); return found_custom_call || found_ffi_handler; } diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 8027bd69756a3d..619fe2281611d7 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -187,7 +187,7 @@ absl::StatusOr EmitCustomCall( const BufferAssignment& buffer_assignment = ir_emitter_context.buffer_assignment(); - const std::string call_target_name = custom_call.custom_call_target(); + const std::string& call_target_name = custom_call.custom_call_target(); // Typed FFI custom calls is a replacement for legacy custom calls with // a rich type safe API. It's under construction and not fully supported. @@ -197,12 +197,12 @@ absl::StatusOr EmitCustomCall( void* call_target = CustomCallTargetRegistry::Global()->Lookup( call_target_name, std::string(ir_emitter_context.platform_name())); - absl::StatusOr handler = + absl::StatusOr registration = ffi::FindHandler(call_target_name, ir_emitter_context.platform_name()); // At least one implementation should be available at run time. bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; - bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + bool found_ffi_handler = is_ffi_custom_call && registration.ok(); if (!found_custom_call && !found_ffi_handler) { return absl::InternalError( @@ -323,8 +323,9 @@ absl::StatusOr EmitCustomCall( auto ffi_thunk = [&] { auto& called_computations = custom_call.called_computations(); return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), *handler, - std::move(operands), std::move(results), std::move(attributes), + Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), + registration->handler, std::move(operands), std::move(results), + std::move(attributes), called_computations.empty() ? nullptr : called_computations[0]); }; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 79eca88e8f96ea..09c0631925406b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1315,12 +1315,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk( void* call_target = CustomCallTargetRegistry::Global()->Lookup( call_target_name, std::string(platform_name())); - absl::StatusOr handler = + absl::StatusOr registration = ffi::FindHandler(call_target_name, platform_name()); // At least one implementation should be available at run time. bool found_custom_call = !is_ffi_custom_call && call_target != nullptr; - bool found_ffi_handler = is_ffi_custom_call && handler.ok(); + bool found_ffi_handler = is_ffi_custom_call && registration.ok(); if (!found_custom_call && !found_ffi_handler) { auto& debug_options = ir_emitter_context_->debug_options(); @@ -1452,7 +1452,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk( auto ffi_thunk = [&] { auto& called_computations = instr->called_computations(); return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), *handler, + Thunk::ThunkInfo::WithProfileAnnotation(instr), registration->handler, std::move(operands), std::move(results), std::move(attributes), called_computations.empty() ? nullptr : called_computations[0]); }; diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index 1167cf18a93c57..dc57a6447922e4 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -560,8 +560,8 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { // Preparing custom call thunk: setting up call target and operands + results // buffers. - auto handler = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); - ASSERT_TRUE(handler.ok()); + auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(registration.ok()); std::vector> operands{ CustomCallThunk::Slice{slice_src_fake, @@ -573,7 +573,7 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { // Creating embedded custom call thunk. ThunkSequence seq; seq.emplace_back(std::make_unique( - Thunk::ThunkInfo(nullptr), *handler, operands, results, + Thunk::ThunkInfo(nullptr), registration->handler, operands, results, /*attributes=*/CustomCallThunk::AttributesMap(), /*called_computation=*/nullptr)); @@ -713,8 +713,8 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { // Preparing custom call thunk: setting up call target and operands + results // buffers. - auto handler = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); - ASSERT_TRUE(handler.ok()); + auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); + ASSERT_TRUE(registration.ok()); std::vector> operands{ CustomCallThunk::Slice{slice_src_fake, @@ -726,7 +726,7 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { // Creating embedded custom call thunk. ThunkSequence seq; seq.emplace_back(std::make_unique( - Thunk::ThunkInfo(nullptr), *handler, operands, results, + Thunk::ThunkInfo(nullptr), registration->handler, operands, results, /*attributes=*/CustomCallThunk::AttributesMap(), /*called_computation=*/nullptr)); From a644925a6a62fb08e1d984553bf81bc9daff4a4e Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 18 Mar 2024 20:15:27 -0700 Subject: [PATCH 063/670] Handle element_size_in_bits in constant folding. Constant folding runs HloEvaluator, which creates Literals. Before, if element_size_in_bits was nonzero in a constant folded op, an error would occur since Literals do not support element_size_in_bits and so CHECKed it was zero in the constructor. Now Literal will silently set element_size_in_bits to zero in the constructor. Because the newly created constant-folded constant op derives its Shape from the constant-folded literal, now HloConstantFolding explicitly sets element_size_in_bits on the newly created constant op since the Literal will always have element_size_in_bits set to zero. This will be needed to support int4 in arbitrary ops on CPUs/GPUs. PiperOrigin-RevId: 617033705 --- third_party/xla/xla/literal.cc | 24 +++++++++++------ third_party/xla/xla/literal.h | 5 ++++ .../xla/xla/service/hlo_constant_folding.cc | 10 +++++++ .../xla/service/hlo_constant_folding_test.cc | 26 +++++++++++++++++++ 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 7c1ae28e8a9d58..d5364cb848e652 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -249,6 +249,21 @@ Literal::Literal() : Literal(NilShape()) {} Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} +void Literal::SetShape(const Shape& shape) { + Shape shape_storage; + const Shape* shape_ptr = &shape; + if (LayoutUtil::HasCustomElementSizeInBits(shape)) { + shape_storage = shape; + shape_storage.mutable_layout()->set_element_size_in_bits(0); + shape_ptr = &shape_storage; + } + if (const Shape* intered_shape_ptr = TryInternShape(*shape_ptr)) { + shape_ = intered_shape_ptr; + } else { + shape_ = std::make_unique(*shape_ptr); + } +} + void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays, ArrayValueState leaf_array_value_state) { if (shape.IsTuple()) { @@ -276,16 +291,9 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays, Literal::Literal(const Shape& shape, bool allocate_arrays, ArrayValueState leaf_array_value_state) : MutableLiteralBase() { - if (const Shape* intered_shape_ptr = TryInternShape(shape)) { - shape_ = intered_shape_ptr; - } else { - shape_ = std::make_unique(shape); - } + SetShape(shape); CHECK(leaf_array_value_state != ArrayValueState::kKnown || LayoutUtil::HasLayout(*shape_)); - // Currently we do nibble packing/unpacking in TPU host/device transfer. - CHECK(!LayoutUtil::HasCustomElementSizeInBits(*shape_)) - << "Literal does not support layouts with custom bit size: " << *shape_; root_piece_.set_subshape(shape_.get()); CHECK(&root_piece_.subshape() == shape_.get()); diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index 8f8894dbc26ea8..a6b4758cf64234 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -1469,6 +1469,11 @@ class Literal : public MutableLiteralBase { // Deallocate the buffers held by this literal. void DeallocateBuffers(); + // Sets the shape_ field from a Shape. shape_'s element_size_in_bits field + // on the layout is always set to 0 since Literals do not support packed + // subbyte elements. + void SetShape(const Shape& shape); + // Recursively sets the subshapes and buffers of all subpieces rooted at // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in // the shape. diff --git a/third_party/xla/xla/service/hlo_constant_folding.cc b/third_party/xla/xla/service/hlo_constant_folding.cc index 71f58d9a241232..7afdb75649edc3 100644 --- a/third_party/xla/xla/service/hlo_constant_folding.cc +++ b/third_party/xla/xla/service/hlo_constant_folding.cc @@ -233,6 +233,16 @@ StatusOr HloConstantFolding::Run( dead_instructions.push_back(instruction); HloInstruction* new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(result))); + if (new_constant->shape().has_layout()) { + // Update element_size_in_bits on the new instruction's layout. Literals + // always have element_size_in_bits set to 0, and CreateConstant copies + // the shape/layout from the Literal, so we need to set + // element_size_in_bits here. + new_constant->mutable_shape() + ->mutable_layout() + ->set_element_size_in_bits( + instruction->shape().layout().element_size_in_bits()); + } TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_constant)); } } diff --git a/third_party/xla/xla/service/hlo_constant_folding_test.cc b/third_party/xla/xla/service/hlo_constant_folding_test.cc index 4150b24ead5ee1..4958bee65f54d1 100644 --- a/third_party/xla/xla/service/hlo_constant_folding_test.cc +++ b/third_party/xla/xla/service/hlo_constant_folding_test.cc @@ -346,6 +346,32 @@ TEST_F(HloConstantFoldingTest, FoldOpsWhereOneOperandIsBroadcast) { ))); } +TEST_F(HloConstantFoldingTest, FoldInt4Ops) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY entry { + c0 = s4[2]{0:E(4)} constant({1, 2}) + c1 = s4[2]{0:E(4)} constant({3, 4}) + add1 = s4[2]{0:E(4)} add(c0, c1) + c2 = s4[]{:E(4)} constant(5) + add2 = s4[2]{0:E(4)} add(c0, s4[2]{0:E(4)} broadcast(c2)) + ROOT root = tuple(add1, add2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_TRUE(result); + auto is_4_bit = [](const HloInstruction* instr) { + return instr->shape().layout().element_size_in_bits() == 4; + }; + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Constant().WithPredicate(is_4_bit), + m::Constant().WithPredicate(is_4_bit)))); +} + TEST_F(HloConstantFoldingTest, BigReduceWindow) { constexpr absl::string_view kModuleStr = R"( HloModule test From 9c1b61a664d66b15d07c25e713d8d96f6bf347c8 Mon Sep 17 00:00:00 2001 From: Wilsin Gosti Date: Mon, 18 Mar 2024 20:36:42 -0700 Subject: [PATCH 064/670] #tf-data Set the iterator prefix and `DebugString` of `GlobalShuffleDataset` to `GlobalShuffle` to be consistent with other datasets. PiperOrigin-RevId: 617037057 --- .../kernels/data/experimental/global_shuffle_dataset_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc index e0cbd047bc945b..ad0006724bd5ef 100644 --- a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc @@ -47,6 +47,7 @@ namespace { constexpr int32_t kIndexShuffleRounds = 8; +constexpr const char kDatasetType[] = "GlobalShuffle"; constexpr const char kElementCount[] = "element_count"; constexpr const char kGlobalShuffleDataset[] = "GlobalShuffleDataset"; constexpr const char kReshuffleEachIteration[] = "reshuffle_each_iteration"; @@ -105,7 +106,7 @@ class GlobalShuffleDatasetOp::Dataset : public DatasetBase { } std::string DebugString() const override { - return name_utils::DatasetDebugString(kGlobalShuffleDataset); + return name_utils::DatasetDebugString(kDatasetType); } int64_t CardinalityInternal(CardinalityOptions options) const override { @@ -340,8 +341,7 @@ std::unique_ptr GlobalShuffleDatasetOp::Dataset::MakeIteratorInternal( const std::string& prefix) const { return std::make_unique( - Iterator::Params{ - this, name_utils::IteratorPrefix(kGlobalShuffleDataset, prefix)}, + Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)}, seed_generator_->get()); } From a66b17075a05337532a1758cfe81c4f11abe3029 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 21:05:42 -0700 Subject: [PATCH 065/670] When possible, use the device ids provided by the user instead of defaulting to the iota order. PiperOrigin-RevId: 617041894 --- .../auto_sharding/auto_sharding.cc | 13 +++++-- .../auto_sharding/auto_sharding_strategy.cc | 27 ++++++++++++-- .../auto_sharding/auto_sharding_test.cc | 37 +++++++++++++++++++ .../auto_sharding/cluster_environment.h | 11 ++++++ 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index dc7a4eb01edf79..a59f4ee2f335b3 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3631,9 +3631,16 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( return changed.status(); } } - std::vector device_mesh_ids = std::vector(total_devices); - std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); - device_mesh.SetValues(device_mesh_ids); + if (option_.device_mesh_ids.size() == total_devices) { + // It is unclear what device order to use for partial meshes. So we only + // use the actual device order only for the final full mesh. + device_mesh.SetValues(option_.device_mesh_ids); + } else { + std::vector device_mesh_ids = + std::vector(total_devices); + std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); + device_mesh.SetValues(device_mesh_ids); + } // TODO (zhuohan): Include the prof result as an option. spmd::ProfilingResult prof_result; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 8dfe3877e3b7b2..4563141e30b67f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -393,13 +393,32 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Find output shardings. switch (opcode) { case HloOpcode::kSlice: { + // When solve_nd_sharding_iteratively is true, in some cases, we + // can have 1D shardings where the total number of tiles is larger + // than the number of elements in the partial mesh (and is + // actually equal to the number of devices in the original + // mesh). Below, we use the correct mesh depending on the number + // of elements in the 1D sharding. bool is_1d_sharding = VectorGreaterThanOneElementCount( input_spec.tile_assignment().dimensions()) == 1; - output_spec = PropagateDimwiseShardingSlice( - input_spec, operand->shape(), ins->shape(), - is_1d_sharding ? cluster_env.device_mesh_1d_ - : cluster_env.device_mesh_); + if (is_1d_sharding && + input_spec.TotalNumTiles() == + cluster_env.device_mesh_1d_.num_elements()) { + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_1d_); + } else if (is_1d_sharding) { + CHECK_EQ(input_spec.TotalNumTiles(), + cluster_env.original_device_mesh_1d_.num_elements()); + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.original_device_mesh_1d_); + } else { + output_spec = PropagateDimwiseShardingSlice( + input_spec, operand->shape(), ins->shape(), + cluster_env.device_mesh_); + } break; } case HloOpcode::kPad: diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index aa1167dfac33bb..27b9df98e3a88c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -269,6 +269,43 @@ ENTRY %elementwise { op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"))); } +TEST_F(AutoShardingTest, SliceMixedUserShardingTest) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +ENTRY %elementwise { + param = s32[512,3084]{1,0} parameter(0), sharding={devices=[4,1]0,2,1,3} + slice = s32[512,2048]{1,0} slice(param), slice={[0:512], [0:2048]} + ROOT copy = s32[512,2048]{1,0} copy(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ { + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .solve_nd_sharding_iteratively = true, + .device_mesh_shape = {2, 2}, + .device_mesh_ids = {0, 2, 1, 3}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + EXPECT_THAT(instructions, + Each(ResultOf( + [](const HloInstruction* ins) { return ins->has_sharding(); }, + IsTrue()))); + EXPECT_THAT(instructions, Each(op::Sharding("{devices=[4,1]0,2,1,3}"))); +} + TEST_F(AutoShardingTest, RngBitGeneratorArrayInput) { constexpr absl::string_view hlo_string = R"( HloModule rng_bit_generator diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index 7bab542bdbd2e3..19736d19e25f0a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -51,6 +51,7 @@ class ClusterEnvironment { prof_result_(prof_result), total_devices_(device_mesh.num_elements()), device_mesh_1d_(device_mesh), + original_device_mesh_1d_(original_device_mesh), auto_sharding_option_(auto_sharding_option) { // Build replica group for each dimension. non_zero_mesh_dims_ = @@ -71,6 +72,12 @@ class ClusterEnvironment { std::vector device_mesh_1d_shape(device_mesh.num_dimensions(), 1); device_mesh_1d_shape[largest_dim_idx] = device_mesh.num_elements(); device_mesh_1d_.Reshape(device_mesh_1d_shape); + + std::vector original_device_mesh_1d_shape( + original_device_mesh.num_dimensions(), 1); + original_device_mesh_1d_shape[largest_dim_idx] = + original_device_mesh.num_elements(); + original_device_mesh_1d_.Reshape(original_device_mesh_1d_shape); } size_t NumDevices() const { return total_devices_; } @@ -171,6 +178,10 @@ class ClusterEnvironment { // Used for mixed mesh shape strategies. Array device_mesh_1d_; + // Cache a flatten 1d version of the original device mesh. + // Used for mixed mesh shape strategies. + Array original_device_mesh_1d_; + // The option may override the cost of communication primitives const AutoShardingOption& auto_sharding_option_; From 165e3288ce537138b6f02f9a71dd23466960df85 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 18 Mar 2024 21:09:55 -0700 Subject: [PATCH 066/670] [OptimizeFunctionGraph] Prune the function library to reachable functions in the post-optimized graph. Adds an overload `FunctionLibraryDefinition::ReachableDefinitions(const Graph&)` that enables capturing the reachable definitions from a `tensorflow::Graph` (and not just a protobuf-based graph). PiperOrigin-RevId: 617042757 --- .../optimize_function_graph_utils.cc | 7 +- tensorflow/core/framework/function.cc | 90 +++++++++++++------ tensorflow/core/framework/function.h | 1 + 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index 357520c827c393..264067a10a73d5 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" #include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/optimized_function_graph.pb.h" @@ -640,8 +641,12 @@ StatusOr OptimizeFunctionGraph( graph->mutable_flib_def()->set_default_registry(nullptr); graph->mutable_flib_def()->Clear(); + + FunctionLibraryDefinition pruned_lib_def = + reachable_lib_def.ReachableDefinitions(*graph); + return OptimizedFunctionGraphInfo( - function_name, std::move(graph), std::move(reachable_lib_def), + function_name, std::move(graph), std::move(pruned_lib_def), node_name_to_control_ret, ret_types, ret_nodes.size(), env->NowMicros() - graph_optimization_start_time_usecs, optimization_source); diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 5a63dc61f019ac..0b6bacd94af0d9 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1820,9 +1820,12 @@ namespace { constexpr char kApiImplements[] = "api_implements"; -std::set ReachableFunctions( - const FunctionLibraryDefinition& flib, - const protobuf::RepeatedPtrField& nodes) { +template +std::set ReachableFunctions(const FunctionLibraryDefinition& flib, + NodeIter begin, NodeIter end, + OpTypeGetter op_type_getter, + AttrGetter attr_getter) { // Functions that are reachable from the graph. std::set reachable_funcs; @@ -1860,31 +1863,33 @@ std::set ReachableFunctions( } }; + const auto process_attr_value = [&](const AttrValue& attr_value) { + // 1. AttrValue.func + if (attr_value.has_func()) { + add_to_func_queue(attr_value.func().name()); + } + + // 2. AttrValue.ListValue.func + if (attr_value.has_list()) { + for (const auto& func : attr_value.list().func()) { + add_to_func_queue(func.name()); + } + } + }; + // Add all the functions that are reachable from the given node to the queue. - const auto process_node = [&](const NodeDef& node) { + const auto process_node = [&](NodeType node) { // Node itself can be a call to the function. - add_to_func_queue(node.op()); + add_to_func_queue(op_type_getter(node)); // Or node can have an attribute referencing a function. - for (const auto& attr : node.attr()) { - const auto& attr_value = attr.second; - - // 1. AttrValue.func - if (attr_value.has_func()) { - add_to_func_queue(attr_value.func().name()); - } - - // 2. AttrValue.ListValue.func - if (attr_value.has_list()) { - for (const auto& func : attr_value.list().func()) { - add_to_func_queue(func.name()); - } - } + for (const auto& attr : attr_getter(node)) { + process_attr_value(attr.second); } }; // Add all functions that are directly called from the optimized graph. - std::for_each(nodes.begin(), nodes.end(), process_node); + std::for_each(begin, end, process_node); // Process all reachable functions. while (!func_queue.empty()) { @@ -1901,7 +1906,18 @@ std::set ReachableFunctions( // Find all the functions called from the function body. const auto& func_body = func->fdef().node_def(); - std::for_each(func_body.begin(), func_body.end(), process_node); + + const auto process_node_def = [&](const NodeDef node) { + // Node itself can be a call to the function. + add_to_func_queue(node.op()); + + // Or node can have an attribute referencing a function. + for (const auto& attr : node.attr()) { + process_attr_value(attr.second); + } + }; + + std::for_each(func_body.begin(), func_body.end(), process_node_def); // Check if the function has a registered gradient. const string grad_func_name = flib.FindGradient(func_name); @@ -1911,10 +1927,13 @@ std::set ReachableFunctions( return reachable_funcs; } +template FunctionLibraryDefinition ReachableFunctionLibraryDefinition( - const FunctionLibraryDefinition& flib, - const protobuf::RepeatedPtrField& nodes) { - std::set reachable_funcs = ReachableFunctions(flib, nodes); + const FunctionLibraryDefinition& flib, NodeIter begin, NodeIter end, + OpTypeGetter op_type_getter, AttrGetter attr_getter) { + std::set reachable_funcs = ReachableFunctions( + flib, begin, end, op_type_getter, attr_getter); FunctionLibraryDefinition reachable_flib(flib.default_registry(), FunctionDefLibrary()); @@ -1961,12 +1980,26 @@ const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; } FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions( const GraphDef& graph) const { - return ReachableFunctionLibraryDefinition(*this, graph.node()); + return ReachableFunctionLibraryDefinition( + *this, graph.node().begin(), graph.node().end(), + [](const NodeDef& ndef) { return ndef.op(); }, + [](const NodeDef& ndef) { return ndef.attr(); }); } FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions( const FunctionDef& func) const { - return ReachableFunctionLibraryDefinition(*this, func.node_def()); + return ReachableFunctionLibraryDefinition( + *this, func.node_def().begin(), func.node_def().end(), + [](const NodeDef& ndef) { return ndef.op(); }, + [](const NodeDef& ndef) { return ndef.attr(); }); +} + +FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions( + const Graph& graph) const { + return ReachableFunctionLibraryDefinition( + *this, graph.nodes().begin(), graph.nodes().end(), + [](const Node* node) { return node->type_string(); }, + [](const Node* node) { return node->attrs(); }); } absl::StatusOr @@ -1975,7 +2008,10 @@ FunctionLibraryDefinition::ReachableDefinitions( auto* func = Find(function_name); if (func) { FunctionLibraryDefinition ret = - ReachableFunctionLibraryDefinition(*this, func->node_def()); + ReachableFunctionLibraryDefinition( + *this, func->node_def().begin(), func->node_def().end(), + [](const NodeDef& ndef) { return ndef.op(); }, + [](const NodeDef& ndef) { return ndef.attr(); }); TF_RETURN_IF_ERROR(ret.CopyFunctionDefFrom(function_name, *this)); return ret; } else { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index af956ac1524427..eb74ea58905405 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -569,6 +569,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // reachable from the nodes of `graph` or `func`. FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const; FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const; + FunctionLibraryDefinition ReachableDefinitions(const Graph& graph) const; absl::StatusOr ReachableDefinitions( const std::string& function_name) const; From a74f5d1d9238d696d4347d28bfee70e45b5dc78c Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Mon, 18 Mar 2024 21:44:15 -0700 Subject: [PATCH 067/670] Implement expanding presets from `QuantizationConfig`. `ExpandPresets` transfers quantization presets and populates other fields in `QuantizationConfig`. PiperOrigin-RevId: 617049405 --- .../mlir/quantization/stablehlo/cc/config.cc | 46 +++++++++ .../mlir/quantization/stablehlo/cc/config.h | 17 ++++ .../quantization/stablehlo/cc/config_test.cc | 98 +++++++++++++++++++ .../stablehlo/quantization_config.proto | 35 +++++-- 4 files changed, 189 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index e8a4aa87bb0619..0284c00523f420 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + namespace stablehlo::quantization { namespace { @@ -23,11 +25,55 @@ CalibrationOptions GetDefaultCalibrationOptions() { CalibrationOptions options{}; options.set_calibration_method( CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); + return options; } +void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset, + QuantizationConfig& config) { + // Populate with preset's representative dataset configs if the user didn't + // explicitly specify other representative dataset configs to the top-level + // `CalibrationOptions`. + if (config.calibration_options().representative_datasets().empty()) { + auto preset_datasets = preset.representative_datasets(); + config.mutable_calibration_options() + ->mutable_representative_datasets() + ->Add(preset_datasets.begin(), preset_datasets.end()); + } + + // Create a new `QuantizationSpecs` to replace the existing one. The expansion + // from `StaticRangePtqPreset` gets populated first and then user-provided + // explicit `QuantizationSpec`s will be appended. + QuantizationSpecs new_specs{}; + QuantizationSpec& spec = *new_specs.add_specs(); + spec.mutable_matcher()->mutable_function_name()->set_regex(".*"); + spec.mutable_method()->mutable_static_range_ptq(); + + const QuantizationSpecs& previous_specs = config.specs(); + new_specs.mutable_specs()->Add(previous_specs.specs().begin(), + previous_specs.specs().end()); + + config.mutable_specs()->Swap(&new_specs); +} + } // namespace +QuantizationConfig ExpandPresets(const QuantizationConfig& config) { + QuantizationConfig new_config = config; + + // Update the `new_config` with each preset's expansions. + switch (config.preset_case()) { + case QuantizationConfig::kStaticRangePtqPreset: + ExpandStaticRangePtqPreset(config.static_range_ptq_preset(), new_config); + break; + default: + // Preset has not been specified. The expansion is a no-op. + break; + } + + return new_config; +} + QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config) { QuantizationConfig config = user_provided_config; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h index 20b9efa4a60fa0..5dc4554d784c92 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -24,6 +24,23 @@ namespace stablehlo::quantization { QuantizationConfig PopulateDefaults( const QuantizationConfig& user_provided_config); +// Returns a copy of `QuantizationConfig` where presets are expanded and +// transformed into other fields in `QuantizationConfig`. +// +// The expansion rules are as follows: +// * StaticRangePtqPreset +// - The preset's `representative_datasets` field will be transferred to +// `QuantizationConfig.calibration_options.representative_datasets`, unless +// the user explicitly provided representative dataset configs to +// `calibration_options`. In that case, the explicit configs take precedence +// and the preset's configs are ignored. +// - For `QuantizationSpecs`, the expanded `QuantizationSpec`s will be +// populated first and user-provided `QuantizationSpec`s, if any, will be +// appended. This expresses the fact that user-provided specs take precedence. +// * Preset unspecified +// - No-op. +QuantizationConfig ExpandPresets(const QuantizationConfig& config); + } // namespace stablehlo::quantization #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index 164cd6bae237f8..b606c797819c4b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -22,6 +22,8 @@ namespace stablehlo::quantization { namespace { using ::testing::Eq; +using ::testing::SizeIs; +using ::testing::StrEq; TEST(PopulateDefaultsTest, PopulateDefaultsForEmptyConfig) { QuantizationConfig config{}; @@ -68,5 +70,101 @@ TEST(PopulateDefaultsTest, ExplicitCalibrationOptionsNotOverridden) { Eq(512)); } +TEST(ExpandPresetsTest, ExpandUnspecifiedPreset) { + QuantizationConfig config{}; + const QuantizationConfig new_config = ExpandPresets(config); + + // Test that nothing has been changed. + EXPECT_FALSE(new_config.has_specs()); + EXPECT_FALSE(new_config.has_calibration_options()); + EXPECT_FALSE(new_config.has_pipeline_config()); +} + +TEST(ExpandPresetsTest, ExpandStaticRangePtqPreset) { + QuantizationConfig config{}; + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + preset_dataset_config.mutable_tf_record()->set_path("/test/path"); + + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(1)); + + const QuantizationSpec& spec = new_config.specs().specs(0); + EXPECT_THAT(spec.matcher().function_name().regex(), StrEq(".*")); + EXPECT_TRUE(spec.method().has_static_range_ptq()); + + // Test that representative dataset config has been transferred to the + // `CalibrationOptions`. + ASSERT_THAT(new_config.calibration_options().representative_datasets(), + SizeIs(1)); + EXPECT_THAT(new_config.calibration_options() + .representative_datasets(0) + .tf_record() + .path(), + StrEq("/test/path")); +} + +TEST(ExpandPresetsTest, + ExpandStaticRangePtqPresetWithExplicitRepresentativeDatasetConfigs) { + // Test the scenario where both + // `config.calibration_options.representative_datasets` and + // `config.static_range_ptq_preset.representative_datasets` are both + // specified. In this case, the one set to the `calibration_options` takes + // precedence. + QuantizationConfig config{}; + RepresentativeDatasetConfig& top_level_dataset_config = + *config.mutable_calibration_options()->add_representative_datasets(); + top_level_dataset_config.mutable_tf_record()->set_path("/test/path/1"); + + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + preset_dataset_config.mutable_tf_record()->set_path("/test/path/2"); + + const QuantizationConfig new_config = ExpandPresets(config); + + // Test that representative dataset config has not been transferred to the + // `CalibrationOptions`. Top-level config takes precedence. + ASSERT_THAT(new_config.calibration_options().representative_datasets(), + SizeIs(1)); + EXPECT_THAT(new_config.calibration_options() + .representative_datasets(0) + .tf_record() + .path(), + StrEq("/test/path/1")); +} + +TEST(ExpandPresetsTest, + ExpandStaticRangePtqPresetWithExplicitSpecsAppendedAfterExpandedSpecs) { + QuantizationConfig config{}; + config.mutable_static_range_ptq_preset(); + + QuantizationSpec& user_provided_spec = *config.mutable_specs()->add_specs(); + user_provided_spec.mutable_matcher()->mutable_function_name()->set_regex( + "composite_dot_general_fn_1"); + user_provided_spec.mutable_method()->mutable_no_quantization(); + + // Test that the expanded `QuantizationSpec`s are populated first and then + // user-provided specs are appended. + // + // It should look like: + // + // specs {matcher {function_name {regex: ".*"}} method {static_range_ptq {}}} + // specs { + // matcher {function_name {regex: "composite_dot_general_fn_1"}} + // method {no_quantization {}} + // } + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(2)); + + const QuantizationSpec& first_spec = new_config.specs().specs(0); + EXPECT_THAT(first_spec.matcher().function_name().regex(), StrEq(".*")); + EXPECT_TRUE(first_spec.method().has_static_range_ptq()); + + const QuantizationSpec& second_spec = new_config.specs().specs(1); + EXPECT_THAT(second_spec.matcher().function_name().regex(), + StrEq("composite_dot_general_fn_1")); + EXPECT_TRUE(second_spec.method().has_no_quantization()); +} + } // namespace } // namespace stablehlo::quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index b4c4dbdf1f26c8..36b781a7d28914 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -28,13 +28,23 @@ message RepresentativeDatasetConfig { } // Preset config for static-range post-training quantization (PTQ). +// // Minimal user input about representative datasets is required. Representative // datasets are required for static-range PTQ to retrieve quantization // statistics via calibration. +// +// This preset is equivalent to the following `QuantizationSpecs`: +// +// ``` +// specs {matcher {function_name {regex: ".*"}} method {static_range_ptq {}}} +// ``` +// // Next ID: 3 message StaticRangePtqPreset { // Configures representative dataset. Each item corresponds to a // representative dataset used to calibrate a function. + // If `QuantizationConfig.calibration_options.representative_datasets` is also + // provided then this field will be ignored. repeated RepresentativeDatasetConfig representative_datasets = 1; // NOTE: This field will be deprecated. @@ -93,6 +103,9 @@ message QuantizationResults { // denylisting quantizable units from quantization. message NoQuantization {} +// Configurations for static-range post-training quantization method. +message StaticRangePtq {} + // Represents a matching method that matches quantizable units by lifted // functions' names. message FunctionNameMatcherSpec { @@ -110,7 +123,10 @@ message MatcherSpec { // Specifies how to quantize matched quantizable units. message Method { - NoQuantization no_quantization = 1; + oneof method { + NoQuantization no_quantization = 1; + StaticRangePtq static_range_ptq = 2; + } } // A QuantizationSpec is essentially a (matcher spec, quantization method) pair, @@ -184,9 +200,10 @@ message DebuggerConfig { } // Defines various calibration options. +// Next ID: 4 message CalibrationOptions { // Configurations for calibration methods. - // NEXT ID: 7 + // Next ID: 7 enum CalibrationMethod { CALIBRATION_METHOD_UNSPECIFIED = 0; // Use the min, max values of all sample datasets. @@ -211,7 +228,7 @@ message CalibrationOptions { } // Parameters required for calibration. - // NEXT ID: 4 + // Next ID: 4 message CalibrationParameters { // The number of bins when histogram is initialized. It can be increased // because histogram is dynamically expanded by sample inputs. @@ -234,6 +251,10 @@ message CalibrationOptions { // MIN_MAX and AVERAGE_MIN_MAX don't require this parameter and methods // starting with HISTOGRAM require this parameter. CalibrationParameters calibration_parameters = 2; + + // Configures representative dataset. Each item corresponds to a + // representative dataset used to calibrate a function. + repeated RepresentativeDatasetConfig representative_datasets = 3; } // Quantization configuration for StableHLO Quantizer. This is the primary @@ -242,10 +263,10 @@ message CalibrationOptions { message QuantizationConfig { // Config presets provide predefined popular or common quantization specs. // Lightweight users may choose one of the presets for quick experiments. Each - // preset is completely represented by `QuantizationSpecs`. When extra entries - // in `QuantizationSpecs` are provided along with a preset, then the preset - // will be overridden for the quantizable units matched by those additional - // `QuantizationSpec`s. + // preset is completely represented by other fields in `QuantizationConfig`. + // + // When extra entries in `QuantizationSpecs` are provided along with a preset, + // then those entries will take precedence. oneof preset { // Performs best-effort static-range post-training quantization (PTQ). StaticRangePtqPreset static_range_ptq_preset = 1; From ac3c5809b04ad37428d003e489e7651a50eb6b5e Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Mon, 18 Mar 2024 21:56:22 -0700 Subject: [PATCH 068/670] Lower tf.IfrtRestoreVariableOp to tf_mlrt.IfrtRestoreVariableOp PiperOrigin-RevId: 617051341 --- .../compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td | 28 ++++++++++++++++++- .../mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir | 21 ++++++++++++++ .../mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc | 23 ++++++++++++++- .../mlir/tfrt/transforms/mlrt/util.cc | 6 ++-- 4 files changed, 73 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td index 7fbc42ad3db93f..72eac197011a6d 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -427,7 +427,7 @@ def AsyncWhileOp : TensorflowMlrt_Op<"async_while", [Pure]> { }]; } -def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { +def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", [Pure]> { let summary = "Loads a variable tensor as an IFRT array for mlrt"; let description = [{ @@ -458,5 +458,31 @@ def IfrtLoadVariableOp: TensorflowMlrt_Op<"ifrt_load_variable", []> { ); } +def IfrtRestoreVariableOp: TensorflowMlrt_Op<"ifrt_restore_variable", []> { + let summary = "Restore variable tensors"; + let description = [{ + This is the MLRT version of tf.IfrtRestoreVariableOp. + + This Op is similar to a combination of RestoreV2 and AssignVariable Op, but + this Op's execution is asynchronous. + + This Op is specific to MLRT runtime and is not a stable interface for + serialization. + + This Op will restore the tensors asynchronously and allow the runtime to look + for them. + The runtime shall handle the possibility that the tensors are not ready when requested + because the tensors are loaded asynchronously. + + }]; + + let arguments = (ins + TFTensorType:$prefix, + TFTensorType:$tensor_names, + TFTensorType:$shape_and_slices, + Variadic:$var_handles, + TypeArrayAttr: $restored_dtypes + ); +} #endif diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir index eb2e0587364d6e..4cd2d6f3613a27 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -476,3 +476,24 @@ func.func @ifrt_load_variable_test() -> () { func.return } +// ----- + +// Test lowering of IfrtRestoreVariableOp + +// CHECK-LABEL: func @ifrt_restore_variable_test +func.func @ifrt_restore_variable_test() -> () { + // CHECK-NEXT: [[PREFIX:%.*]] = tf_mlrt.executeop + %cst = "tf.Const"() {__op_key = 0: i32, value = dense<"restore_ariables"> : tensor} : () -> tensor + // CHECK-NEXT: [[SLICE:%.*]] = tf_mlrt.executeop + %cst_0 = "tf.Const"() {__op_key = 1: i32, value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + // CHECK-NEXT: [[NAME:%.*]] = tf_mlrt.executeop + %cst_1 = "tf.Const"() {__op_key = 2: i32, value = dense<["y"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + // CHECK-NEXT: [[HANDLE:%.*]] = tf_mlrt.executeop + %handle = "tf.VarHandleOp"() {__op_key = 3: i32, container = "x", shared_name = "y"} : () -> tensor>> + // CHECK-NEXT: "tf_mlrt.ifrt_restore_variable"([[PREFIX]], [[NAME]], [[SLICE]], [[HANDLE]]) {restored_dtypes = [f32]} + "tf.IfrtRestoreVariableOp"(%cst, %cst_1, %cst_0, %handle) {restored_dtypes = [f32]} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor>>) -> () + // CHECK-NEXT: return + func.return +} + + diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 8271a5c796e5c4..0fb986e567b2f4 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -343,6 +343,26 @@ class IfrtLoadVariableOpConversion } }; +// Convert tf.IfrtRestoreVariableOp to tf_mlrt.IfrtRestoreVariableOp +class IfrtRestoreVariableOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::IfrtRestoreVariableOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto new_op = rewriter.create( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[2], + adaptor.getOperands().slice(3, adaptor.getOperands().size() - 3), + op.getRestoredDtypes()); + rewriter.replaceOp(op, new_op); + + return mlir::success(); + } +}; + std::optional DecodeLongName(mlir::Location loc) { if (auto name_loc = loc.dyn_cast()) { return name_loc.getName().str(); @@ -1189,7 +1209,8 @@ class TfToMlrtConversionPass patterns.add(&context, &type_converter_, &symbol_table); patterns.add(&context); + IfrtRestoreVariableOpConversion, TFAwaitOpConversion, + TFPromiseOpConversion>(&context); patterns.add(type_converter_, &context); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc index d9e1b7f73ac0c8..a1f9d401f5c485 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -35,9 +35,9 @@ bool UseFallback(mlir::Operation *op) { return !llvm::isa< mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp, mlir::TF::BatchFunctionOp, mlir::TF::CaseOp, mlir::TF::IfrtLoadVariableOp, - mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp, - mlir::TF::LegacyCallOp, mlir::TF::IfOp, mlir::TF::WhileOp, - mlir::TF::TPUCompileMlirAndExecuteOp>(op); + mlir::TF::IfrtRestoreVariableOp, mlir::TF::StatefulPartitionedCallOp, + mlir::TF::PartitionedCallOp, mlir::TF::LegacyCallOp, mlir::TF::IfOp, + mlir::TF::WhileOp, mlir::TF::TPUCompileMlirAndExecuteOp>(op); } } // namespace mlrt_compiler From a0911b4c89e9511d0089a37b0d32fb4c8b4795e6 Mon Sep 17 00:00:00 2001 From: Doyoung Gwak Date: Mon, 18 Mar 2024 23:20:46 -0700 Subject: [PATCH 069/670] Migrate DebuggerOptions to DebuggerConfig PiperOrigin-RevId: 617066168 --- RELEASE.md | 2 ++ .../mlir/quantization/stablehlo/cc/debugger.cc | 8 ++++---- .../mlir/quantization/stablehlo/cc/debugger.h | 2 +- .../integration_test/quantize_model_test.py | 5 ++--- .../tensorflow/python/pywrap_quantize_model.cc | 10 +++++----- .../tensorflow/python/quantize_model.cc | 15 +-------------- .../tensorflow/python/quantize_model.py | 12 ++++++------ .../tensorflow/quantization_options.proto | 17 +---------------- .../quantization/tensorflow/quantize_passes.cc | 6 +++--- ...ion.experimental.-quantization-options.pbtxt | 4 ++-- ...ion.experimental.-quantization-options.pbtxt | 4 ++-- 11 files changed, 29 insertions(+), 56 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 6ff074b10e465d..cd4e3a2cc3bdb8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -29,6 +29,8 @@ * GPU * Support for NVIDIA GPUs with compute capability 8.9 (e.g. L4 & L40) has been added to TF binary distributions (Python wheels). +* Replace `DebuggerOptions` of TensorFlow Quantizer, and migrate to + `DebuggerConfig` of StableHLO Quantizer. ## Keras diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc index 1ba51790de0ac9..134ce2a5a89ebd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.cc @@ -30,16 +30,16 @@ limitations under the License. namespace stablehlo::quantization { namespace { +using ::stablehlo::quantization::DebuggerConfig; using ::tensorflow::NodeDef; using ::tensorflow::SignatureDef; -using ::tensorflow::quantization::DebuggerOptions; using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; } // namespace void EnableDebugging( - ExportedModel& exported_model, const DebuggerOptions& debugger_options, + ExportedModel& exported_model, const DebuggerConfig& debugger_config, const PyFunctionLibrary& py_function_library, const absl::string_view src_saved_model_path, const std::unordered_set& tags, @@ -52,13 +52,13 @@ void EnableDebugging( } }); - if (debugger_options.debugger_type() == + if (debugger_config.debugger_type() == DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL) { // TODO: b/295139417 - Remove CustomAggregator op in unquantized dump model. // TODO: b/296916287 - Create a separate function for saving unquantized // dump model. py_function_library.SaveExportedModel( - debugger_options.unquantized_dump_model_path(), exported_model, + debugger_config.unquantized_dump_model_path(), exported_model, src_saved_model_path, tags, signature_def_map); // Update the `DumpTensor` ops' file name in `graph_def`. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h index 6bb427ecbdf1fd..4cb1523a7594ee 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -38,7 +38,7 @@ namespace stablehlo::quantization { // and compare them offline. void EnableDebugging( tensorflow::quantization::ExportedModel& exported_model, - const tensorflow::quantization::DebuggerOptions& debugger_options, + const stablehlo::quantization::DebuggerConfig& debugger_config, const tensorflow::quantization::PyFunctionLibrary& py_function_library, absl::string_view src_saved_model_path, const std::unordered_set& tags, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index a28d7ebe4bf7f3..18e5a14de44110 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -83,7 +83,6 @@ 'UniformQuantizedDotHybrid', ) -_DebuggerOptions = quant_opts_pb2.DebuggerOptions _DebuggerConfig = stablehlo_quant_config_pb2.DebuggerConfig # Lists of ops whose channel dimension should be changed if per_channel @@ -5926,7 +5925,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 ), op_set=quant_opts_pb2.XLA, - debugger_options=_DebuggerOptions( + debugger_config=_DebuggerConfig( debugger_type=_DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL, unquantized_dump_model_path=unquantized_dump_model_path, log_dir_path=log_dir_path, @@ -6039,7 +6038,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: preset_method=_PresetMethod.METHOD_STATIC_RANGE_INT8 ), op_set=target_opset, - debugger_options=_DebuggerOptions( + debugger_config=_DebuggerConfig( debugger_type=debugger_type, log_dir_path=log_dir_path, ), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 8273279df67787..d61cb59905d66f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -89,7 +89,7 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( @@ -138,7 +138,7 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( @@ -255,9 +255,9 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { << status; } - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { EnableDebugging(*exported_model, - quantization_options.debugger_options(), + quantization_options.debugger_config(), py_function_library, src_saved_model_path, tags, signature_def_map); } @@ -283,7 +283,7 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { // Remove the `tpu` tag from the debug quantized saved model as it is // for CPU. Note the 'tpu' value should be the same as `TPU` defined in // tensorflow/python/saved_model/tag_constants.py. - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { tags.erase("tpu"); } py_function_library.SaveExportedModel( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 08b71190bbb5b5..10bedefb55161d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -79,18 +79,6 @@ using ::stablehlo::quantization::DebuggerConfig; using ::stablehlo::quantization::QuantizationConfig; using ::stablehlo::quantization::io::GetLocalTmpFileName; -// TODO: b/326355110 - Removes `ConvertDebuggerOptionToDebuggerConfig` when -// merging `DebuggingOption` to `DebuggingConfig`. -DebuggerConfig ConvertDebuggerOptionToDebuggerConfig( - const DebuggerOptions &debugger_options) { - DebuggerConfig debugger_config; - debugger_config.set_debugger_type(debugger_options.debugger_type()); - debugger_config.set_unquantized_dump_model_path( - debugger_options.unquantized_dump_model_path()); - debugger_config.set_log_dir_path(debugger_options.log_dir_path()); - return debugger_config; -} - absl::StatusOr> ImportAndPreprocessSavedModel( absl::string_view saved_model_path, const std::vector &signature_keys, @@ -268,8 +256,7 @@ absl::StatusOr QuantizePtqModelPreCalibration( if (is_stablehlo) { QuantizationConfig quantization_config; *quantization_config.mutable_debugger_config() = - ConvertDebuggerOptionToDebuggerConfig( - quantization_options.debugger_options()); + quantization_options.debugger_config(); PreCalibrationComponent pre_calibration_component(context.get()); TF_ASSIGN_OR_RETURN(*module_ref, pre_calibration_component.Run( *module_ref, quantization_config)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 1bf3fe81c7d8ba..961db5334e3bbe 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -692,7 +692,7 @@ def _populate_quantization_options_default_values( ' quantization via TF Quantizer.' ) - if quantization_options.HasField('debugger_options'): + if quantization_options.HasField('debugger_config'): # Set `force_graph_mode_calibration` to True to avoid skipping op execution, # which are not connected to return ops, during calibration execution. # Setting `force_graph_mode_calibration` to True enables execution of the @@ -704,11 +704,11 @@ def _populate_quantization_options_default_values( ) quantization_options.force_graph_mode_calibration = True - if not quantization_options.debugger_options.log_dir_path: - quantization_options.debugger_options.log_dir_path = '/tmp/dumps' + if not quantization_options.debugger_config.log_dir_path: + quantization_options.debugger_config.log_dir_path = '/tmp/dumps' if ( - quantization_options.debugger_options.debugger_type + quantization_options.debugger_config.debugger_type == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_UNSPECIFIED ): raise ValueError( @@ -716,9 +716,9 @@ def _populate_quantization_options_default_values( ) if ( - quantization_options.debugger_options.debugger_type + quantization_options.debugger_config.debugger_type == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL - and not quantization_options.debugger_options.unquantized_dump_model_path + and not quantization_options.debugger_config.unquantized_dump_model_path ): raise ValueError( 'Debugger type whole model verify was used but' diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index 13d3876500fe0d..d2c79b6ce4c668 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -145,21 +145,6 @@ message RepresentativeDatasetFile { } } -// Configuration for quantization debugger. -// NEXT ID: 4 -message DebuggerOptions { - // Type of quantization debugger. Depending on the type, inputs and outputs - // are wired differently. - stablehlo.quantization.DebuggerConfig.DebuggerType debugger_type = 1; - - // Path to save unquantized model with dump tensor ops attached. - // Used when debugger_type is WHOLE_MODEL. - string unquantized_dump_model_path = 2; - - // Path to save debugger related logs. Defaults to '/tmp/dumps'. - string log_dir_path = 3; -} - // Defines various options to specify and control the behavior of the quantizer. // It consists of // 1) Model-wise quantization configuration as a default configuration. If it is @@ -251,7 +236,7 @@ message QuantizationOptions { stablehlo.quantization.CalibrationOptions calibration_options = 15; // Configuration related to quantization debugger. - DebuggerOptions debugger_options = 16; + stablehlo.quantization.DebuggerConfig debugger_config = 16; reserved 3; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0d5e43cd6f334e..0e756021844a5c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -149,10 +149,10 @@ void AddQuantizePtqPreCalibrationPasses( pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsPass( quantization_options)); // TODO: b/295140328 - Add debugger support for weight only - if (quantization_options.has_debugger_options()) { + if (quantization_options.has_debugger_config()) { pm.addPass(mlir::quant::CreateAddDumpTensorOpPass( - quantization_options.debugger_options().debugger_type(), - quantization_options.debugger_options().log_dir_path())); + quantization_options.debugger_config().debugger_type(), + quantization_options.debugger_config().log_dir_path())); } pm.addNestedPass( mlir::quant::CreateInsertCustomAggregationOpsPass( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.experimental.-quantization-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.experimental.-quantization-options.pbtxt index 7edb5900b4b5f4..e00e4c66e47900 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.experimental.-quantization-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.experimental.-quantization-options.pbtxt @@ -96,11 +96,11 @@ tf_proto { type_name: ".stablehlo.quantization.CalibrationOptions" } field { - name: "debugger_options" + name: "debugger_config" number: 16 label: LABEL_OPTIONAL type: TYPE_MESSAGE - type_name: ".tensorflow.quantization.DebuggerOptions" + type_name: ".stablehlo.quantization.DebuggerConfig" } nested_type { name: "RepresentativeDatasetsEntry" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.experimental.-quantization-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.experimental.-quantization-options.pbtxt index 7edb5900b4b5f4..e00e4c66e47900 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.experimental.-quantization-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.experimental.-quantization-options.pbtxt @@ -96,11 +96,11 @@ tf_proto { type_name: ".stablehlo.quantization.CalibrationOptions" } field { - name: "debugger_options" + name: "debugger_config" number: 16 label: LABEL_OPTIONAL type: TYPE_MESSAGE - type_name: ".tensorflow.quantization.DebuggerOptions" + type_name: ".stablehlo.quantization.DebuggerConfig" } nested_type { name: "RepresentativeDatasetsEntry" From a88a81a7db64070b0f66e77f88b1e54c9981106d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Mar 2024 23:23:17 -0700 Subject: [PATCH 070/670] [XLA:Python] Improve error checking for the return value of the to_iterable function of custom pytree nodes. PiperOrigin-RevId: 617066587 --- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/pytree.cc | 48 ++++++++++++++++-------- third_party/xla/xla/python/pytree.h | 8 +++- third_party/xla/xla/python/xla_client.py | 2 +- 4 files changed, 42 insertions(+), 17 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 3668ab3ba2ff5a..1c58803c966be2 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -863,6 +863,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//third_party/nanobind", + "@local_config_python//:python_headers", # buildcleaner: keep "//xla/pjrt:exceptions", "@local_tsl//tsl/platform:logging", ], diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index edd43c8dd74a31..0c8dcf5fe02e49 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -18,7 +18,11 @@ limitations under the License. #include "xla/python/pytree.h" +#include + #include +#include +#include #include #include #include @@ -93,6 +97,28 @@ void PyTreeRegistry::Register(nb::object type, nb::callable to_iterable, } } +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + // Computes the node kind of a given Python object. PyTreeKind PyTreeRegistry::KindOfObject( nb::handle obj, PyTreeRegistry::Registration const** custom) const { @@ -257,14 +283,10 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, break; } case PyTreeKind::kCustom: { - nb::tuple out = nb::cast(node.custom->to_iterable(handle)); - if (out.size() != 2) { - throw xla::XlaRuntimeError( - "PyTree custom to_iterable function should return a pair"); - } - node.node_data = out[1]; + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); node.arity = 0; - for (nb::handle entry : nb::cast(out[0])) { + for (nb::handle entry : leaves) { ++node.arity; recurse(entry); } @@ -558,20 +580,16 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { nb::cast(nb::repr(node.custom->type)), nb::cast(nb::repr(object)))); } - nb::tuple out = nb::cast(node.custom->to_iterable(object)); - if (out.size() != 2) { - throw xla::XlaRuntimeError( - "PyTree custom to_iterable function should return a pair"); - } - if (node.node_data.not_equal(out[1])) { + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom node data: %s != %s; value: %s.", nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(out[1])), + nb::cast(nb::repr(aux_data)), nb::cast(nb::repr(object)))); } int arity = 0; - for (nb::handle entry : nb::cast(out[0])) { + for (nb::handle entry : leaves) { ++arity; agenda.push_back(nb::borrow(entry)); } diff --git a/third_party/xla/xla/python/pytree.h b/third_party/xla/xla/python/pytree.h index 266af78b56c552..9a453ad0f17f8f 100644 --- a/third_party/xla/xla/python/pytree.h +++ b/third_party/xla/xla/python/pytree.h @@ -19,9 +19,9 @@ limitations under the License. // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation // about pytree. +#include #include #include -#include #include #include #include @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" +#include "absl/types/span.h" #include "third_party/nanobind/include/nanobind/nanobind.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/pytree.pb.h" @@ -67,6 +68,11 @@ class PyTreeRegistry : public std::enable_shared_from_this { nanobind::callable to_iterable; // A function with signature: (aux_data, iterable) -> object nanobind::callable from_iterable; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; }; // Registers a new custom type. Objects of `type` will be treated as container diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index d8b24aba09dcb5..ca419694f95d3a 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 246 +_version = 247 # Version number for MLIR:Python components. mlir_api_version = 55 From 752146579214f1ecb13fc9bcea8d221e7da8067f Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 19 Mar 2024 00:09:52 -0700 Subject: [PATCH 071/670] #tf-data Support global shuffle for the skip dataset. PiperOrigin-RevId: 617074621 --- tensorflow/core/kernels/data/BUILD | 6 +- .../core/kernels/data/skip_dataset_op.cc | 61 ++++++++++++- tensorflow/python/data/kernel_tests/BUILD | 1 + .../python/data/kernel_tests/skip_test.py | 88 +++++++++++++++++++ 4 files changed, 151 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 9509792dcd2450..3210941c5d2d7d 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1228,9 +1228,11 @@ tf_kernel_library( deps = [ "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core/data:global_shuffle_utils", "//tensorflow/core/data:name_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 2a0c75f4c54b93..c5ccea131b96c2 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -14,9 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/skip_dataset_op.h" +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/data/global_shuffle_utils.h" #include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace data { @@ -40,6 +48,14 @@ class SkipDatasetOp::Dataset : public DatasetBase { Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input) : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) { input_->Ref(); + if (input_ != nullptr && count >= 0) { + random_indexing_compatible_ = input_->RandomIndexingCompatible(); + } else { + random_indexing_compatible_ = absl::FailedPreconditionError( + absl::StrCat("Global shuffling does not support empty dataset or " + "skipping the entire dataset. Got skip(", + count, ").")); + } } ~Dataset() override { input_->Unref(); } @@ -90,6 +106,10 @@ class SkipDatasetOp::Dataset : public DatasetBase { return input_->Get(ctx, index + count_, out_tensors); } + absl::Status RandomIndexingCompatible() const override { + return random_indexing_compatible_; + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, @@ -156,10 +176,13 @@ class SkipDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } + IteratorContextWithIndexMapper ctx_with_index_mapper(ctx, this); if (i_ < dataset()->count_) { int num_skipped; - TF_RETURN_IF_ERROR(input_impl_->Skip(ctx, dataset()->count_ - i_, + TF_RETURN_IF_ERROR(input_impl_->Skip(ctx_with_index_mapper.Get(), + dataset()->count_ - i_, end_of_sequence, &num_skipped)); + ctx_with_index_mapper.MergeCheckpoint(); i_ += num_skipped; if (*end_of_sequence) { // We reached the end before the count was reached. @@ -169,14 +192,29 @@ class SkipDatasetOp::Dataset : public DatasetBase { } // Return GetNext() on the underlying iterator. - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx_with_index_mapper.Get(), + out_tensors, end_of_sequence)); + ctx_with_index_mapper.MergeCheckpoint(); if (*end_of_sequence) { input_impl_.reset(); } return absl::OkStatus(); } + IndexMapperFn GetIndexMapper( + IndexMapperFn parent_index_mapper) const override { + int64_t skip_count = dataset()->count_; + return [parent_index_mapper, + skip_count](size_t element_position) -> size_t { + if (element_position < skip_count) { + // The first `skip_count` elements are to be skipped. + return parent_index_mapper(element_position); + } + // Maps the range [skip_count, cardinality) to a permuted range. + return parent_index_mapper(element_position - skip_count) + skip_count; + }; + } + protected: std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { @@ -198,6 +236,22 @@ class SkipDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { + if (ctx->restored_element_count().has_value()) { + mutex_lock l(mu_); + if (*ctx->restored_element_count() > 0) { + i_ = dataset()->count_; + // For upstream iterators, the restored count is the returned element + // count + skipped element count. + IteratorContext::Params params(ctx); + params.restored_element_count = + *ctx->restored_element_count() + dataset()->count_; + IteratorContext ctx_with_restored_count(params); + return RestoreInput(&ctx_with_restored_count, reader, input_impl_); + } + i_ = 0; + return RestoreInput(ctx, reader, input_impl_); + } + mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIndex, &i_)); int64_t input_empty; @@ -219,6 +273,7 @@ class SkipDatasetOp::Dataset : public DatasetBase { const int64_t count_; const DatasetBase* const input_; + absl::Status random_indexing_compatible_; }; SkipDatasetOp::SkipDatasetOp(OpKernelConstruction* ctx) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 3b63533996613b..bdb05c5950c821 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -1179,6 +1179,7 @@ tf_py_strict_test( deps = [ ":checkpoint_test_base", ":test_base", + "//tensorflow/python/data/experimental/ops:global_shuffle_op", "//tensorflow/python/data/experimental/ops:random_access", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", diff --git a/tensorflow/python/data/kernel_tests/skip_test.py b/tensorflow/python/data/kernel_tests/skip_test.py index d117ced2b12222..bba8d1e30ca68d 100644 --- a/tensorflow/python/data/kernel_tests/skip_test.py +++ b/tensorflow/python/data/kernel_tests/skip_test.py @@ -13,9 +13,13 @@ # limitations under the License. # ============================================================================== """Tests for `tf.data.Dataset.skip()`.""" + +from typing import Callable, Optional + from absl.testing import parameterized import numpy as np +from tensorflow.python.data.experimental.ops import global_shuffle_op from tensorflow.python.data.experimental.ops import random_access from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base @@ -124,5 +128,89 @@ def testMultipleCombinations(self, elements, skip): self.evaluate(random_access.at(dataset, index=i)), i + skip) +class SkipGlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine( + dataset_range=[10], + count=[0, 2], + repetitions=[1, 2], + seed=[None, 42], + reshuffle_each_iteration=[True, False]))) + def testSkip( + self, + dataset_range: int, + count: int, + repetitions: int, + seed: Optional[int], + reshuffle_each_iteration: bool): + dataset = dataset_ops.Dataset.range(dataset_range) + dataset = dataset.skip(count) + dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) + if repetitions > 1: + dataset = dataset.repeat(repetitions) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) + + expected = list(range(count, dataset_range)) * repetitions + dataset_output = self.getDatasetOutput( + dataset, requires_initialization=True) + self.assertCountEqual(dataset_output, expected) + self.assertNotEqual(dataset_output, expected) + self.assertLen(dataset_output, self.evaluate(dataset.cardinality())) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(skip=[-2, -1]))) + def testNegativeSkip(self, skip: int): + dataset = dataset_ops.Dataset.range(10).skip(skip) + with self.assertRaises(errors.FailedPreconditionError): + dataset = global_shuffle_op._global_shuffle(dataset) + self.getDatasetOutput(dataset, requires_initialization=True) + + +class SkipGlobalShuffleCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_range=[10], + count=[0, 2], + repetitions=[1, 2], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False]))) + def testSkip( + self, + verify_fn: Callable[..., None], + dataset_range: int, + count: int, + repetitions: int, + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool): + def _build_dataset() -> dataset_ops.Dataset: + dataset = dataset_ops.Dataset.range(dataset_range) + dataset = dataset.skip(count) + dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) + if repetitions > 1: + dataset = dataset.repeat(repetitions) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=42, reshuffle_each_iteration=reshuffle_each_iteration) + options = options_lib.Options() + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=(dataset_range - count) * repetitions, + assert_items_equal=reshuffle_each_iteration, + ) + + if __name__ == "__main__": test.main() From 65d46501f75ecb6b564c5c156d889dd0a48269e3 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Tue, 19 Mar 2024 00:28:35 -0700 Subject: [PATCH 072/670] [xla:gpu] No need to distinguish operand vs. result slices for AddressComputationThunk Distinguishing between operand and result is only required when creating the embedded thunk, during `ExecuteOnStream` all we need is the list of buffers. PiperOrigin-RevId: 617077937 --- .../gpu/runtime/address_computation_thunk.cc | 198 ++++-------------- .../gpu/runtime/address_computation_thunk.h | 32 +-- .../runtime/address_computation_thunk_test.cc | 61 +++--- 3 files changed, 86 insertions(+), 205 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 28cf9163774ca5..3872683e70a75d 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -44,63 +44,38 @@ namespace gpu { AddressComputationThunk::AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, - std::vector> operands, - std::vector> results, + std::vector> arguments, std::vector>> - operand_offset_buffer_indices, - std::vector> operand_orig_shapes, - std::vector> operand_sliced_shapes, - std::vector>> - result_offset_buffer_indices, - std::vector> result_orig_shapes, - std::vector> result_sliced_shapes) + offset_buffer_indices, + std::vector> orig_shapes, + std::vector> sliced_shapes) : Thunk(Kind::kAddressComputation, thunk_info), embedded_thunk_(std::make_unique( ThunkInfo(thunk_info.op), std::move(*embedded_thunk))), - embedded_thunk_operands_(std::move(operands)), - embedded_thunk_results_(std::move(results)), - operand_offset_buffer_indices_(std::move(operand_offset_buffer_indices)), - operand_orig_shapes_(std::move(operand_orig_shapes)), - operand_sliced_shapes_(std::move(operand_sliced_shapes)), - result_offset_buffer_indices_(std::move(result_offset_buffer_indices)), - result_orig_shapes_(std::move(result_orig_shapes)), - result_sliced_shapes_(std::move(result_sliced_shapes)) {} + embedded_thunk_arguments_(std::move(arguments)), + offset_buffer_indices_(std::move(offset_buffer_indices)), + orig_shapes_(std::move(orig_shapes)), + sliced_shapes_(std::move(sliced_shapes)) {} absl::Status AddressComputationThunk::Prepare( const PrepareParams& params, ResourceRequests& resource_requests) { - auto num_operands = embedded_thunk_operands_.size(); - TF_RET_CHECK(num_operands == operand_offset_buffer_indices_.size()); - TF_RET_CHECK(num_operands == operand_orig_shapes_.size()); - TF_RET_CHECK(num_operands == operand_sliced_shapes_.size()); - for (unsigned i = 0; i < num_operands; ++i) { - if (operand_sliced_shapes_[i].has_value()) { - TF_RET_CHECK(embedded_thunk_operands_[i].has_value()); - TF_RET_CHECK(operand_offset_buffer_indices_[i].has_value()); - TF_RET_CHECK(operand_sliced_shapes_[i]->IsArray()); - TF_RET_CHECK(operand_orig_shapes_[i].has_value() && - operand_orig_shapes_[i]->IsArray()); - TF_RET_CHECK(operand_sliced_shapes_[i]->rank() == - operand_orig_shapes_[i]->rank()); - TF_RET_CHECK(operand_offset_buffer_indices_[i]->size() == - operand_orig_shapes_[i]->rank()); - } - } - - auto num_results = embedded_thunk_results_.size(); - TF_RET_CHECK(num_results == result_offset_buffer_indices_.size()); - TF_RET_CHECK(num_results == result_orig_shapes_.size()); - TF_RET_CHECK(num_results == result_sliced_shapes_.size()); - for (unsigned i = 0; i < num_results; ++i) { - if (result_sliced_shapes_[i].has_value()) { - TF_RET_CHECK(embedded_thunk_results_[i].has_value()); - TF_RET_CHECK(result_offset_buffer_indices_[i].has_value()); - TF_RET_CHECK(result_sliced_shapes_[i]->IsArray()); - TF_RET_CHECK(result_orig_shapes_[i].has_value() && - result_orig_shapes_[i]->IsArray()); - TF_RET_CHECK(result_sliced_shapes_[i]->rank() == - result_orig_shapes_[i]->rank()); - TF_RET_CHECK(result_offset_buffer_indices_[i]->size() == - result_orig_shapes_[i]->rank()); + auto num_arguments = embedded_thunk_arguments_.size(); + TF_RET_CHECK(num_arguments == offset_buffer_indices_.size()); + TF_RET_CHECK(num_arguments == orig_shapes_.size()); + TF_RET_CHECK(num_arguments == sliced_shapes_.size()); + for (auto [argument, offset_slice, orig_shape, sliced_shape] : + llvm::zip(embedded_thunk_arguments_, offset_buffer_indices_, + orig_shapes_, sliced_shapes_)) { + if (offset_slice.has_value()) { + TF_RET_CHECK(argument.has_value()); + TF_RET_CHECK(orig_shape.has_value()); + TF_RET_CHECK(sliced_shape.has_value()); + + TF_RET_CHECK(orig_shape->IsArray()); + TF_RET_CHECK(sliced_shape->IsArray()); + + TF_RET_CHECK(offset_slice->size() == orig_shape->rank()); + TF_RET_CHECK(sliced_shape->rank() == orig_shape->rank()); } } @@ -112,38 +87,17 @@ absl::Status AddressComputationThunk::Initialize( const InitializeParams& params) { TF_RETURN_IF_ERROR(embedded_thunk_->Initialize(params)); - unsigned operand_offset_count = 0; - for (auto maybe_shape : operand_sliced_shapes_) { - operand_offset_count += - (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); - } - - { - absl::MutexLock lock(&mutex_); - if (auto it = operand_offsets_.find(params.executor); - it == operand_offsets_.end()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, - params.executor->HostMemoryAllocate( - operand_offset_count * sizeof(int64_t))); - operand_offsets_.emplace(params.executor, std::move(allocation)); - } - } - - unsigned result_offset_count = 0; - for (auto maybe_shape : result_sliced_shapes_) { - result_offset_count += - (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); + unsigned offset_count = 0; + for (auto maybe_shape : sliced_shapes_) { + offset_count += (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank(); } - { - absl::MutexLock lock(&mutex_); - if (auto it = result_offsets_.find(params.executor); - it == result_offsets_.end()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, - params.executor->HostMemoryAllocate( - result_offset_count * sizeof(int64_t))); - result_offsets_.emplace(params.executor, std::move(allocation)); - } + absl::MutexLock lock(&mutex_); + if (auto it = offsets_.find(params.executor); it == offsets_.end()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr allocation, + params.executor->HostMemoryAllocate(offset_count * sizeof(int64_t))); + offsets_.emplace(params.executor, std::move(allocation)); } return absl::OkStatus(); @@ -155,28 +109,27 @@ absl::Status AddressComputationThunk::ExecuteOnStream( std::vector new_buffers; const BufferAllocations& orig_allocations = *params.buffer_allocations; - // Get memory allocation for copying operand offsets from device. - int64_t* operand_offsets_base = [&] { + // Get memory allocation for copying offsets from device. + int64_t* offsets_base = [&] { absl::MutexLock lock(&mutex_); - return reinterpret_cast( - operand_offsets_.at(stream.parent())->opaque()); + return reinterpret_cast(offsets_.at(stream.parent())->opaque()); }(); - for (unsigned i = 0; i < operand_offset_buffer_indices_.size(); ++i) { - if (embedded_thunk_operands_[i] == std::nullopt) { + for (unsigned i = 0; i < offset_buffer_indices_.size(); ++i) { + if (embedded_thunk_arguments_[i] == std::nullopt) { new_buffers.push_back(se::DeviceMemoryBase()); continue; } se::DeviceMemoryBase orig_operand = - orig_allocations.GetDeviceAddress(*embedded_thunk_operands_[i]); - if (operand_offset_buffer_indices_[i] == std::nullopt) { + orig_allocations.GetDeviceAddress(*embedded_thunk_arguments_[i]); + if (offset_buffer_indices_[i] == std::nullopt) { new_buffers.push_back(orig_operand); continue; } - const Shape& src_shape = *operand_orig_shapes_[i]; - const Shape& dst_shape = *operand_sliced_shapes_[i]; + const Shape& src_shape = *orig_shapes_[i]; + const Shape& dst_shape = *sliced_shapes_[i]; TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); std::vector slice_starts; @@ -184,10 +137,10 @@ absl::Status AddressComputationThunk::ExecuteOnStream( // Get offset for ith operand, which has `dst_shape.rank()` components. for (auto [idx, offset_slice] : - llvm::enumerate(*operand_offset_buffer_indices_[i])) { + llvm::enumerate(*offset_buffer_indices_[i])) { se::DeviceMemoryBase offset_src = orig_allocations.GetDeviceAddress(offset_slice); - int64_t* offset_dst = &operand_offsets_base[i + idx]; + int64_t* offset_dst = &offsets_base[i + idx]; // Copy the idx-th component of the ith offset from device to host. TF_RETURN_IF_ERROR( stream.Memcpy(offset_dst, offset_src, sizeof(int64_t))); @@ -203,7 +156,7 @@ absl::Status AddressComputationThunk::ExecuteOnStream( // Compute new slice. No need to copy the content to new buffers as we can // reuse the original buffers since slices are contiguous. int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); - BufferAllocation::Slice orig_slice = *embedded_thunk_operands_[i]; + BufferAllocation::Slice orig_slice = *embedded_thunk_arguments_[i]; int64_t new_offset = orig_slice.offset(); for (auto [start, stride] : @@ -214,65 +167,6 @@ absl::Status AddressComputationThunk::ExecuteOnStream( new_buffers.push_back(orig_operand.GetByteSlice(new_offset, new_size)); } - // Get memory allocation for copying result offsets from device. - int64_t* result_offsets_base = [&] { - absl::MutexLock lock(&mutex_); - return reinterpret_cast( - result_offsets_.at(stream.parent())->opaque()); - }(); - - for (unsigned i = 0; i < result_offset_buffer_indices_.size(); ++i) { - if (embedded_thunk_results_[i] == std::nullopt) { - new_buffers.push_back(se::DeviceMemoryBase()); - continue; - } - - se::DeviceMemoryBase orig_result = - orig_allocations.GetDeviceAddress(*embedded_thunk_results_[i]); - if (result_offset_buffer_indices_[i] == std::nullopt) { - new_buffers.push_back(orig_result); - continue; - } - - const Shape& src_shape = *result_orig_shapes_[i]; - const Shape& dst_shape = *result_sliced_shapes_[i]; - TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape)); - - std::vector slice_starts; - slice_starts.reserve(dst_shape.rank()); - - // Get offset for ith result, which has `dst_shape.rank()` components. - for (auto [idx, offset_slice] : - llvm::enumerate(*result_offset_buffer_indices_[i])) { - se::DeviceMemoryBase offset_src = - orig_allocations.GetDeviceAddress(offset_slice); - int64_t* offset_dst = &result_offsets_base[i + idx]; - // Copy the idx-th component of the ith offset from device to host. - TF_RETURN_IF_ERROR( - stream.Memcpy(offset_dst, offset_src, sizeof(int64_t))); - - if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { - return absl::InternalError(absl::StrFormat( - "Failed to retrieve all slice offset values on stream %p: %s", - &stream, blocked.message())); - } - slice_starts.push_back(*offset_dst); - } - - // Compute new slice. No need to copy the content to new buffers as we can - // reuse the original buffers since slices are contiguous. - int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); - BufferAllocation::Slice orig_slice = *embedded_thunk_results_[i]; - - int64_t new_offset = orig_slice.offset(); - for (auto [start, stride] : - llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { - new_offset += start * stride; - } - - new_buffers.push_back(orig_result.GetByteSlice(new_offset, new_size)); - } - // Safe to create a local BufferAllocations here since buffers are only slices // of bigger ones allocated elsewhere. BufferAllocations new_allocations(new_buffers, diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index b52b5fdfde861e..a08d5c19d0d47b 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -44,16 +44,11 @@ class AddressComputationThunk : public Thunk { public: AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, - std::vector> operands, - std::vector> results, + std::vector> arguments, std::vector>> - operand_offset_buffer_indices, - std::vector> operand_orig_shapes, - std::vector> operand_sliced_shapes, - std::vector>> - result_offset_buffer_indices, - std::vector> result_orig_shapes, - std::vector> result_sliced_shapes); + offset_buffer_indices, + std::vector> orig_shapes, + std::vector> sliced_shapes); AddressComputationThunk(const AddressComputationThunk&) = delete; AddressComputationThunk& operator=(const AddressComputationThunk&) = delete; @@ -66,26 +61,17 @@ class AddressComputationThunk : public Thunk { private: std::unique_ptr embedded_thunk_; std::vector> - embedded_thunk_operands_; - std::vector> - embedded_thunk_results_; - std::vector>> - operand_offset_buffer_indices_; - std::vector> operand_orig_shapes_; - std::vector> operand_sliced_shapes_; + embedded_thunk_arguments_; std::vector>> - result_offset_buffer_indices_; - std::vector> result_orig_shapes_; - std::vector> result_sliced_shapes_; + offset_buffer_indices_; + std::vector> orig_shapes_; + std::vector> sliced_shapes_; // Pinned host memory for transferring offset values from device to host. absl::Mutex mutex_; absl::flat_hash_map> - operand_offsets_ ABSL_GUARDED_BY(mutex_); - absl::flat_hash_map> - result_offsets_ ABSL_GUARDED_BY(mutex_); + offsets_ ABSL_GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index dc57a6447922e4..d2b2d48262ccc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -128,12 +128,13 @@ TEST(AddressComputationThunkTest, SlicedGemm) { slice_lhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), - std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {lhs_offsets, std::nullopt}, - {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt}, - {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt}, - {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, - {std::nullopt, std::nullopt}); + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}); // Step 2: // Execute address computation thunk. @@ -270,14 +271,15 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), - std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {lhs_offsets, rhs_offsets}, + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), - ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3})}, + ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), std::nullopt, + std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2})}, - {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, - {std::nullopt, std::nullopt}); + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), std::nullopt, + std::nullopt}); // Step 2: // Execute address computation thunk. @@ -418,14 +420,15 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { slice_rhs_offset_1}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), - std::make_unique(std::move(seq)), {slice_lhs, slice_rhs}, - {slice_out, slice_workspace}, {lhs_offsets, rhs_offsets}, + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), - ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1})}, + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1}), std::nullopt, + std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), - ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1})}, - {std::nullopt, std::nullopt}, {std::nullopt, std::nullopt}, - {std::nullopt, std::nullopt}); + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), std::nullopt, + std::nullopt}); // Step 2: // Execute address computation thunk. @@ -582,13 +585,12 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { slice_offset_0, slice_offset_1, slice_offset_2, slice_offset_3}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), - std::make_unique(std::move(seq)), {slice_src}, {slice_dst}, - {slice_offsets}, - {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8})}, + std::make_unique(std::move(seq)), {slice_src, slice_dst}, + {slice_offsets, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8}), std::nullopt}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) - {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8})}, {std::nullopt}, - {std::nullopt}, {std::nullopt}); + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8}), std::nullopt}); // Step 2: // Execute address computation thunk. @@ -739,15 +741,14 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { slice_dst_offset_3}; AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), - std::make_unique(std::move(seq)), {slice_src}, {slice_dst}, - {slice_src_offsets}, - {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2})}, + std::make_unique(std::move(seq)), {slice_src, slice_dst}, + {slice_src_offsets, slice_dst_offsets}, + {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) - {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}, - {slice_dst_offsets}, - {{ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}}, - {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}); + {ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2}), + ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})}); // Step 2: // Execute address computation thunk. From a5ec72acbf058dfe016f97f49b27e9b4668d48da Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Mar 2024 02:02:27 -0700 Subject: [PATCH 073/670] compat: Update forward compatibility horizon to 2024-03-19 PiperOrigin-RevId: 617095822 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 813819ae0aec8d..ef8e89811f42bb 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 18) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 19) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 5b4e1879bd6799140f91a856082f0f49a3d989fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Mar 2024 02:02:35 -0700 Subject: [PATCH 074/670] Update GraphDef version to 1806. PiperOrigin-RevId: 617095844 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index b199c37ee80142..0d3b39573bb785 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1805 // Updated: 2024/3/18 +#define TF_GRAPH_DEF_VERSION 1806 // Updated: 2024/3/19 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From c9c2f388759e39ea4dd09e90ece46cc6278c336f Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Tue, 19 Mar 2024 02:43:09 -0700 Subject: [PATCH 075/670] #shlo_ref Add std compatible member functions to `Shape`. PiperOrigin-RevId: 617104178 --- tensorflow/lite/experimental/shlo/shape.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tensorflow/lite/experimental/shlo/shape.h b/tensorflow/lite/experimental/shlo/shape.h index 262c7b0bc1e901..72a322299972d1 100644 --- a/tensorflow/lite/experimental/shlo/shape.h +++ b/tensorflow/lite/experimental/shlo/shape.h @@ -66,6 +66,24 @@ class Shape { // and possible confusion with C++ container's usage of size(). DimensionSize NumElements() const; + // The following members are provided for compatibility with the standard + // library. + using value_type = DimensionSize; + + const value_type& operator[](int dim) const { return dims_[dim]; } + value_type& operator[](int dim) { return dims_[dim]; } + + auto cbegin() const { return dims_.begin(); } + auto begin() const { return dims_.begin(); } + auto begin() { return dims_.begin(); } + auto cend() const { return dims_.end(); } + auto end() const { return dims_.end(); } + auto end() { return dims_.end(); } + bool empty() const { return dims_.empty(); } + size_t size() const { return dims_.size(); } + const value_type* data() const { return dims_.data(); } + value_type* data() { return dims_.data(); } + private: absl::InlinedVector dims_; }; From 89ffb3c7f16df4c2e72896465c0200ff6ac9e735 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 19 Mar 2024 03:01:54 -0700 Subject: [PATCH 076/670] Support all transposes. PiperOrigin-RevId: 617107984 --- .../xla/xla/service/gpu/fusions/fusions.cc | 2 +- .../fusions/mlir/computation_partitioner.cc | 18 +++--- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 31 ++++++++-- .../mlir/elemental_hlo_to_mlir_test.cc | 21 +++++++ .../xla/service/gpu/fusions/transpose_mlir.cc | 52 ++++++++-------- .../xla/service/gpu/fusions/transpose_mlir.h | 9 +-- .../gpu/fusions/transpose_mlir_test.cc | 59 ++++++++++++++++++- 7 files changed, 146 insertions(+), 46 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 6dd5bf20d80152..5037dd676e0171 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -212,7 +212,7 @@ absl::StatusOr> GetFusionEmitter( return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - if (check_mlir_emitters(MlirTransposeFusion::IsSupported)) { + if (check_mlir_emitters(nullptr)) { return std::make_unique(analysis); } return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc index 472efdb0197501..ad46b914a9e484 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -341,15 +341,19 @@ mlir::func::FuncOp CreateSubgraphMlirFunction( return *ConvertPrimitiveTypeToMlirType(shape.element_type(), b); }; - const xla::Shape* one_root_shape = nullptr; + const xla::Shape* first_root_shape = nullptr; for (auto* root : subgraph.roots) { if (root->shape().IsTuple()) { for (auto& shape : root->shape().tuple_shapes()) { - one_root_shape = &shape; + if (!first_root_shape) { + first_root_shape = &shape; + } result_types.push_back(element_type(shape)); } } else { - one_root_shape = &root->shape(); + if (!first_root_shape) { + first_root_shape = &root->shape(); + } result_types.push_back(element_type(root->shape())); } } @@ -362,13 +366,13 @@ mlir::func::FuncOp CreateSubgraphMlirFunction( parameter_types.push_back(TensorShapeToMlirType(param->shape(), b)); arg_attrs.emplace_back(); } - for (int dim = 0; dim < one_root_shape->rank(); ++dim) { + for (int dim = 0; dim < first_root_shape->rank(); ++dim) { parameter_types.push_back(b.getIndexType()); arg_attrs.emplace_back(mlir::DictionaryAttr::get( b.getContext(), - {b.getNamedAttr( - "xla.range", - b.getIndexArrayAttr({0, one_root_shape->dimensions(dim) - 1}))})); + {b.getNamedAttr("xla.range", + b.getIndexArrayAttr( + {0, first_root_shape->dimensions(dim) - 1}))})); } // Populate arguments for injected parameters (values that are computed diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 8e2124dbe0e87b..b9169f82207ae2 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -187,12 +187,14 @@ bool IsUnsupportedTuple(const HloInstruction* instr) { return true; } - // All tuple elements must have the same dimensions (element types may - // differ). + // All tuple elements must have bitcast-compatible dimensions (element types + // may differ). auto first_shape = instr->shape().tuple_shapes(0); for (int i = 1; i < instr->operand_count(); ++i) { - if (instr->shape().tuple_shapes(i).dimensions() != - first_shape.dimensions()) { + const auto& tuple_shape = instr->shape().tuple_shapes(i); + if (!ShapeUtil::EqualIgnoringElementType(tuple_shape, first_shape) && + !ShapeUtil::IsReshapeOrTransposeBitcast(tuple_shape, first_shape, + /*ignore_element_type=*/true)) { return true; } } @@ -544,6 +546,8 @@ Value ApplyAffineExpr(mlir::AffineExpr expr, ValueRange dims, SmallVector ApplyAffineMap(mlir::AffineMap map, ValueRange dims, ValueRange symbols, ImplicitLocOpBuilder& b) { + CHECK_EQ(map.getNumDims(), dims.size()); + CHECK_EQ(map.getNumSymbols(), symbols.size()); SmallVector result; result.reserve(map.getNumResults()); for (auto expr : map.getResults()) { @@ -606,6 +610,7 @@ absl::StatusOr> HloToMlir( result_element_type = sign_converter.convertType(element_mlir_type); } + IndexingContext indexing_context(builder.getContext()); // Handle ops that aren't elementwise and aren't just indexing // transformations. switch (instr->opcode()) { @@ -648,11 +653,26 @@ absl::StatusOr> HloToMlir( builder); case HloOpcode::kTuple: { CHECK(!IsUnsupportedTuple(instr)); + const auto& first_shape = instr->shape().tuple_shapes(0); + CHECK_EQ(first_shape.rank(), indices.size()) + << "Indices for tuple must be for the first tuple element"; SmallVector operands; for (int i = 0; i < instr->operand_count(); ++i) { + llvm::SmallVector operand_indices; + // The tuple shapes only need to be bitcast compatible, so insert + // bitcasts where necessary. + if (i > 0 && !ShapeUtil::EqualIgnoringElementType( + first_shape, instr->operand(i)->shape())) { + auto operand_map = GetBitcastMap( + first_shape, instr->operand(i)->shape(), &indexing_context); + operand_indices = + ApplyAffineMap(operand_map.GetAffineMap(), indices, {}, builder); + } else { + operand_indices = indices; + } TF_ASSIGN_OR_RETURN( operands.emplace_back(), - GetSingleOperandValue(operand_provider, instr, i, indices)); + GetSingleOperandValue(operand_provider, instr, i, operand_indices)); } return operands; } @@ -675,7 +695,6 @@ absl::StatusOr> HloToMlir( operand->shape().element_type(), builder)); arg_types.push_back(operand_element_type); } - IndexingContext indexing_context(builder.getContext()); auto input_indices = GetInputIndices(ComputeOutputToInputIndexing(instr, 0, &indexing_context), indices, builder); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index cd07c607d11f93..f326b06113e0f5 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -728,6 +728,27 @@ TEST_F(ElementalHloToMlirTest, IotaComplex) { )")); } +TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + %p0 = f32[10,10] parameter(0) + %p1 = f32[100] parameter(1) + ROOT tuple = (f32[10,10], f32[100]) tuple(%p0, %p1) + })", + R"( + // CHECK: @main_tuple( + // CHECK-SAME: %[[P0:.*]]: tensor<10x10xf32>, + // CHECK-SAME: %[[P1:.*]]: tensor<100xf32>, + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] + // CHECK: %[[IDX:.*]] = affine.apply + // CHECK-SAME: affine_map<()[s0, s1] -> (s0 * 10 + s1)>() + // CHECK-SAME: [%[[X]], %[[Y]]] + // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] + // CHECK: return %[[A]], %[[B]] + )")); +} + } // namespace } // namespace mlir_converter } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 4b8a2af5661935..ba41af491180a6 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -138,26 +138,9 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) } } -/*static*/ bool MlirTransposeFusion::IsSupported( - const HloFusionAnalysis& analysis) { - // If there is a hero, which does not have a transpose, the codegen might - // fail because of the incorrect thread ID mapping for that particular case. - return GetShMemTransposes(analysis).size() == analysis.fusion_heroes().size(); -} - std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( int64_t root_index, IndexingContext* indexing_context) const { const auto& hero = *analysis_.fusion_heroes()[root_index]; - const auto& root = *analysis_.fusion_roots()[root_index]; - if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { - // Non-transpose roots are elementwise by definition. - return ComputeThreadIdToInputIndexing(root_index, 0, indexing_context); - } - return ComputeThreadIdToOutputIndexing(hero, indexing_context); -} - -IndexingMap MlirTransposeFusion::ComputeThreadIdToOutputIndexing( - const HloInstruction& hero, IndexingContext* indexing_context) const { // The block offsets are permuted, but the thread offsets remain the same. auto* mlir_context = indexing_context->GetMLIRContext(); auto block_offset = GetBlockOffsetsForTiling(tiling_, mlir_context) @@ -187,6 +170,20 @@ IndexingMap MlirTransposeFusion::ComputeThreadIdToInputIndexing( return map; } +std::optional MlirTransposeFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + IndexingContext* indexing_context) const { + const auto& hero = *analysis_.fusion_heroes()[root_index]; + const auto& root = *analysis_.fusion_roots()[root_index]; + if (!GetDescriptionForTiledTransposeEmitter(root, hero)) { + // Non-transpose roots are elementwise by definition. + return ComputeThreadIdToOutputIndexing(root_index, indexing_context); + } + + return ComputeThreadIdToInputIndexing(*analysis_.fusion_heroes()[root_index], + indexing_context); +} + LaunchDimensions MlirTransposeFusion::launch_dimensions() const { return LaunchDimensions(tiling_.GetNumBlocks(), tiling_.GetNumThreadsPerBlock()); @@ -298,12 +295,11 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( IndexingContext indexing_context{mlir_context}; ValueRange output_tensor_args = entry_function.getArguments().drop_front(num_inputs); - auto output_indexing = ComputeThreadIdToOutputIndexing( - *shmem_transposes_.front(), &indexing_context); + auto output_indexing = *ComputeThreadIdToOutputIndexing(0, &indexing_context); auto shmem_output_indexing = GetSharedMemoryReadIndexingMap(output_indexing, permutation_[2]); auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( - shmem_transposes_.front(), &indexing_context); + analysis_.fusion_heroes()[0], &indexing_context); auto root_indexing = ComposeIndexingMaps(output_indexing, epilogue_indexing); auto result_tensors = EmitThreadLoopNest( builder, output_tensor_args, output_indexing, @@ -324,9 +320,19 @@ absl::Status MlirTransposeFusion::EmitReadFromShMemMlir( root_indices, builder); SmallVector results; results.reserve(output_tensor_args.size()); - for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { - results.push_back( - builder.create(value, tensor, root_indices)); + const auto& first_shape = analysis_.fusion_roots().front()->shape(); + for (auto [tensor, value, root] : llvm::zip( + output_tensors, result_scalars, analysis_.fusion_roots())) { + llvm::SmallVector indices; + if (ShapeUtil::EqualIgnoringElementType(first_shape, root->shape())) { + indices = root_indices; + } else { + auto bitcast_map = + GetBitcastMap(first_shape, root->shape(), &indexing_context); + indices = ApplyAffineMap(bitcast_map.GetAffineMap(), root_indices, + {}, builder); + } + results.push_back(builder.create(value, tensor, indices)); } return results; }); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index fd9f5863e8260e..3eb6e6fef98a74 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -51,23 +51,16 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { explicit MlirTransposeFusion(const HloFusionAnalysis& analysis); LaunchDimensions launch_dimensions() const override; - static bool IsSupported(const HloFusionAnalysis& analysis); - std::optional ComputeThreadIdToOutputIndexing( int64_t root_index, IndexingContext* indexing_context) const override; std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - IndexingContext* indexing_context) const override { - return ComputeThreadIdToInputIndexing( - *analysis_.fusion_heroes()[root_index], indexing_context); - } + IndexingContext* indexing_context) const override; protected: IndexingMap ComputeThreadIdToInputIndexing( const HloInstruction& hero, IndexingContext* indexing_context) const; - IndexingMap ComputeThreadIdToOutputIndexing( - const HloInstruction& hero, IndexingContext* indexing_context) const; absl::Status EmitEntryFunction( const mlir_converter::PartitionedComputations& computations, diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index 38fe0789b8eadf..bbffea39df1042 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -251,7 +251,6 @@ TEST_F(MlirTransposeFusionTest, Transpose021_NoEpilogue) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x32xf32> // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] @@ -285,6 +284,7 @@ TEST_F(MlirTransposeFusionTest, Transpose_4D) { calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -303,6 +303,7 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D) { calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -328,6 +329,7 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D_2) { ROOT %fusion = f32[2820,17]{1,0} fusion(%p0, %p1), kind=kInput, calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -352,6 +354,7 @@ TEST_F(MlirTransposeFusionTest, MultipleRootsForTranspose) { fusion(), kind=kInput, calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -369,6 +372,60 @@ TEST_F(MlirTransposeFusionTest, PartialTile) { ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, MixedIndexing) { + auto kHloString = R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,6,4] parameter(0) + %bc = f64[24,2,24] bitcast(%p0) + %t1 = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %t2 = f64[24,2,24] transpose(%bc), dimensions={2,1,0} + %p1 = f64[] parameter(1) + %bc1 = f64[6,4,2,24] broadcast(%p1), dimensions={} + %bc2 = f64[24,2,24] broadcast(%p1), dimensions={} + %a1 = f64[6,4,2,24] add(%t1, %bc1) + %a2 = f64[24,2,24] add(%t2, %bc2) + ROOT %t = (f64[6,4,2,24], f64[24,2,24]) tuple(%a1, %a2) + } + + ENTRY main { + %p0 = f64[24,2,6,4] parameter(0) + %p1 = f64[] parameter(1) + ROOT %fusion = (f64[6,4,2,24], f64[24,2,24]) fusion(%p0, %p1), + kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, SideOutputs) { + auto kHloString = R"( + HloModule m + + fused_computation { + %p0 = f64[24,2,36] parameter(0) + %p1 = f64[36,2,24] parameter(1) + %tr = f64[36,2,24] transpose(%p0), dimensions={2,1,0} + %neg = f64[36,2,24] negate(%p1) + %log = f64[24,2,36] log(%p0) + ROOT %t = (f64[36,2,24], f64[36,2,24], f64[24,2,36]) + tuple(%neg, %tr, %log) + } + + ENTRY main { + %p0 = f64[24,2,36] parameter(0) + %p1 = f64[36,2,24] parameter(1) + ROOT %fusion = (f64[36,2,24], f64[36,2,24], f64[24,2,36]) + fusion(%p0, %p1), kind=kInput, calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } From 89572357dada17656848bb77824d834ecb225e25 Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Tue, 19 Mar 2024 03:02:49 -0700 Subject: [PATCH 077/670] PR #10261: [ROCm] ConvBfloat16Support HLO pass for AMDGPU Compiler Imported from GitHub PR https://github.com/openxla/xla/pull/10261 Copybara import of the project: -- 0568134a7d3108f1c29794f536f3acbfe238dff1 by Pavel Emeliyanenko : added ConvBfloat16Support HLO pass -- 3106f99a8d75e1d891053f3a2b3ee1a46c29f5db by Harsha HS : fix typo Merging this change closes #10261 PiperOrigin-RevId: 617108219 --- .../xla/xla/service/gpu/amdgpu_compiler.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index 723585608b5c55..f429e20f27d1ac 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/convert_mover.h" #include "xla/service/dot_dimension_merger.h" +#include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuner_util.h" #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" @@ -61,6 +62,34 @@ limitations under the License. namespace xla { namespace gpu { +namespace { + +struct ConvBfloat16Support : public FloatSupport { + explicit ConvBfloat16Support(const se::RocmComputeCapability& rocm) + : FloatSupport(BF16), + // TODO: MIOpen does not support bf16 convolutions yet + is_conv_bf16_supported_(rocm.has_bf16_dtype_support()) {} + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + // Skip all HLOs other than convolutions. + return (hlo.opcode() != HloOpcode::kConvolution); + } + + private: + bool is_conv_bf16_supported_; +}; + +} // namespace + absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, @@ -71,6 +100,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + // Convert unsupported bf16 convolutions to f32. + ConvBfloat16Support conv_bf16_support( + std::get(gpu_version)); + pipeline.AddPass(&conv_bf16_support); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); From eebacf22187ea45c85e8581a0525f0a238ccfeb8 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Tue, 19 Mar 2024 03:12:30 -0700 Subject: [PATCH 078/670] #shlo_ref Add typedefs for the tensor (element) type variant. PiperOrigin-RevId: 617110499 --- tensorflow/lite/experimental/shlo/tensor.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/experimental/shlo/tensor.h b/tensorflow/lite/experimental/shlo/tensor.h index 6904ad92db4689..57029105d1a218 100644 --- a/tensorflow/lite/experimental/shlo/tensor.h +++ b/tensorflow/lite/experimental/shlo/tensor.h @@ -33,8 +33,10 @@ constexpr TensorElementType BaselineType(TensorElementType type) { return type; } -std::variant BaselineType( - const std::variant& type); +using TensorElementTypeVariant = + std::variant; + +TensorElementTypeVariant BaselineType(const TensorElementTypeVariant& type); struct TensorType { Shape shape; @@ -46,6 +48,8 @@ struct QuantizedTensorType { QuantizedTensorElementType element_type; }; +using TensorTypeVariant = std::variant; + struct Tensor { const Shape& shape() const; Shape& shape(); @@ -69,8 +73,7 @@ struct Tensor { const TensorElementType& tensor_element_type() const; const QuantizedTensorElementType& quantized_tensor_element_type() const; - std::variant element_type() - const; + TensorElementTypeVariant element_type() const; template ::Type> T* GetDataAs() { @@ -88,7 +91,7 @@ struct Tensor { static_cast(NumElements())); } - std::variant type; + TensorTypeVariant type; // If type is TensorType, the type should be Storage::Type. // If type is QuantizedTensorType, the type should be From 1bde09fabaa39ecc09beb29aaf706e8adb346ac2 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Tue, 19 Mar 2024 03:23:19 -0700 Subject: [PATCH 079/670] #shlo_ref Refactor unary element wise op tests. PiperOrigin-RevId: 617112557 --- tensorflow/lite/experimental/shlo/data_type.h | 9 +- tensorflow/lite/experimental/shlo/ops/BUILD | 29 +- tensorflow/lite/experimental/shlo/ops/abs.cc | 26 +- .../lite/experimental/shlo/ops/abs_test.cc | 85 ++---- tensorflow/lite/experimental/shlo/ops/cbrt.cc | 23 +- .../lite/experimental/shlo/ops/cbrt_test.cc | 66 ++--- tensorflow/lite/experimental/shlo/ops/ceil.cc | 22 +- .../lite/experimental/shlo/ops/ceil_test.cc | 66 ++--- .../lite/experimental/shlo/ops/cosine.cc | 23 +- .../lite/experimental/shlo/ops/cosine_test.cc | 67 ++--- .../lite/experimental/shlo/ops/test_util.h | 174 ++++++++++-- .../shlo/ops/unary_elementwise_test.cc | 2 +- .../shlo/ops/unary_elementwise_test_util.h | 250 ++++++++++++++++++ tensorflow/lite/experimental/shlo/ops/util.cc | 37 +++ tensorflow/lite/experimental/shlo/ops/util.h | 52 ++++ .../lite/experimental/shlo/status_matcher.h | 5 +- 16 files changed, 638 insertions(+), 298 deletions(-) create mode 100644 tensorflow/lite/experimental/shlo/ops/unary_elementwise_test_util.h diff --git a/tensorflow/lite/experimental/shlo/data_type.h b/tensorflow/lite/experimental/shlo/data_type.h index 8e8fe2d6202911..f313fdc175ce58 100644 --- a/tensorflow/lite/experimental/shlo/data_type.h +++ b/tensorflow/lite/experimental/shlo/data_type.h @@ -95,10 +95,17 @@ using StorageType = typename Storage::Type; constexpr bool IsBool(DataType data_type) { return data_type == DataType::kI1; } -constexpr bool IsInteger(DataType data_type) { +constexpr bool IsSignedInteger(DataType data_type) { return data_type == DataType::kSI4 || data_type == DataType::kSI8 || data_type == DataType::kSI16 || data_type == DataType::kSI32; } + +constexpr bool IsUnsignedInteger(DataType data_type) { return false; } + +constexpr bool IsInteger(DataType data_type) { + return IsSignedInteger(data_type) || IsUnsignedInteger(data_type); +} + constexpr bool IsFloat(DataType data_type) { return data_type == DataType::kBF16 || data_type == DataType::kF16 || data_type == DataType::kF32; diff --git a/tensorflow/lite/experimental/shlo/ops/BUILD b/tensorflow/lite/experimental/shlo/ops/BUILD index 45eaa7771807e2..15bcab95773b55 100644 --- a/tensorflow/lite/experimental/shlo/ops/BUILD +++ b/tensorflow/lite/experimental/shlo/ops/BUILD @@ -70,7 +70,9 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + "//tensorflow/lite/experimental/shlo:data_type", "//tensorflow/lite/experimental/shlo:shape", + "//tensorflow/lite/experimental/shlo:tensor", "@com_google_absl//absl/status", ], ) @@ -118,7 +120,6 @@ cc_test( "//tensorflow/lite/experimental/shlo:status_matcher", "//tensorflow/lite/experimental/shlo:tensor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", "@com_google_googletest//:gtest_main", ], ) @@ -129,13 +130,30 @@ cc_library( hdrs = ["test_util.h"], deps = [ "//tensorflow/lite/experimental/shlo:data_type", + "//tensorflow/lite/experimental/shlo:quantized_tensor_element_type", "//tensorflow/lite/experimental/shlo:shape", + "//tensorflow/lite/experimental/shlo:tensor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_googletest//:gtest", ], ) +cc_library( + name = "unary_elementwise_test_util", + testonly = True, + hdrs = ["unary_elementwise_test_util.h"], + deps = [ + ":test_util", + "//tensorflow/lite/experimental/shlo:data_type", + "//tensorflow/lite/experimental/shlo:shape", + "//tensorflow/lite/experimental/shlo:status_matcher", + "//tensorflow/lite/experimental/shlo:tensor", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "abs", srcs = ["abs.cc"], @@ -156,6 +174,7 @@ cc_test( deps = [ ":abs", ":test_util", + ":unary_elementwise_test_util", "//tensorflow/lite/experimental/shlo:quantize", "//tensorflow/lite/experimental/shlo:quantized_tensor_element_type", "//tensorflow/lite/experimental/shlo:shape", @@ -173,7 +192,6 @@ cc_library( ":unary_elementwise", ":util", "//tensorflow/lite/experimental/shlo:bf16", - "//tensorflow/lite/experimental/shlo:data_type", "//tensorflow/lite/experimental/shlo:dispatch", "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:tensor", @@ -188,6 +206,7 @@ cc_test( deps = [ ":cbrt", ":test_util", + ":unary_elementwise_test_util", "//tensorflow/lite/experimental/shlo:bf16", "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:quantize", @@ -195,7 +214,6 @@ cc_test( "//tensorflow/lite/experimental/shlo:shape", "//tensorflow/lite/experimental/shlo:status_matcher", "//tensorflow/lite/experimental/shlo:tensor", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -223,6 +241,7 @@ cc_test( deps = [ ":ceil", ":test_util", + ":unary_elementwise_test_util", "//tensorflow/lite/experimental/shlo:bf16", "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:quantize", @@ -230,7 +249,6 @@ cc_test( "//tensorflow/lite/experimental/shlo:shape", "//tensorflow/lite/experimental/shlo:status_matcher", "//tensorflow/lite/experimental/shlo:tensor", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) @@ -244,7 +262,6 @@ cc_library( ":unary_elementwise", ":util", "//tensorflow/lite/experimental/shlo:bf16", - "//tensorflow/lite/experimental/shlo:data_type", "//tensorflow/lite/experimental/shlo:dispatch", "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:tensor", @@ -258,6 +275,7 @@ cc_test( deps = [ ":cosine", ":test_util", + ":unary_elementwise_test_util", "//tensorflow/lite/experimental/shlo:bf16", "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:quantize", @@ -265,7 +283,6 @@ cc_test( "//tensorflow/lite/experimental/shlo:shape", "//tensorflow/lite/experimental/shlo:status_matcher", "//tensorflow/lite/experimental/shlo:tensor", - "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/experimental/shlo/ops/abs.cc b/tensorflow/lite/experimental/shlo/ops/abs.cc index 8f8f7415b3197b..dd92713df972ed 100644 --- a/tensorflow/lite/experimental/shlo/ops/abs.cc +++ b/tensorflow/lite/experimental/shlo/ops/abs.cc @@ -33,32 +33,26 @@ AbsOp Create(typename AbsOp::Attributes) { return AbsOp{}; } absl::Status Prepare(AbsOp& op, const Tensor& input, Tensor& output) { SHLO_REF_RETURN_ON_ERROR(Propagate(input.shape(), output.shape())); - if (BaselineType(input.element_type()) != - BaselineType(output.element_type())) { - return absl::FailedPreconditionError( - "stablehlo.abs constraint (C2) is not satisfied (incompatible baseline " - "types.)."); - } + SHLO_REF_RETURN_ON_ERROR(CheckSupportedTypes(CheckCtx("abs"), input, + IsSignedIntTensor, IsFloatTensor, + IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("abs"), input, output)); return absl::OkStatus(); } absl::Status Evaluate(AbsOp& op, const Tensor& input, Tensor& output) { Abs abs; - if (input.IsPerAxisQuantized()) { - DISPATCH_QUANTIZED(detail::DequantizeOpQuantizePerChannel, - input.quantized_tensor_element_type().StorageType(), - input.quantized_tensor_element_type().ExpressedType(), - abs, input, output); - } else if (input.IsPerTensorQuantized()) { + if (input.IsPerTensorQuantized()) { DISPATCH_QUANTIZED(detail::DequantizeOpQuantizePerTensor, input.quantized_tensor_element_type().StorageType(), input.quantized_tensor_element_type().ExpressedType(), abs, input, output) - } else { - DISPATCH_BOOL_INT_FLOAT(detail::EvaluateNoQuantization, - input.tensor_element_type(), abs, input, output); + } else if (IsSignedIntTensor(input) || IsFloatTensor(input)) { + DISPATCH_INT_FLOAT(detail::EvaluateNoQuantization, + input.tensor_element_type(), abs, input, output); } - return absl::OkStatus(); + return absl::FailedPreconditionError("Unsupported tensor type."); } } // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/abs_test.cc b/tensorflow/lite/experimental/shlo/ops/abs_test.cc index 66972cabe9a6bf..0e3962825c56d6 100644 --- a/tensorflow/lite/experimental/shlo/ops/abs_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/abs_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/shlo/ops/abs.h" -#include -#include +#include #include #include #include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/unary_elementwise_test_util.h" #include "tensorflow/lite/experimental/shlo/quantize.h" #include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" #include "tensorflow/lite/experimental/shlo/shape.h" @@ -30,6 +30,11 @@ using testing::ElementsAreArray; namespace shlo_ref { +template <> +struct ParamName { + static std::string Get() { return "Abs"; } +}; + namespace { constexpr struct AbsRef { @@ -39,12 +44,25 @@ constexpr struct AbsRef { } } abs_ref; +INSTANTIATE_TYPED_TEST_SUITE_P(Abs, UnaryElementwiseOpShapePropagationTest, + AbsOp, TestParamNames); + +INSTANTIATE_TYPED_TEST_SUITE_P( + Abs, UnaryElementwiseSameBaselineElementTypeConstraintTest, + UnaryElementwiseConstraint1Types, TestParamNames); + +using UnsupportedTypes = + WithOpTypes>; + +INSTANTIATE_TYPED_TEST_SUITE_P(Abs, UnaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); + template struct AbsTest : ::testing::Test {}; -TYPED_TEST_SUITE(AbsTest, NonQuantizedTestTypes, TestParamNames); +TYPED_TEST_SUITE(AbsTest, ArithmeticTestTypes, TestParamNames); -TYPED_TEST(AbsTest, NonQuantized) { +TYPED_TEST(AbsTest, ArithmeticTensorsWork) { using StorageT = typename TypeParam::StorageT; const Shape shape({2, 3, 4}); @@ -107,64 +125,5 @@ TYPED_TEST(QuantizedAbsTest, QuantizedPerTensor) { EXPECT_THAT(output_data, ElementsAreArray(expected_data)); } -TYPED_TEST(QuantizedAbsTest, QuantizedPerAxis) { - using StorageT = typename TypeParam::StorageT; - using ExpressedT = typename TypeParam::ExpressedT; - - const Shape shape({4, 3, 2}); - const int quantized_dimension = 2; - const size_t rank = shape.Rank(); - const Axis quantized_dimension_size = shape.Dim(quantized_dimension); - const size_t quantization_stride = [&] { - size_t res = 1; - for (int64_t i = rank - 1; i > quantized_dimension; --i) { - res *= shape.Dim(i); - } - return res; - }(); - Vector input_data = IotaBuffer(shape); - Vector output_data(shape.NumElements()); - Vector zero_points_data = RandomBuffer( - /*shape=*/Shape({shape.Dim(2)}), /*min=*/static_cast(-5), - /*max=*/static_cast(5)); - Vector scales_data = RandomBuffer( - /*shape=*/Shape({shape.Dim(2)}), /*min=*/static_cast(1), - /*max=*/static_cast(3)); - const QuantizedTensorElementType tensor_type = - QuantizedTensorElementType::PerAxis( - scales_data, zero_points_data, quantized_dimension); - Tensor input_tensor{ - .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, - .data = input_data.data()}; - Tensor output_tensor{ - .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, - .data = output_data.data()}; - - Vector expected_data(shape.NumElements()); - absl::c_transform( - input_data, expected_data.begin(), - [&, element_index = 0ull, quantization_index = 0ull](auto v) mutable { - const StorageT zero_point = zero_points_data[quantization_index]; - const ExpressedT scale = scales_data[quantization_index]; - - if (++element_index >= quantization_stride) { - element_index = 0; - if (++quantization_index >= quantized_dimension_size) { - quantization_index = 0; - } - } - const ExpressedT dequantized_input = Dequantize(v, zero_point, scale); - const ExpressedT dequantized_res = abs_ref(dequantized_input); - return Quantize( - dequantized_res, zero_point, ExpressedT(1) / scale); - }); - - auto op = Create(AbsOp::Attributes{}); - ASSERT_OK(Prepare(op, input_tensor, output_tensor)); - ASSERT_OK(Evaluate(op, input_tensor, output_tensor)); - EXPECT_THAT(output_data, ElementsAreArray(expected_data)); -} - } // namespace } // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/cbrt.cc b/tensorflow/lite/experimental/shlo/ops/cbrt.cc index 2a526292829363..2e50c92c2e5998 100644 --- a/tensorflow/lite/experimental/shlo/ops/cbrt.cc +++ b/tensorflow/lite/experimental/shlo/ops/cbrt.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/lite/experimental/shlo/bf16.h" -#include "tensorflow/lite/experimental/shlo/data_type.h" #include "tensorflow/lite/experimental/shlo/dispatch.h" #include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/ops/unary_elementwise.h" @@ -49,20 +48,10 @@ CbrtOp Create(CbrtOp::Attributes) { return {}; } absl::Status Prepare(CbrtOp& op, const Tensor& input, Tensor& output) { SHLO_REF_RETURN_ON_ERROR(Propagate(input.shape(), output.shape())); - if (!input.IsQuantized() && IsInteger(input.StorageType())) { - return absl::FailedPreconditionError( - "stablehlo.cbrt does not support integer tensor types."); - } - if (input.IsPerAxisQuantized()) { - return absl::FailedPreconditionError( - "stablehlo.cbrt does not support per axis quantization."); - } - if (BaselineType(input.element_type()) != - BaselineType(output.element_type())) { - return absl::FailedPreconditionError( - "stablehlo.cbrt constraint (C1) is not satisfied (incompatible " - "baseline types)."); - } + SHLO_REF_RETURN_ON_ERROR(CheckSupportedTypes( + CheckCtx("cbrt"), input, IsFloatTensor, IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("cbrt"), input, output)); return absl::OkStatus(); } @@ -73,11 +62,11 @@ absl::Status Evaluate(CbrtOp& op, const Tensor& input, Tensor& output) { input.quantized_tensor_element_type().StorageType(), input.quantized_tensor_element_type().ExpressedType(), cbrt, input, output) - } else { + } else if (IsFloatTensor(input)) { DISPATCH_FLOAT(detail::EvaluateNoQuantization, input.tensor_element_type(), cbrt, input, output); } - return absl::OkStatus(); + return absl::FailedPreconditionError("Unsupported tensor type."); } }; // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc b/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc index 1c8ae75845c0bb..687e3cb7debb15 100644 --- a/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc @@ -16,26 +16,31 @@ limitations under the License. #include "tensorflow/lite/experimental/shlo/ops/cbrt.h" #include +#include #include #include -#include "absl/status/status.h" #include "tensorflow/lite/experimental/shlo/bf16.h" #include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/unary_elementwise_test_util.h" #include "tensorflow/lite/experimental/shlo/quantize.h" #include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" #include "tensorflow/lite/experimental/shlo/shape.h" #include "tensorflow/lite/experimental/shlo/status_matcher.h" #include "tensorflow/lite/experimental/shlo/tensor.h" -using shlo_ref::testing::StatusIs; using testing::ElementsAreArray; using testing::NanSensitiveFloatEq; using testing::Pointwise; namespace shlo_ref { +template <> +struct ParamName { + static std::string Get() { return "Cbrt"; } +}; + namespace { struct Cbrt { @@ -55,36 +60,25 @@ struct Cbrt { } } cbrt_ref; -template -struct NonQuantizedIntCbrtTest : ::testing::Test {}; +INSTANTIATE_TYPED_TEST_SUITE_P(Cbrt, UnaryElementwiseOpShapePropagationTest, + CbrtOp, TestParamNames); -TYPED_TEST_SUITE(NonQuantizedIntCbrtTest, NonQuantizedIntTestTypes, - TestParamNames); +INSTANTIATE_TYPED_TEST_SUITE_P( + Cbrt, UnaryElementwiseSameBaselineElementTypeConstraintTest, + UnaryElementwiseConstraint1Types, TestParamNames); -TYPED_TEST(NonQuantizedIntCbrtTest, IntTensorsRaiseAnError) { - using StorageT = typename TypeParam::StorageT; +using UnsupportedTypes = WithOpTypes< + CbrtOp, ConcatTypes>; - const Shape shape({2, 3, 4}); - Vector input_data = RandomBuffer(shape); - Vector output_data(shape.NumElements()); - - Tensor input_tensor{ - .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CbrtOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} +INSTANTIATE_TYPED_TEST_SUITE_P(Cbrt, UnaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); template -struct NonQuantizedCbrtTest : ::testing::Test {}; +struct CbrtTest : ::testing::Test {}; -TYPED_TEST_SUITE(NonQuantizedCbrtTest, NonQuantizedFloatTestTypes, - TestParamNames); +TYPED_TEST_SUITE(CbrtTest, FloatTestTypes, TestParamNames); -TYPED_TEST(NonQuantizedCbrtTest, FloatTensorsWork) { +TYPED_TEST(CbrtTest, FloatTensorsWork) { using StorageT = typename TypeParam::StorageT; const Shape shape({2, 3, 4}); @@ -147,27 +141,5 @@ TYPED_TEST(QuantizedCbrtTest, PerTensorWorks) { EXPECT_THAT(output_data, ElementsAreArray(expected_data)); } -TYPED_TEST(QuantizedCbrtTest, PerAxisFails) { - using StorageT = typename TypeParam::StorageT; - using ExpressedT = typename TypeParam::ExpressedT; - - const Shape shape({4, 3, 2}); - const int quantized_dimension = 2; - Vector empty_scales; - Vector empty_zero_points; - const QuantizedTensorElementType tensor_type = - QuantizedTensorElementType::PerAxis( - empty_scales, empty_zero_points, quantized_dimension); - Tensor input_tensor{ - .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CbrtOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} - } // namespace } // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/ceil.cc b/tensorflow/lite/experimental/shlo/ops/ceil.cc index 5a506dc12923af..a6b501131db5f9 100644 --- a/tensorflow/lite/experimental/shlo/ops/ceil.cc +++ b/tensorflow/lite/experimental/shlo/ops/ceil.cc @@ -49,20 +49,10 @@ CeilOp Create(CeilOp::Attributes) { return {}; } absl::Status Prepare(CeilOp& op, const Tensor& input, Tensor& output) { SHLO_REF_RETURN_ON_ERROR(Propagate(input.shape(), output.shape())); - if (!input.IsQuantized() && IsInteger(input.StorageType())) { - return absl::FailedPreconditionError( - "stablehlo.ceil does not support integer tensor types."); - } - if (input.IsPerAxisQuantized()) { - return absl::FailedPreconditionError( - "stablehlo.ceil does not support per axis quantization."); - } - if (BaselineType(input.element_type()) != - BaselineType(output.element_type())) { - return absl::FailedPreconditionError( - "stablehlo.ceil constraint (C1) is not satisfied (incompatible " - "baseline types)."); - } + SHLO_REF_RETURN_ON_ERROR(CheckSupportedTypes( + CheckCtx("ceil"), input, IsFloatTensor, IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("ceil"), input, output)); return absl::OkStatus(); } @@ -73,11 +63,11 @@ absl::Status Evaluate(CeilOp& op, const Tensor& input, Tensor& output) { input.quantized_tensor_element_type().StorageType(), input.quantized_tensor_element_type().ExpressedType(), ceil, input, output) - } else { + } else if (IsFloatTensor(input)) { DISPATCH_FLOAT(detail::EvaluateNoQuantization, input.tensor_element_type(), ceil, input, output); } - return absl::OkStatus(); + return absl::FailedPreconditionError("Unsupported tensor type."); } }; // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/ceil_test.cc b/tensorflow/lite/experimental/shlo/ops/ceil_test.cc index 0875a02435e941..4059b19bcca63c 100644 --- a/tensorflow/lite/experimental/shlo/ops/ceil_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/ceil_test.cc @@ -16,26 +16,31 @@ limitations under the License. #include "tensorflow/lite/experimental/shlo/ops/ceil.h" #include +#include #include #include -#include "absl/status/status.h" #include "tensorflow/lite/experimental/shlo/bf16.h" #include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/unary_elementwise_test_util.h" #include "tensorflow/lite/experimental/shlo/quantize.h" #include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" #include "tensorflow/lite/experimental/shlo/shape.h" #include "tensorflow/lite/experimental/shlo/status_matcher.h" #include "tensorflow/lite/experimental/shlo/tensor.h" -using shlo_ref::testing::StatusIs; using testing::ElementsAreArray; using testing::NanSensitiveFloatEq; using testing::Pointwise; namespace shlo_ref { +template <> +struct ParamName { + static std::string Get() { return "Ceil"; } +}; + namespace { struct Ceil { @@ -55,36 +60,25 @@ struct Ceil { } } ceil_ref; -template -struct NonQuantizedIntCeilTest : ::testing::Test {}; +INSTANTIATE_TYPED_TEST_SUITE_P(Ceil, UnaryElementwiseOpShapePropagationTest, + CeilOp, TestParamNames); -TYPED_TEST_SUITE(NonQuantizedIntCeilTest, NonQuantizedIntTestTypes, - TestParamNames); +INSTANTIATE_TYPED_TEST_SUITE_P( + Ceil, UnaryElementwiseSameBaselineElementTypeConstraintTest, + UnaryElementwiseConstraint1Types, TestParamNames); -TYPED_TEST(NonQuantizedIntCeilTest, IntTensorsRaiseAnError) { - using StorageT = typename TypeParam::StorageT; +using UnsupportedTypes = WithOpTypes< + CeilOp, ConcatTypes>; - const Shape shape({2, 3, 4}); - Vector input_data = RandomBuffer(shape); - Vector output_data(shape.NumElements()); - - Tensor input_tensor{ - .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CeilOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} +INSTANTIATE_TYPED_TEST_SUITE_P(Ceil, UnaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); template -struct NonQuantizedCeilTest : ::testing::Test {}; +struct CeilTest : ::testing::Test {}; -TYPED_TEST_SUITE(NonQuantizedCeilTest, NonQuantizedFloatTestTypes, - TestParamNames); +TYPED_TEST_SUITE(CeilTest, FloatTestTypes, TestParamNames); -TYPED_TEST(NonQuantizedCeilTest, FloatTensorsWork) { +TYPED_TEST(CeilTest, FloatTensorsWork) { using StorageT = typename TypeParam::StorageT; const Shape shape({2, 3, 4}); @@ -147,27 +141,5 @@ TYPED_TEST(QuantizedCeilTest, PerTensorWorks) { EXPECT_THAT(output_data, ElementsAreArray(expected_data)); } -TYPED_TEST(QuantizedCeilTest, PerAxisFails) { - using StorageT = typename TypeParam::StorageT; - using ExpressedT = typename TypeParam::ExpressedT; - - const Shape shape({4, 3, 2}); - const int quantized_dimension = 2; - Vector empty_scales; - Vector empty_zero_points; - const QuantizedTensorElementType tensor_type = - QuantizedTensorElementType::PerAxis( - empty_scales, empty_zero_points, quantized_dimension); - Tensor input_tensor{ - .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CeilOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} - } // namespace } // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/cosine.cc b/tensorflow/lite/experimental/shlo/ops/cosine.cc index e373708c15f369..8b757f9709ef18 100644 --- a/tensorflow/lite/experimental/shlo/ops/cosine.cc +++ b/tensorflow/lite/experimental/shlo/ops/cosine.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/lite/experimental/shlo/bf16.h" -#include "tensorflow/lite/experimental/shlo/data_type.h" #include "tensorflow/lite/experimental/shlo/dispatch.h" #include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/ops/unary_elementwise.h" @@ -49,20 +48,10 @@ CosineOp Create(CosineOp::Attributes) { return {}; } absl::Status Prepare(CosineOp& op, const Tensor& input, Tensor& output) { SHLO_REF_RETURN_ON_ERROR(Propagate(input.shape(), output.shape())); - if (!input.IsQuantized() && IsInteger(input.StorageType())) { - return absl::FailedPreconditionError( - "stablehlo.cosine does not support integer tensor types."); - } - if (input.IsPerAxisQuantized()) { - return absl::FailedPreconditionError( - "stablehlo.cosine does not support per axis quantization."); - } - if (BaselineType(input.element_type()) != - BaselineType(output.element_type())) { - return absl::FailedPreconditionError( - "stablehlo.cosine constraint (C1) is not satisfied (incompatible " - "baseline types)."); - } + SHLO_REF_RETURN_ON_ERROR(CheckSupportedTypes( + CheckCtx("cosine"), input, IsFloatTensor, IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("cosine"), input, output)); return absl::OkStatus(); } @@ -73,11 +62,11 @@ absl::Status Evaluate(CosineOp& op, const Tensor& input, Tensor& output) { input.quantized_tensor_element_type().StorageType(), input.quantized_tensor_element_type().ExpressedType(), cosine, input, output) - } else { + } else if (IsFloatTensor(input)) { DISPATCH_FLOAT(detail::EvaluateNoQuantization, input.tensor_element_type(), cosine, input, output); } - return absl::OkStatus(); + return absl::FailedPreconditionError("Unsupported tensor type."); } }; // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/cosine_test.cc b/tensorflow/lite/experimental/shlo/ops/cosine_test.cc index 7eb8901cbe2aff..41fce8a264dd57 100644 --- a/tensorflow/lite/experimental/shlo/ops/cosine_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/cosine_test.cc @@ -16,26 +16,31 @@ limitations under the License. #include "tensorflow/lite/experimental/shlo/ops/cosine.h" #include +#include #include #include -#include "absl/status/status.h" #include "tensorflow/lite/experimental/shlo/bf16.h" #include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/unary_elementwise_test_util.h" #include "tensorflow/lite/experimental/shlo/quantize.h" #include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" #include "tensorflow/lite/experimental/shlo/shape.h" #include "tensorflow/lite/experimental/shlo/status_matcher.h" #include "tensorflow/lite/experimental/shlo/tensor.h" -using shlo_ref::testing::StatusIs; using testing::ElementsAreArray; using testing::NanSensitiveFloatEq; using testing::Pointwise; namespace shlo_ref { +template <> +struct ParamName { + static std::string Get() { return "Cosine"; } +}; + namespace { struct Cosine { @@ -55,36 +60,26 @@ struct Cosine { } } cosine_ref; -template -struct NonQuantizedIntCosineTest : ::testing::Test {}; +INSTANTIATE_TYPED_TEST_SUITE_P(Cosine, UnaryElementwiseOpShapePropagationTest, + CosineOp, TestParamNames); -TYPED_TEST_SUITE(NonQuantizedIntCosineTest, NonQuantizedIntTestTypes, - TestParamNames); +INSTANTIATE_TYPED_TEST_SUITE_P( + Cosine, UnaryElementwiseSameBaselineElementTypeConstraintTest, + UnaryElementwiseConstraint1Types, TestParamNames); -TYPED_TEST(NonQuantizedIntCosineTest, IntTensorsRaiseAnError) { - using StorageT = typename TypeParam::StorageT; +using UnsupportedTypes = + WithOpTypes>; - const Shape shape({2, 3, 4}); - Vector input_data = RandomBuffer(shape); - Vector output_data(shape.NumElements()); - - Tensor input_tensor{ - .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CosineOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} +INSTANTIATE_TYPED_TEST_SUITE_P(Cosine, UnaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); template -struct NonQuantizedCosineTest : ::testing::Test {}; +struct CosineTest : ::testing::Test {}; -TYPED_TEST_SUITE(NonQuantizedCosineTest, NonQuantizedFloatTestTypes, - TestParamNames); +TYPED_TEST_SUITE(CosineTest, FloatTestTypes, TestParamNames); -TYPED_TEST(NonQuantizedCosineTest, FloatTensorsWork) { +TYPED_TEST(CosineTest, FloatTensorsWork) { using StorageT = typename TypeParam::StorageT; const Shape shape({2, 3, 4}); @@ -147,27 +142,5 @@ TYPED_TEST(QuantizedCosineTest, PerTensorWorks) { EXPECT_THAT(output_data, ElementsAreArray(expected_data)); } -TYPED_TEST(QuantizedCosineTest, PerAxisFails) { - using StorageT = typename TypeParam::StorageT; - using ExpressedT = typename TypeParam::ExpressedT; - - const Shape shape({4, 3, 2}); - const int quantized_dimension = 2; - Vector empty_scales; - Vector empty_zero_points; - const QuantizedTensorElementType tensor_type = - QuantizedTensorElementType::PerAxis( - empty_scales, empty_zero_points, quantized_dimension); - Tensor input_tensor{ - .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, - .data = nullptr}; - Tensor output_tensor = input_tensor; - - auto op = Create(CosineOp::Attributes{}); - EXPECT_THAT(Prepare(op, input_tensor, output_tensor), - StatusIs(absl::StatusCode::kFailedPrecondition)); -} - } // namespace } // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/test_util.h b/tensorflow/lite/experimental/shlo/ops/test_util.h index 9b64d10c8455f0..9eaab155e5c20f 100644 --- a/tensorflow/lite/experimental/shlo/ops/test_util.h +++ b/tensorflow/lite/experimental/shlo/ops/test_util.h @@ -18,16 +18,21 @@ limitations under the License. #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "tensorflow/lite/experimental/shlo/data_type.h" +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" #include "tensorflow/lite/experimental/shlo/shape.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" namespace shlo_ref { +// We use a vector class that is different from std::vector to have a consistent +// API when dealing with bool tensors. template using Vector = absl::InlinedVector; @@ -91,6 +96,19 @@ struct TestParam { using ExpressedT = StorageType; }; +// Typed test parameter tag to ask for a per-tensor quantized tensor. +template +struct PerTensor { + using Param = TestParamT; +}; + +// Typed test parameter tag to ask for a per-channel quantized tensor. +template +struct PerAxis { + using Param = TestParamT; + static constexpr Axis axis = kAxis; +}; + constexpr const char* ToString(DataType t) { switch (t) { case DataType::kI1: @@ -133,6 +151,33 @@ struct ParamName> { } }; +template +struct ParamName>> { + static std::string Get() { + std::string name = std::string("PerTensor[") + ToString(T); + ((name += std::string("_") + ToString(Ts)), ...); + return name + "]"; + } +}; + +template +struct ParamName, axis>> { + static std::string Get() { + std::string name = std::string("PerAxis[") + ToString(T); + ((name += std::string("_") + ToString(Ts)), ...); + return name + ":" + std::to_string(axis) + "]"; + } +}; + +template +struct ParamName> { + static std::string Get() { + std::string name = ParamName::Get(); + ((name += std::string(":") + ParamName::Get()), ...); + return name; + } +}; + class TestParamNames { public: template @@ -141,32 +186,123 @@ class TestParamNames { } }; +template