diff --git a/xla/service/gpu/autotuning/BUILD b/xla/service/gpu/autotuning/BUILD index d78886aad3eb2a..213cf3fc570db9 100644 --- a/xla/service/gpu/autotuning/BUILD +++ b/xla/service/gpu/autotuning/BUILD @@ -386,7 +386,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:status", diff --git a/xla/service/gpu/autotuning/autotuner_util.cc b/xla/service/gpu/autotuning/autotuner_util.cc index a652c3d8a103e6..237d06a83b6245 100644 --- a/xla/service/gpu/autotuning/autotuner_util.cc +++ b/xla/service/gpu/autotuning/autotuner_util.cc @@ -463,10 +463,13 @@ absl::StatusOr> TryFindInCache( // Cache miss. if (config.should_require_complete_aot_autotune_results()) { - return NotFound( + absl::Status s = NotFound( "Complete XLA AOT autotuning results are required, but no AOT result " "was found for key: %s", key.ToString()); + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + return s; } TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn()); @@ -593,5 +596,8 @@ AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, autotune_cache_stats = CacheStats(); } +constexpr absl::string_view kAutotuneCacheRequiredErrorPayloadKey = + "https://openxla.org/gpu/autotune_cache_hit_required/"; + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/autotuning/autotuner_util.h b/xla/service/gpu/autotuning/autotuner_util.h index e70b252abb30a0..48bb3e3b291442 100644 --- a/xla/service/gpu/autotuning/autotuner_util.h +++ b/xla/service/gpu/autotuning/autotuner_util.h @@ -57,6 +57,12 @@ struct DevicelessConfig { se::DeviceDescription device_description; }; +// Status payload key to put errors at when autotune cache hits are required. +// See absl::Status docs for full details, but methods like +// {Get,Set,Clear}Payload allow manipulating it. The value of the payload is not +// specified and individual sources of this error may provide different values. +extern const absl::string_view kAutotuneCacheRequiredErrorPayloadKey; + class AutotuneCacheKey { public: AutotuneCacheKey(const se::DeviceDescription& device_description, diff --git a/xla/service/gpu/autotuning/autotuner_util_test.cc b/xla/service/gpu/autotuning/autotuner_util_test.cc index 9c1d7e016a8f09..c34807e254fd75 100644 --- a/xla/service/gpu/autotuning/autotuner_util_test.cc +++ b/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include +#include #include #include @@ -58,6 +59,7 @@ namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Ne; using ::testing::Not; using ::testing::TempDir; using ::testing::UnorderedElementsAre; @@ -221,13 +223,16 @@ TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) { auto options = DebugOptions(); options.set_xla_gpu_require_complete_aot_autotune_results(true); AutotuneConfig config(DeviceConfig{executor}, options); + absl::Status s = AutotunerUtil::Autotune(instruction, config, [&] { + return AutotuneResult(); + }).status(); EXPECT_THAT( - AutotunerUtil::Autotune(instruction, config, - [&] { return AutotuneResult(); }), - StatusIs( - absl::StatusCode::kNotFound, - HasSubstr("Complete XLA AOT autotuning results are required, but " - "no AOT result was found for key: absl::StatusOr { + absl::Status s; if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( + s = absl::InternalError(absl::StrCat( "Expect autotune result cache hit for deviceless " "compilation (HLO: ", fusion_instr->ToString(), ")")); + } else { + s = absl::InternalError("Expect autotune result cache hit."); } - return absl::InternalError("Expect autotune result cache hit."); + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + + return s; })); VLOG(4) << "Autotuning result: " << autotune_result.ShortDebugString();