Skip to content

Commit

Permalink
Merge pull request #2310 from ROCmSoftwarePlatform/hipblaslt_6.0_fixe…
Browse files Browse the repository at this point in the history
…s_dev_upstream

Fixing enums and datatypes for rocm6.0
  • Loading branch information
jayfurmanek authored Nov 29, 2023
2 parents e0ebf23 + 5af3d7f commit 8bbf798
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 57 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ cc_library(
"//xla/stream_executor:host_or_device_scalar",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:hipblas_lt_header",
"//xla/stream_executor/rocm:hipblaslt_plugin",
"//xla/stream_executor/rocm:amdhipblaslt_plugin",
"//xla/stream_executor:host_or_device_scalar",
"//xla/stream_executor/platform:dso_loader",
]) + if_static([
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ cc_library(
)

cc_library(
name = "hipblaslt_plugin",
name = "amdhipblaslt_plugin",
srcs = if_rocm_is_configured(["hip_blas_lt.cc"]),
hdrs = if_rocm_is_configured([
"hip_blas_lt.h",
Expand Down Expand Up @@ -560,7 +560,7 @@ cc_library(
":rocm_driver",
":rocm_platform",
":rocm_helpers",
":hipblaslt_plugin",
":amdhipblaslt_plugin",
]),
alwayslink = 1,
)
Expand Down
46 changes: 23 additions & 23 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,31 +421,31 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul(

namespace {

template <hipblasltDatatype_t>
template <hipDataType>
struct HipToNativeT;

template <>
struct HipToNativeT<HIPBLASLT_R_16B> {
struct HipToNativeT<HIP_R_16BF> {
using type = Eigen::bfloat16;
};
template <>
struct HipToNativeT<HIPBLASLT_R_16F> {
struct HipToNativeT<HIP_R_16F> {
using type = Eigen::half;
};
template <>
struct HipToNativeT<HIPBLASLT_R_32F> {
struct HipToNativeT<HIP_R_32F> {
using type = float;
};
template <>
struct HipToNativeT<HIPBLASLT_R_64F> {
struct HipToNativeT<HIP_R_64F> {
using type = double;
};
template <>
struct HipToNativeT<HIPBLASLT_C_32F> {
struct HipToNativeT<HIP_C_32F> {
using type = complex64;
};
template <>
struct HipToNativeT<HIPBLASLT_C_64F> {
struct HipToNativeT<HIP_C_64F> {
using type = complex128;
};

Expand Down Expand Up @@ -476,22 +476,22 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream(
}

// Other data types:
TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_16B,
HIPBLASLT_R_16B)
TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_16F,
HIPBLASLT_R_16F)
TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_32F,
HIPBLASLT_R_32F)
TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_32F,
HIPBLASLT_R_32F)
TYPED_MATMUL(float, HIPBLASLT_R_32F, HIPBLASLT_R_32F, HIPBLASLT_R_32F,
HIPBLASLT_R_32F)
TYPED_MATMUL(double, HIPBLASLT_R_64F, HIPBLASLT_R_64F, HIPBLASLT_R_64F,
HIPBLASLT_R_64F)
TYPED_MATMUL(complex64, HIPBLASLT_C_32F, HIPBLASLT_C_32F, HIPBLASLT_C_32F,
HIPBLASLT_C_32F)
TYPED_MATMUL(complex128, HIPBLASLT_C_64F, HIPBLASLT_C_64F, HIPBLASLT_C_64F,
HIPBLASLT_C_64F)
TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF,
HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F,
HIP_R_16F)
TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F,
HIP_R_32F)
TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_32F,
HIP_R_32F)
TYPED_MATMUL(float, HIP_R_32F, HIP_R_32F, HIP_R_32F,
HIP_R_32F)
TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F,
HIP_R_64F)
TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F,
HIP_C_32F)
TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F,
HIP_C_64F)

#undef TYPED_MATMUL

Expand Down
18 changes: 9 additions & 9 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ class BlasLt : public gpu::BlasLt {
struct MatrixLayout {
static tsl::StatusOr<MatrixLayout> Create(const gpu::MatrixLayout& m);

hipblasltDatatype_t type() const { return datatype_; }
hipDataType type() const { return datatype_; }
hipblasLtMatrixLayout_t get() const { return handle_.get(); }

private:
MatrixLayout(hipblasLtMatrixLayout_t handle, hipblasltDatatype_t datatype)
MatrixLayout(hipblasLtMatrixLayout_t handle, hipDataType datatype)
: handle_(handle, wrap::hipblasLtMatrixLayoutDestroy),
datatype_(datatype) {}

Owned<hipblasLtMatrixLayout_t> handle_;
hipblasltDatatype_t datatype_;
hipDataType datatype_;
};

class MatmulDesc {
Expand All @@ -64,24 +64,24 @@ class BlasLt : public gpu::BlasLt {
Epilogue epilogue = Epilogue::kDefault,
PointerMode pointer_mode = PointerMode::kHost);

hipblasLtComputeType_t compute_type() const { return compute_type_; }
hipblasltDatatype_t scale_type() const { return datatype_; }
hipblasComputeType_t compute_type() const { return compute_type_; }
hipDataType scale_type() const { return datatype_; }
hipblasPointerMode_t pointer_mode() const {
return HIPBLAS_POINTER_MODE_HOST;
}
hipblasLtMatmulDesc_t get() const { return handle_.get(); }

private:
MatmulDesc(hipblasLtMatmulDesc_t handle,
hipblasLtComputeType_t compute_type,
hipblasltDatatype_t datatype)
hipblasComputeType_t compute_type,
hipDataType datatype)
: handle_(handle, wrap::hipblasLtMatmulDescDestroy),
compute_type_(compute_type),
datatype_(datatype) {}

Owned<hipblasLtMatmulDesc_t> handle_;
hipblasLtComputeType_t compute_type_;
hipblasltDatatype_t datatype_;
hipblasComputeType_t compute_type_;
hipDataType datatype_;
};

struct MatmulPlan : public gpu::BlasLt::MatmulPlan {
Expand Down
22 changes: 11 additions & 11 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,36 @@ tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) {
return tsl::OkStatus();
}

hipblasltDatatype_t AsHipblasDataType(blas::DataType type) {
hipDataType AsHipblasDataType(blas::DataType type) {
switch (type) {
case blas::DataType::kF8E5M2:
case blas::DataType::kF8E4M3FN:
LOG(FATAL) << "hipblaslt does not support F8 yet";
case blas::DataType::kHalf:
return HIPBLASLT_R_16F;
return HIP_R_16F;
case blas::DataType::kBF16:
return HIPBLASLT_R_16B;
return HIP_R_16BF;
case blas::DataType::kFloat:
return HIPBLASLT_R_32F;
return HIP_R_32F;
case blas::DataType::kDouble:
return HIPBLASLT_R_64F;
return HIP_R_64F;
case blas::DataType::kInt8:
return HIPBLASLT_R_8I;
return HIP_R_8I;
case blas::DataType::kInt32:
return HIPBLASLT_R_32I;
return HIP_R_32I;
case blas::DataType::kComplexFloat:
return HIPBLASLT_C_32F;
return HIP_C_32F;
case blas::DataType::kComplexDouble:
return HIPBLASLT_C_64F;
return HIP_C_64F;
default:
LOG(FATAL) << "unknown data type";
}
}

hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type) {
hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type) {
if (type == blas::ComputationType::kF32 ||
type == blas::ComputationType::kTF32AsF32)
return HIPBLASLT_COMPUTE_F32;
return HIPBLAS_COMPUTE_32F;
else
LOG(FATAL) << "unsupported hipblaslt computation type";
}
Expand Down
27 changes: 16 additions & 11 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,20 @@ limitations under the License.
#if TF_HIPBLASLT

#if TF_ROCM_VERSION < 60000
#define hipblasltDatatype_t hipblasDatatype_t
#define HIPBLASLT_R_16F HIPBLAS_R_16F
#define HIPBLASLT_R_16B HIPBLAS_R_16B
#define HIPBLASLT_R_32F HIPBLAS_R_32F
#define HIPBLASLT_R_64F HIPBLAS_R_64F
#define HIPBLASLT_R_8I HIPBLAS_R_8I
#define HIPBLASLT_R_32I HIPBLAS_R_32I
#define HIPBLASLT_C_32F HIPBLAS_C_32F
#define HIPBLASLT_C_64F HIPBLAS_C_64F
#define hipDataType hipblasDatatype_t
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_8I HIPBLAS_R_8I
#define HIP_R_32I HIPBLAS_R_32I
#define HIP_C_32F HIPBLAS_C_32F
#define HIP_C_64F HIPBLAS_C_64F

#define hipblasComputeType_t hipblasLtComputeType_t
#define HIPBLAS_COMPUTE_32F HIPBLASLT_COMPUTE_F32
#define HIPBLAS_COMPUTE_64F HIPBLASLT_COMPUTE_F64
#define HIPBLAS_COMPUTE_32I HIPBLASLT_COMPUTE_I32
#endif

namespace stream_executor {
Expand All @@ -46,8 +51,8 @@ namespace rocm {
TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr))

tsl::Status ToStatus(hipblasStatus_t status, const char* prefix);
hipblasltDatatype_t AsHipblasDataType(blas::DataType type);
hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type);
hipDataType AsHipblasDataType(blas::DataType type);
hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type);
hipblasOperation_t AsHipblasOperation(blas::Transpose trans);

} // namespace rocm
Expand Down

0 comments on commit 8bbf798

Please sign in to comment.