From 7c91d631c64735c9a2580d3ab01274508f68f9f2 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Sun, 27 Oct 2024 18:48:11 +0000 Subject: [PATCH] Update device_description.h --- tensorflow/core/kernels/matmul_op_fused.cc | 3 +- tensorflow/core/kernels/matmul_op_impl.h | 3 +- .../xla/xla/service/gpu/fusions/reduction.cc | 3 +- .../xla/stream_executor/device_description.h | 70 ++++++++++++------- 4 files changed, 46 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index 3908c79885264a..4b9ebe88e22ef7 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -492,8 +492,7 @@ struct LaunchFusedMatMulOp { #if TF_HIPBLASLT auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true; + if (!cap.has_hipblaslt()) use_cudnn = true; #endif BlasScratchAllocator scratch_allocator(context); diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 39abb98b1eca56..4479eeb5217733 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -602,8 +602,7 @@ struct LaunchBatchMatMul { #if TF_HIPBLASLT if (!std::is_same_v) bCublasLtSupport = false; auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false; + if (!cap.has_hipblaslt()) bCublasLtSupport = false; #endif if (EnableCublasLtGemm() && bCublasLtSupport) { static const int64_t max_scratch_size = diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 318e34fea68cdf..8f8e1bc1c5800c 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -212,8 +212,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context) { auto* module = builder->GetInsertBlock()->getModule(); if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { + ir_emitter_context.rocm_compute_capability().fence_before_barrier()) { builder->CreateFence( llvm::AtomicOrdering::SequentiallyConsistent, builder->getContext().getOrInsertSyncScopeID("workgroup")); diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 8ed02e972f67d6..8e48e8c4ab74ee 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -123,14 +123,13 @@ class RocmComputeCapability { public: // gcn_arch_name example -- gfx90a:sramecc+:xnack- // gfx_version is the "gfx90a" part of the gcn_arch_name - explicit RocmComputeCapability(const std::string &gcn_arch_name) - : gcn_arch_name_(gcn_arch_name) {} + explicit RocmComputeCapability(std::string gcn_arch_name) + : gcn_arch_name_(std::move(gcn_arch_name)) {} explicit RocmComputeCapability(const RocmComputeCapabilityProto &proto) : gcn_arch_name_(proto.gcn_arch_name()) {} RocmComputeCapability() = default; - ~RocmComputeCapability() = default; std::string gcn_arch_name() const { return gcn_arch_name_; } @@ -147,38 +146,57 @@ class RocmComputeCapability { return absl::StrJoin(kSupportedGfxVersions, ", "); } - bool has_nhwc_layout_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", - "gfx941", "gfx942"}; + bool gfx9_mi100() const { return gfx_version() == "gfx908"; } + + bool gfx9_mi200() const { return gfx_version() == "gfx90a"; } + + bool gfx9_mi300() const { + static constexpr absl::string_view kList[] = {"gfx940", "gfx941", "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } - bool has_bf16_dtype_support() const { + bool gfx9_mi100_or_later() const { static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } - bool has_fast_fp16_support() const { - static constexpr absl::string_view kList[] = {"gfx906", "gfx908", "gfx90a", - "gfx940", "gfx941", "gfx942", - "gfx1030", "gfx1100"}; + bool gfx9_mi200_or_later() const { + static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941", + "gfx942"}; return absl::c_count(kList, gfx_version()) != 0; } - bool has_mfma_instr_support() const { - static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940", - "gfx941", "gfx942"}; - return absl::c_count(kList, gfx_version()) != 0; + bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; } + + bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; } + + bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } + + bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } + + bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } + + bool has_fast_fp16_support() const { + return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() || + gfx11_rx7900(); } + bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } + bool has_fp16_atomics_support() const { // TODO(rocm): Check. This should be the same as has_fast_fp16_support(). - static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941", - "gfx942"}; - return absl::c_count(kList, gfx_version()) != 0; + return gfx9_mi200_or_later(); + } + + bool fence_before_barrier() const { + return gfx_version() != "gfx900" && gfx_version() != "gfx906"; } + bool has_hipblaslt() const { return gfx9_mi200_or_later(); } + + bool has_fp8_support() const { return gfx9_mi300(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_); @@ -193,15 +211,13 @@ class RocmComputeCapability { std::string gcn_arch_name_ = "gfx000"; // default to invalid arch. static constexpr absl::string_view kSupportedGfxVersions[]{ - "gfx900", // MI25 - "gfx906", // MI50 / MI60 - "gfx908", // MI100 - "gfx90a", // MI200 - "gfx940", // MI300 - "gfx941", // MI300 - "gfx942", // MI300 - "gfx1030", // RX68xx / RX69xx - "gfx1100" // RX7900 + "gfx900", // MI25 + "gfx906", // MI50 / MI60 + "gfx908", // MI100 + "gfx90a", // MI200 + "gfx940", "gfx941", "gfx942", // MI300 + "gfx1030", // RX68xx / RX69xx + "gfx1100" // RX7900 }; };