From f95289371b3860007e40a7474deefe11fd214a25 Mon Sep 17 00:00:00 2001 From: pizzud Date: Tue, 15 Oct 2024 08:09:09 -0700 Subject: [PATCH] autotuning: Append a Status payload to the autotuning cache miss error. Doing so allows detecting this case from callers that might be out-of-process to signal them to recompute the autotune DB. The other callers of Autotune() don't mess with the returned Status, so we don't need to modify them directly. The value is empty right now, as we don't have a need to actually examine it, but if one is needed it should be a serialized proto. PiperOrigin-RevId: 686109156 --- xla/service/gpu/autotuning/BUILD | 1 - xla/service/gpu/autotuning/autotuner_util.cc | 8 +++++++- xla/service/gpu/autotuning/autotuner_util.h | 6 ++++++ .../gpu/autotuning/autotuner_util_test.cc | 17 +++++++++++------ .../gpu/autotuning/conv_algorithm_picker.cc | 1 - .../gpu/autotuning/gemm_fusion_autotuner.cc | 10 ++++++++-- 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index d78886aad3eb2..213cf3fc570db 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -386,7 +386,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:status", diff --git a/xla/service/gpu/autotuning/autotuner_util.cc b/xla/service/gpu/autotuning/autotuner_util.cc index a652c3d8a103e..237d06a83b624 100644 --- a/xla/service/gpu/autotuning/autotuner_util.cc +++ b/xla/service/gpu/autotuning/autotuner_util.cc @@ -463,10 +463,13 @@ absl::StatusOr> TryFindInCache( // Cache miss. if (config.should_require_complete_aot_autotune_results()) { - return NotFound( + absl::Status s = NotFound( "Complete XLA AOT autotuning results are required, but no AOT result " "was found for key: %s", key.ToString()); + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + return s; } TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn()); @@ -593,5 +596,8 @@ AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, autotune_cache_stats = CacheStats(); } +constexpr absl::string_view kAutotuneCacheRequiredErrorPayloadKey = + "https://openxla.org/gpu/autotune_cache_hit_required/"; + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/autotuning/autotuner_util.h b/xla/service/gpu/autotuning/autotuner_util.h index e70b252abb30a..48bb3e3b29144 100644 --- a/xla/service/gpu/autotuning/autotuner_util.h +++ b/xla/service/gpu/autotuning/autotuner_util.h @@ -57,6 +57,12 @@ struct DevicelessConfig { se::DeviceDescription device_description; }; +// Status payload key to put errors at when autotune cache hits are required. +// See absl::Status docs for full details, but methods like +// {Get,Set,Clear}Payload allow manipulating it. The value of the payload is not +// specified and individual sources of this error may provide different values. +extern const absl::string_view kAutotuneCacheRequiredErrorPayloadKey; + class AutotuneCacheKey { public: AutotuneCacheKey(const se::DeviceDescription& device_description, diff --git a/xla/service/gpu/autotuning/autotuner_util_test.cc b/xla/service/gpu/autotuning/autotuner_util_test.cc index 9c1d7e016a8f0..c34807e254fd7 100644 --- a/xla/service/gpu/autotuning/autotuner_util_test.cc +++ b/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include +#include #include #include @@ -58,6 +59,7 @@ namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Ne; using ::testing::Not; using ::testing::TempDir; using ::testing::UnorderedElementsAre; @@ -221,13 +223,16 @@ TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) { auto options = DebugOptions(); options.set_xla_gpu_require_complete_aot_autotune_results(true); AutotuneConfig config(DeviceConfig{executor}, options); + absl::Status s = AutotunerUtil::Autotune(instruction, config, [&] { + return AutotuneResult(); + }).status(); EXPECT_THAT( - AutotunerUtil::Autotune(instruction, config, - [&] { return AutotuneResult(); }), - StatusIs( - absl::StatusCode::kNotFound, - HasSubstr("Complete XLA AOT autotuning results are required, but " - "no AOT result was found for key: absl::StatusOr { + absl::Status s; if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( + s = absl::InternalError(absl::StrCat( "Expect autotune result cache hit for deviceless " "compilation (HLO: ", fusion_instr->ToString(), ")")); + } else { + s = absl::InternalError("Expect autotune result cache hit."); } - return absl::InternalError("Expect autotune result cache hit."); + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + + return s; })); VLOG(4) << "Autotuning result: " << autotune_result.ShortDebugString();