Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rCk working truth #1622

Draft
wants to merge 15 commits into
base: rocm_gemm_ck
Choose a base branch
from
Draft
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,7 @@
path = third_party/cpp-httplib
url = https://github.com/yhirose/cpp-httplib.git
branch = v0.15.3
[submodule "third_party/composable_kernel"]
path = third_party/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = develop
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ cmake_policy(SET CMP0069 NEW)
# and it's possible on our Windows configs.
cmake_policy(SET CMP0092 NEW)

include(CMakePrintHelpers)



# Prohibit in-source builds
if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR})
message(FATAL_ERROR "In-source build are not supported")
Expand Down Expand Up @@ -773,7 +777,7 @@ set(CAFFE2_ALLOWLIST
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Build type not set - defaulting to Release")
set(CMAKE_BUILD_TYPE
"Release"
"Debug"
CACHE
STRING
"Choose the type of build from: Debug Release RelWithDebInfo MinSizeRel Coverage."
Expand Down Expand Up @@ -851,6 +855,8 @@ endif()
# aotriton build decision later.

include(cmake/Dependencies.cmake)
message("BEFORE USE_FLASH ATTENTION IS SUPPOSEDLY CREATED")
cmake_print_variables(USE_FLASH_ATTENTION)

cmake_dependent_option(
USE_FLASH_ATTENTION
Expand All @@ -860,6 +866,9 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

message("AFTER USE_FLASH_ATTENTION IS CREATED")


# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
Expand All @@ -871,7 +880,7 @@ cmake_dependent_option(
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM" OFF)
"USE_CUDA" OFF)

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/BlasBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

namespace at {

enum class BlasBackend : int8_t { Cublas, Cublaslt };
enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck };

inline std::string BlasBackendToString(at::BlasBackend backend) {
switch (backend) {
case BlasBackend::Cublas:
return "at::BlasBackend::Cublas";
case BlasBackend::Cublaslt:
return "at::BlasBackend::Cublaslt";
case BlasBackend::Ck:
return "at::BlasBackend::Ck";
default:
TORCH_CHECK(false, "Unknown blas backend");
}
Expand Down
26 changes: 24 additions & 2 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,22 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")

# if USE_CK_FLASH_ATTENTION
# this should generate the files before anything eelse
add_subdirectory(native/transformers/hip/flash_attn)
#set_source_files_properties(
# "native/transformers/hip/flash_attn/*.cpp"
# DIRECTORY "native/transformers/hip/flash_attn"
# PROPERTIES
# COMPILE_FLAGS "-Wno-undefined-func-template"
# )

#endif

# flash_attention sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
file(GLOB flash_attention_hip_cpp "native/transformers/hip/flash_attn/*.cpp")

#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
Expand All @@ -184,6 +197,14 @@ if(USE_FLASH_ATTENTION)

list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
# add_subdirectory(native/transformers/hip/flash_attn)
# set_source_files_properties(
# "native/transformers/hip/flash_attn/*.cpp"
# DIRECTORY "native/transformers/hip/flash_attn"
# PROPERTIES
# COMPILE_FLAGS "-Wno-undefined-func-template"
# )
list(APPEND native_transformers_hip_cpp ${flash_attention_hip_cpp})
endif()

if(USE_MEM_EFF_ATTENTION)
Expand Down Expand Up @@ -307,8 +328,9 @@ endif()

if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include/)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include/)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha/)
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
if (b != at::BlasBackend::Cublas) {
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class TORCH_API Context {
static bool hasCuBLASLt() {
return detail::getCUDAHooks().hasCuBLASLt();
}
static bool hasROCM() {
return detail::getCUDAHooks().hasROCM();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
#include <ATen/native/hip/ck_gemm.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// needed to work around calling rocblas API instead of hipblas API
Expand Down Expand Up @@ -792,6 +793,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::gemm_internal_cublas: not implemented for ", typeid(Dtype).name());
}


template <>
void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
Expand Down Expand Up @@ -1000,6 +1002,11 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
#endif
else {
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
}
Expand All @@ -1011,6 +1018,11 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
#endif
else {
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
}
Expand Down Expand Up @@ -1054,6 +1066,11 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#endif
else {
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
Expand All @@ -1065,6 +1082,11 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#endif
else {
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}

IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/hip/ck_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <ATen/OpMathType.h>
#include <ATen/hip/HIPBlas.h>
namespace at::native {


template <typename Dtype>
inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
}

template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float));
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));



} // namespace at::native
Loading