diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 266f6fa99f67ec..7b493c1a78a9f9 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1594,7 +1594,7 @@ cc_library( "//xla/stream_executor:host_or_device_scalar", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:hipblas_lt_header", - "//xla/stream_executor/rocm:hipblaslt_plugin", + "//xla/stream_executor/rocm:amdhipblaslt_plugin", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/platform:dso_loader", ]) + if_static([ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index da37c1851e9635..a8b396811a9f51 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -440,7 +440,7 @@ cc_library( ) cc_library( - name = "hipblaslt_plugin", + name = "amdhipblaslt_plugin", srcs = if_rocm_is_configured(["hip_blas_lt.cc"]), hdrs = if_rocm_is_configured([ "hip_blas_lt.h", @@ -560,7 +560,7 @@ cc_library( ":rocm_driver", ":rocm_platform", ":rocm_helpers", - ":hipblaslt_plugin", + ":amdhipblaslt_plugin", ]), alwayslink = 1, ) diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 6a20995b438a3f..793a5e26ec34b0 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -421,31 +421,31 @@ tsl::Status BlasLt::MatmulPlan::DoMatmul( namespace { -template +template struct HipToNativeT; template <> -struct HipToNativeT { +struct HipToNativeT { using type = Eigen::bfloat16; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = Eigen::half; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = float; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = double; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = complex64; }; template <> -struct HipToNativeT { +struct HipToNativeT { using type = complex128; }; @@ -476,22 +476,22 @@ tsl::Status BlasLt::MatmulPlan::ExecuteOnStream( } // Other data types: - TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_16B, - HIPBLASLT_R_16B) - TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_16F, - HIPBLASLT_R_16F) - TYPED_MATMUL(float, HIPBLASLT_R_16B, HIPBLASLT_R_16B, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(float, HIPBLASLT_R_16F, HIPBLASLT_R_16F, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(float, HIPBLASLT_R_32F, HIPBLASLT_R_32F, HIPBLASLT_R_32F, - HIPBLASLT_R_32F) - TYPED_MATMUL(double, HIPBLASLT_R_64F, HIPBLASLT_R_64F, HIPBLASLT_R_64F, - HIPBLASLT_R_64F) - TYPED_MATMUL(complex64, HIPBLASLT_C_32F, HIPBLASLT_C_32F, HIPBLASLT_C_32F, - HIPBLASLT_C_32F) - TYPED_MATMUL(complex128, HIPBLASLT_C_64F, HIPBLASLT_C_64F, HIPBLASLT_C_64F, - HIPBLASLT_C_64F) + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, + HIP_R_16BF) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_16F, + HIP_R_16F) + TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, + HIP_R_32F) + TYPED_MATMUL(float, HIP_R_16F, HIP_R_16F, HIP_R_32F, + HIP_R_32F) + TYPED_MATMUL(float, HIP_R_32F, HIP_R_32F, HIP_R_32F, + HIP_R_32F) + TYPED_MATMUL(double, HIP_R_64F, HIP_R_64F, HIP_R_64F, + HIP_R_64F) + TYPED_MATMUL(complex64, HIP_C_32F, HIP_C_32F, HIP_C_32F, + HIP_C_32F) + TYPED_MATMUL(complex128, HIP_C_64F, HIP_C_64F, HIP_C_64F, + HIP_C_64F) #undef TYPED_MATMUL diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 54a7eec4fbec74..b1a562895163ce 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -43,16 +43,16 @@ class BlasLt : public gpu::BlasLt { struct MatrixLayout { static tsl::StatusOr Create(const gpu::MatrixLayout& m); - hipblasltDatatype_t type() const { return datatype_; } + hipDataType type() const { return datatype_; } hipblasLtMatrixLayout_t get() const { return handle_.get(); } private: - MatrixLayout(hipblasLtMatrixLayout_t handle, hipblasltDatatype_t datatype) + MatrixLayout(hipblasLtMatrixLayout_t handle, hipDataType datatype) : handle_(handle, wrap::hipblasLtMatrixLayoutDestroy), datatype_(datatype) {} Owned handle_; - hipblasltDatatype_t datatype_; + hipDataType datatype_; }; class MatmulDesc { @@ -64,8 +64,8 @@ class BlasLt : public gpu::BlasLt { Epilogue epilogue = Epilogue::kDefault, PointerMode pointer_mode = PointerMode::kHost); - hipblasLtComputeType_t compute_type() const { return compute_type_; } - hipblasltDatatype_t scale_type() const { return datatype_; } + hipblasComputeType_t compute_type() const { return compute_type_; } + hipDataType scale_type() const { return datatype_; } hipblasPointerMode_t pointer_mode() const { return HIPBLAS_POINTER_MODE_HOST; } @@ -73,15 +73,15 @@ class BlasLt : public gpu::BlasLt { private: MatmulDesc(hipblasLtMatmulDesc_t handle, - hipblasLtComputeType_t compute_type, - hipblasltDatatype_t datatype) + hipblasComputeType_t compute_type, + hipDataType datatype) : handle_(handle, wrap::hipblasLtMatmulDescDestroy), compute_type_(compute_type), datatype_(datatype) {} Owned handle_; - hipblasLtComputeType_t compute_type_; - hipblasltDatatype_t datatype_; + hipblasComputeType_t compute_type_; + hipDataType datatype_; }; struct MatmulPlan : public gpu::BlasLt::MatmulPlan { diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc index 96fb1998cb88be..44607363c56404 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc @@ -33,36 +33,36 @@ tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) { return tsl::OkStatus(); } -hipblasltDatatype_t AsHipblasDataType(blas::DataType type) { +hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: case blas::DataType::kF8E4M3FN: LOG(FATAL) << "hipblaslt does not support F8 yet"; case blas::DataType::kHalf: - return HIPBLASLT_R_16F; + return HIP_R_16F; case blas::DataType::kBF16: - return HIPBLASLT_R_16B; + return HIP_R_16BF; case blas::DataType::kFloat: - return HIPBLASLT_R_32F; + return HIP_R_32F; case blas::DataType::kDouble: - return HIPBLASLT_R_64F; + return HIP_R_64F; case blas::DataType::kInt8: - return HIPBLASLT_R_8I; + return HIP_R_8I; case blas::DataType::kInt32: - return HIPBLASLT_R_32I; + return HIP_R_32I; case blas::DataType::kComplexFloat: - return HIPBLASLT_C_32F; + return HIP_C_32F; case blas::DataType::kComplexDouble: - return HIPBLASLT_C_64F; + return HIP_C_64F; default: LOG(FATAL) << "unknown data type"; } } -hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type) { +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type) { if (type == blas::ComputationType::kF32 || type == blas::ComputationType::kTF32AsF32) - return HIPBLASLT_COMPUTE_F32; + return HIPBLAS_COMPUTE_32F; else LOG(FATAL) << "unsupported hipblaslt computation type"; } diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h index 00e015b7aabfc4..85c67915288a85 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h @@ -28,15 +28,20 @@ limitations under the License. #if TF_HIPBLASLT #if TF_ROCM_VERSION < 60000 -#define hipblasltDatatype_t hipblasDatatype_t -#define HIPBLASLT_R_16F HIPBLAS_R_16F -#define HIPBLASLT_R_16B HIPBLAS_R_16B -#define HIPBLASLT_R_32F HIPBLAS_R_32F -#define HIPBLASLT_R_64F HIPBLAS_R_64F -#define HIPBLASLT_R_8I HIPBLAS_R_8I -#define HIPBLASLT_R_32I HIPBLAS_R_32I -#define HIPBLASLT_C_32F HIPBLAS_C_32F -#define HIPBLASLT_C_64F HIPBLAS_C_64F +#define hipDataType hipblasDatatype_t +#define HIP_R_16F HIPBLAS_R_16F +#define HIP_R_16BF HIPBLAS_R_16B +#define HIP_R_32F HIPBLAS_R_32F +#define HIP_R_64F HIPBLAS_R_64F +#define HIP_R_8I HIPBLAS_R_8I +#define HIP_R_32I HIPBLAS_R_32I +#define HIP_C_32F HIPBLAS_C_32F +#define HIP_C_64F HIPBLAS_C_64F + +#define hipblasComputeType_t hipblasLtComputeType_t +#define HIPBLAS_COMPUTE_32F HIPBLASLT_COMPUTE_F32 +#define HIPBLAS_COMPUTE_64F HIPBLASLT_COMPUTE_F64 +#define HIPBLAS_COMPUTE_32I HIPBLASLT_COMPUTE_I32 #endif namespace stream_executor { @@ -46,8 +51,8 @@ namespace rocm { TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr)) tsl::Status ToStatus(hipblasStatus_t status, const char* prefix); -hipblasltDatatype_t AsHipblasDataType(blas::DataType type); -hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type); +hipDataType AsHipblasDataType(blas::DataType type); +hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type); hipblasOperation_t AsHipblasOperation(blas::Transpose trans); } // namespace rocm