Skip to content

Commit

Permalink
autotuning: Append a Status payload to the autotuning cache miss error.
Browse files Browse the repository at this point in the history
Doing so allows detecting this case from callers that might be out-of-process
to signal them to recompute the autotune DB. The other callers of Autotune() don't
mess with the returned Status, so we don't need to modify them directly.

The value is empty right now, as we don't have a need to actually examine it,
but if one is needed it should be a serialized proto.

PiperOrigin-RevId: 684201190
  • Loading branch information
pizzud authored and Google-ML-Automation committed Oct 15, 2024
1 parent b46d80b commit 5b9bd27
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 11 deletions.
1 change: 0 additions & 1 deletion xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:numbers",
"@tsl//tsl/platform:status",
Expand Down
8 changes: 7 additions & 1 deletion xla/service/gpu/autotuning/autotuner_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,13 @@ absl::StatusOr<std::optional<AutotuneResult>> TryFindInCache(

// Cache miss.
if (config.should_require_complete_aot_autotune_results()) {
return NotFound(
absl::Status s = NotFound(
"Complete XLA AOT autotuning results are required, but no AOT result "
"was found for key: %s",
key.ToString());
tsl::errors::InsertPayloads(
s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}});
return s;
}

TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn());
Expand Down Expand Up @@ -593,5 +596,8 @@ AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config,
autotune_cache_stats = CacheStats();
}

constexpr absl::string_view kAutotuneCacheRequiredErrorPayloadKey =
"https://openxla.org/gpu/autotune_cache_hit_required/";

} // namespace gpu
} // namespace xla
6 changes: 6 additions & 0 deletions xla/service/gpu/autotuning/autotuner_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ struct DevicelessConfig {
se::DeviceDescription device_description;
};

// Status payload key to put errors at when autotune cache hits are required.
// See absl::Status docs for full details, but methods like
// {Get,Set,Clear}Payload allow manipulating it. The value of the payload is not
// specified and individual sources of this error may provide different values.
extern const absl::string_view kAutotuneCacheRequiredErrorPayloadKey;

class AutotuneCacheKey {
public:
AutotuneCacheKey(const se::DeviceDescription& device_description,
Expand Down
17 changes: 11 additions & 6 deletions xla/service/gpu/autotuning/autotuner_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/gpu/autotuning/autotuner_util.h"

#include <memory>
#include <optional>
#include <string>
#include <vector>

Expand Down Expand Up @@ -58,6 +59,7 @@ namespace {
using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Ne;
using ::testing::Not;
using ::testing::TempDir;
using ::testing::UnorderedElementsAre;
Expand Down Expand Up @@ -221,13 +223,16 @@ TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) {
auto options = DebugOptions();
options.set_xla_gpu_require_complete_aot_autotune_results(true);
AutotuneConfig config(DeviceConfig{executor}, options);
absl::Status s = AutotunerUtil::Autotune(instruction, config, [&] {
return AutotuneResult();
}).status();
EXPECT_THAT(
AutotunerUtil::Autotune(instruction, config,
[&] { return AutotuneResult(); }),
StatusIs(
absl::StatusCode::kNotFound,
HasSubstr("Complete XLA AOT autotuning results are required, but "
"no AOT result was found for key: <key model")));
s, StatusIs(
absl::StatusCode::kNotFound,
HasSubstr("Complete XLA AOT autotuning results are required, but "
"no AOT result was found for key: <key model")));
EXPECT_THAT(s.GetPayload(kAutotuneCacheRequiredErrorPayloadKey),
Ne(std::nullopt));
EXPECT_EQ(AutotunerUtil::GetCacheStats().cache_hits, 0);
EXPECT_EQ(AutotunerUtil::GetCacheStats().cache_misses, 1);
}
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/autotuning/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ limitations under the License.
#include "xla/tsl/util/proto/proto_utils.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/status.h"
Expand Down
10 changes: 8 additions & 2 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,19 @@ absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion(
AutotuneResult autotune_result,
AutotunerUtil::Autotune(
fusion_instr, config_, [&]() -> absl::StatusOr<AutotuneResult> {
absl::Status s;
if (config_.IsDeviceless()) {
return absl::InternalError(absl::StrCat(
s = absl::InternalError(absl::StrCat(
"Expect autotune result cache hit for deviceless "
"compilation (HLO: ",
fusion_instr->ToString(), ")"));
} else {
s = absl::InternalError("Expect autotune result cache hit.");
}
return absl::InternalError("Expect autotune result cache hit.");
tsl::errors::InsertPayloads(
s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}});

return s;
}));
VLOG(4) << "Autotuning result: " << autotune_result.ShortDebugString();

Expand Down

0 comments on commit 5b9bd27

Please sign in to comment.