Skip to content

Commit

Permalink
Merge pull request #2732 from ROCm/r2.15-rocm-enhanced-dd
Browse files Browse the repository at this point in the history
Update device_description.h
  • Loading branch information
i-chaochen authored Oct 28, 2024
2 parents 26c72d0 + 7c91d63 commit c33aa29
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 33 deletions.
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,7 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {

#if TF_HIPBLASLT
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true;
if (!cap.has_hipblaslt()) use_cudnn = true;
#endif

BlasScratchAllocator scratch_allocator(context);
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
#if TF_HIPBLASLT
if (!std::is_same_v<Scalar, float>) bCublasLtSupport = false;
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false;
if (!cap.has_hipblaslt()) bCublasLtSupport = false;
#endif
if (EnableCublasLtGemm() && bCublasLtSupport) {
static const int64_t max_scratch_size =
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder,
IrEmitterContext& ir_emitter_context) {
auto* module = builder->GetInsertBlock()->getModule();
if (IsAMDGPU(module) &&
ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr(
0, 6) == "gfx90a") {
ir_emitter_context.rocm_compute_capability().fence_before_barrier()) {
builder->CreateFence(
llvm::AtomicOrdering::SequentiallyConsistent,
builder->getContext().getOrInsertSyncScopeID("workgroup"));
Expand Down
70 changes: 43 additions & 27 deletions third_party/xla/xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,13 @@ class RocmComputeCapability {
public:
// gcn_arch_name example -- gfx90a:sramecc+:xnack-
// gfx_version is the "gfx90a" part of the gcn_arch_name
explicit RocmComputeCapability(const std::string &gcn_arch_name)
: gcn_arch_name_(gcn_arch_name) {}
explicit RocmComputeCapability(std::string gcn_arch_name)
: gcn_arch_name_(std::move(gcn_arch_name)) {}

explicit RocmComputeCapability(const RocmComputeCapabilityProto &proto)
: gcn_arch_name_(proto.gcn_arch_name()) {}

RocmComputeCapability() = default;
~RocmComputeCapability() = default;

std::string gcn_arch_name() const { return gcn_arch_name_; }

Expand All @@ -147,38 +146,57 @@ class RocmComputeCapability {
return absl::StrJoin(kSupportedGfxVersions, ", ");
}

bool has_nhwc_layout_support() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
bool gfx9_mi100() const { return gfx_version() == "gfx908"; }

bool gfx9_mi200() const { return gfx_version() == "gfx90a"; }

bool gfx9_mi300() const {
static constexpr absl::string_view kList[] = {"gfx940", "gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_bf16_dtype_support() const {
bool gfx9_mi100_or_later() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_fast_fp16_support() const {
static constexpr absl::string_view kList[] = {"gfx906", "gfx908", "gfx90a",
"gfx940", "gfx941", "gfx942",
"gfx1030", "gfx1100"};
bool gfx9_mi200_or_later() const {
static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941",
"gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
}

bool has_mfma_instr_support() const {
static constexpr absl::string_view kList[] = {"gfx908", "gfx90a", "gfx940",
"gfx941", "gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; }

bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; }

bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; }

bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); }

bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); }

bool has_fast_fp16_support() const {
return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() ||
gfx11_rx7900();
}

bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }

bool has_fp16_atomics_support() const {
// TODO(rocm): Check. This should be the same as has_fast_fp16_support().
static constexpr absl::string_view kList[] = {"gfx90a", "gfx940", "gfx941",
"gfx942"};
return absl::c_count(kList, gfx_version()) != 0;
return gfx9_mi200_or_later();
}

bool fence_before_barrier() const {
return gfx_version() != "gfx900" && gfx_version() != "gfx906";
}

bool has_hipblaslt() const { return gfx9_mi200_or_later(); }

bool has_fp8_support() const { return gfx9_mi300(); }

RocmComputeCapabilityProto ToProto() const {
RocmComputeCapabilityProto proto;
proto.set_gcn_arch_name(gcn_arch_name_);
Expand All @@ -193,15 +211,13 @@ class RocmComputeCapability {
std::string gcn_arch_name_ = "gfx000"; // default to invalid arch.

static constexpr absl::string_view kSupportedGfxVersions[]{
"gfx900", // MI25
"gfx906", // MI50 / MI60
"gfx908", // MI100
"gfx90a", // MI200
"gfx940", // MI300
"gfx941", // MI300
"gfx942", // MI300
"gfx1030", // RX68xx / RX69xx
"gfx1100" // RX7900
"gfx900", // MI25
"gfx906", // MI50 / MI60
"gfx908", // MI100
"gfx90a", // MI200
"gfx940", "gfx941", "gfx942", // MI300
"gfx1030", // RX68xx / RX69xx
"gfx1100" // RX7900
};
};

Expand Down

0 comments on commit c33aa29

Please sign in to comment.