From 8bb7d028089b6b740f07524caafeddb74e2fcbc1 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Fri, 25 Oct 2024 13:20:13 +0000 Subject: [PATCH] compile gpu_blas_lt_gemm_runner successfully --- tensorflow/compiler/xla/stream_executor/BUILD | 1 + .../compiler/xla/stream_executor/gpu/BUILD | 14 +- .../gpu/gpu_blas_lt_gemm_runner.cc | 348 ++++++++++++++++++ .../gpu/gpu_blas_lt_gemm_runner.h | 255 +++++++++++++ 4 files changed, 611 insertions(+), 7 deletions(-) create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD index 0425da4aea423a..34f7034934856a 100644 --- a/tensorflow/compiler/xla/stream_executor/BUILD +++ b/tensorflow/compiler/xla/stream_executor/BUILD @@ -450,6 +450,7 @@ tsl_gpu_library( ":temporary_memory_manager", ":timer", "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt_gemm_runner", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index 1c327bb60ca64f..9479d3cae5eb80 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -87,13 +87,13 @@ cc_library( srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]), hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]), deps = if_gpu_is_configured([ - "//tensorflow/core:autotuning_proto_cc", - "//tensorflow/core:autotune_results_proto_cc", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/stream_executor:scratch_allocator", - "//tensorflow/compiler/xla/service/gpu:autotuner_util", - "//tensorflow/compiler/xla:debug_options_flags", - ":gpu_blas_lt", + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//tensorflow/compiler/xla:autotune_results_proto_cc", + # "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", + "//tensorflow/compiler/xla/service/gpu:autotuner_util", + "//tensorflow/compiler/xla:debug_options_flags", + ":gpu_blas_lt", ]), ) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc new file mode 100644 index 00000000000000..61e6039012bbdb --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc @@ -0,0 +1,348 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/util/env_var.h" + +namespace stream_executor { +namespace gpu { + +bool BlasLtGemmRunner::autotune_enabled_ = true; + +bool operator==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +bool operator==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +std::ostream& operator<<(std::ostream& os, const StridedGemmConfig& cfg) { + return os << "trans_a/b: " << (int)cfg.trans_a << "/" << (int)cfg.trans_b + << " m: " << cfg.m << " n: " << cfg.n << " k: " << cfg.k + << " batch_count: " << cfg.batch_count << " lda: " << cfg.lda + << " ldb: " << cfg.ldb << " ldc: " << cfg.ldc + << " stride_a: " << cfg.stride_a << " stride_b: " << cfg.stride_b + << " stride_c: " << cfg.stride_c << " type_a: " << (int)cfg.type_a + << " type_b: " << (int)cfg.type_b << " type_c: " << (int)cfg.type_c + << " alpha: " << cfg.alpha << " beta: " << cfg.beta; +} + +BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor* parent) + : mutex_(std::make_unique()), + autotune_config_(std::make_unique( + xla::gpu::DeviceConfig{parent, nullptr}, + xla::GetDebugOptionsFromFlags())) {} + +BlasLtGemmRunner::~BlasLtGemmRunner() {} + +/*static*/ BlasLtGemmRunner& BlasLtGemmRunner::i(const Stream* stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets a different cache instance + static std::vector> meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) { + autotune_enabled_ = + xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0; + res.reset(new BlasLtGemmRunner(stream->parent())); + xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce( + *res->autotune_config_); + } + return *res; +} + +template +xla::StatusOr BlasLtGemmRunner::Autotune( + const std::vector& algorithms, + TuneFunc&& benchmark_func) { + gpu::BlasLt::MatmulAlgorithm best_algo; + float best_ms = std::numeric_limits::max(), total_ms = 0; + uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0; + + for (uint32_t j = 0; j < algorithms.size(); j++) { + const auto& algo = algorithms[j]; + if (!benchmark_func(algo, nullptr).ok()) continue; + + blas::ProfileResult profile; + for (i = 0, total_ms = 0; i < n_total; i++) { + auto res = benchmark_func(algo, &profile); + if (!res.ok() || !profile.is_valid()) { + VLOG(1) << j + << ": gemm algorithm is not valid: " /* << res.error_message() */; + break; + } + if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms(); + } + if (i < n_total) continue; // invalid algorithm + total_ms /= n_iters; + VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took " + << total_ms << "ms, workspace: " << algo.workspace_size; + if (total_ms < best_ms) { + best_ms = total_ms, best_algo = algo; + } + } // for algorithms + if (!best_algo.opaque_algo.has_value()) { + return xla::InternalError("No valid gemm algorithms found!"); + } + return best_algo; +} + +xla::StatusOr> BlasLtGemmRunner::ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64_t batch_count) { + uint64_t bsa = 0, bsb = 0, bsc = 0; + using CT = const uint8_t; + for (int64_t i = 0; i < batch_count - 1; i++) { + uint64_t da = (CT*)a[i + 1]->opaque() - (CT*)a[i]->opaque(), + db = (CT*)b[i + 1]->opaque() - (CT*)b[i]->opaque(), + dc = (CT*)c[i + 1]->opaque() - (CT*)c[i]->opaque(); + if (i == 0) { + bsa = da, bsb = db, bsc = dc; + } else if (!(bsa == da && bsb == db && bsc == dc)) { // strides mismatch + return xla::InternalError("Strides are not consistent!"); + } + } + return std::array{bsa, bsb, bsc}; +} + +xla::Status BlasLtGemmRunner::RunBatchedImpl( + Stream& stream, blas::Transpose trans_a, blas::Transpose trans_b, int64_t m, + int64_t n, int64_t k, const void* alpha, blas::DataType type_a, + const void** a, int64_t lda, blas::DataType type_b, const void** b, + int64_t ldb, const void* beta, blas::DataType type_c, void** c, int64_t ldc, + int64_t batch_count, ScratchAllocator* allocator) { + TF_ASSIGN_OR_RETURN(auto compute_type, + gpu::GetBlasComputationType(type_a, type_c, 0)); + + GroupedGemmConfig cfg{ + .m = (int64_t)m, + .n = (int64_t)n, + .k = (int64_t)k, + .batch_count = (int64_t)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .type_d = type_c, + .lda = (int64_t)lda, + .ldb = (int64_t)ldb, + .ldc = (int64_t)ldc, + .ldd = (int64_t)ldc, + .compute_type = compute_type, + .a = a, + .b = b, + .c = const_cast(c), + .d = c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = grouped_gemm_map_.find(cfg); + if (res == grouped_gemm_map_.end()) { + // NOTE: we assume that pointers a,b,c come from the device mem + // hence we need to block stream here + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::CreateGroupedMatmulPlan(&stream, cfg)); + res = grouped_gemm_map_.emplace(cfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN( + auto algorithms, + res->second->GetAlgorithms(num_solutions, + allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << stream.parent() + << ": new GGemm config: " << grouped_gemm_map_.size() + << " #valid algorithms: " << algorithms.size(); + + BlasLt::MatmulAlgorithm best_algo; + if (!autotune_enabled_) { + if (algorithms.empty()) + return xla::InternalError("No GG algorithms found!"); + best_algo = algorithms[0]; // otherwise use default algorithm + } else { + TF_ASSIGN_OR_RETURN( + auto best_algo, + Autotune(algorithms, [&](const gpu::BlasLt::MatmulAlgorithm& algo, + blas::ProfileResult* profile) { + if (profile == nullptr) { + return res->second->SetAlgorithm(algo, allocator); + } + return res->second->ExecuteOnStream(&stream, cfg, profile); + })); + } + TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator)); + } + return res->second->ExecuteOnStream(&stream, cfg); +} + +xla::Status BlasLtGemmRunner::RunStridedBatchedImpl( + Stream& stream, blas::Transpose trans_a, blas::Transpose trans_b, int64_t m, + int64_t n, int64_t k, xla::complex128 alpha, blas::DataType type_a, + const DeviceMemoryBase& a, int64_t lda, int64_t stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64_t ldb, + int64_t stride_b, double beta, blas::DataType type_c, DeviceMemoryBase* c, + int64_t ldc, int64_t stride_c, int64_t batch_count, + ScratchAllocator* allocator) { + StridedGemmConfig scfg{ + .m = m, + .n = n, + .k = k, + .batch_count = (int64_t)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .lda = lda, + .ldb = ldb, + .ldc = ldc, + .stride_a = stride_a, + .stride_b = stride_b, + .stride_c = stride_c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = strided_gemm_map_.find(scfg); + while (res == strided_gemm_map_.end()) { + int64_t row_a = m, col_a = k, row_b = k, col_b = n; + if (trans_a == blas::Transpose::kTranspose) std::swap(row_a, col_a); + if (trans_b == blas::Transpose::kTranspose) std::swap(row_b, col_b); + + auto order = MatrixLayout::Order::kColumnMajor; + GemmConfig cfg = { + .lhs_layout = MatrixLayout(type_a, row_a, col_a, order, batch_count, + lda, stride_a, trans_a), + + .rhs_layout = MatrixLayout(type_b, row_b, col_b, order, batch_count, + ldb, stride_b, trans_b), + + .c_layout = + MatrixLayout(type_c, m, n, order, batch_count, ldc, stride_c), + .output_layout = + MatrixLayout(type_c, m, n, order, batch_count, ldc, stride_c), + .alpha = alpha, + .beta = beta, + .compute_precision = -1, + .epilogue = gpu::BlasLt::Epilogue::kDefault, + }; + + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::GetMatmulPlan(&stream, cfg)); + res = strided_gemm_map_.emplace(scfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN( + auto algorithms, + res->second->GetAlgorithms(num_solutions, + allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << &stream << " dev " << stream.parent() << '(' + << stream.parent()->device_ordinal() + << "): new StridedBatched config: " << strided_gemm_map_.size() + << " #algorithms: " << algorithms.size(); + + if (!autotune_enabled_) { + if (algorithms.empty()) + return xla::InternalError("No algorithms found!"); + res->second->SetAlgorithm(algorithms[0]); + break; + } + + BlasLt::MatmulAlgorithm best_algo{.id = blas::kNoAlgorithm}; + xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/ false)); + auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + auto id = *opt_res; + for (const auto& algo : algorithms) { + if (algo.id == id) best_algo = algo; + } + if (best_algo.id == blas::kNoAlgorithm) { + LOG(WARNING) << "Best algorithm not valid: need to autotune.."; + } + } + + if (best_algo.id == blas::kNoAlgorithm) { + TF_ASSIGN_OR_RETURN( + best_algo, + Autotune(algorithms, [&](const gpu::BlasLt::MatmulAlgorithm& algo, + blas::ProfileResult* profile) { + if (profile == nullptr) { + return res->second->SetAlgorithm(algo); + } + return res->second->ExecuteOnStream(&stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator, // allocator + profile); + })); + xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id; + // reread algorithm ID from cache again (in case some other thread has + // already added this config to the cache to be sure we use the same ID) + auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache( + key, ares, *autotune_config_); + + if (new_id != best_algo.id) { + for (const auto& algo : algorithms) { + if (algo.id == new_id) best_algo = algo; + } + } + } // best_algo.id == blas::kNoAlgorithm + + res->second->SetAlgorithm(best_algo); + break; + } // while + return res->second->ExecuteOnStream(&stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator); // allocator +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h new file mode 100644 index 00000000000000..dedf842875ec23 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h @@ -0,0 +1,255 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/util.h" + +using tensorflow::gtl::ArraySlice; + +namespace xla { +namespace gpu { +class AutotuneConfig; +} +} // namespace xla + +namespace stream_executor { + +namespace gpu { + +struct StridedGemmConfig { + int64_t m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + xla::complex128 alpha; + double beta; + blas::DataType type_a, type_b, type_c; + int64_t lda, ldb, ldc; + int64_t stride_a, stride_b, stride_c; +}; + +namespace { + +auto AsTuple(const GroupedGemmConfig& p) { + // NOTE: alpha, beta and data pointers are not included in cache !! + return std::make_tuple(p.m, p.n, p.k, p.batch_count, p.trans_a, p.trans_b, + p.type_a, p.type_b, p.type_c, p.type_d, p.lda, p.ldb, + p.ldc, p.ldd, p.compute_type); +} + +auto AsTuple(const StridedGemmConfig& p) { + return std::make_tuple(p.m, p.n, p.k, p.batch_count, p.trans_a, p.trans_b, + p.alpha.real(), p.alpha.imag(), p.beta, p.type_a, + p.type_b, p.type_c, p.lda, p.ldb, p.ldc, p.stride_a, + p.stride_b, p.stride_c); +} + +} // namespace + +bool operator==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs); +bool operator==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs); + +template +H AbslHashValue(H h, const GroupedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +template +H AbslHashValue(H h, const StridedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +struct BlasLtGemmRunner { + static BlasLtGemmRunner& i(const Stream* stream); + + template + xla::complex128 Convert(Scalar x) { + if constexpr (std::is_same::value || + std::is_same::value) { + return static_cast(x); + } else { + return static_cast(x); + } + } + + template + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64_t m, int64_t n, int64_t k, + Scalar alpha, const DeviceMemory& a, int64_t lda, + const DeviceMemory& b, int64_t ldb, Scalar beta, + DeviceMemory* c, int64_t ldc, + ScratchAllocator* allocator) { + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, a, lda, 0, + type_b, b, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, 0, 1, allocator); + } + + template + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64_t m, int64_t n, int64_t k, + Scalar alpha, const TypeA* a, int64_t lda, const TypeB* b, + int64_t ldb, Scalar beta, TypeC* c, int64_t ldc, + ScratchAllocator* allocator) { + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, 0, type_b, + DeviceMemoryBase{const_cast(b)}, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, 0, 1, allocator); + } + + template + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64_t m, int64_t n, + int64_t k, Scalar alpha, const TypeA* a, + int64_t lda, int64_t stride_a, const TypeB* b, + int64_t ldb, int64_t stride_b, Scalar beta, + TypeC* c, int64_t ldc, int64_t stride_c, + int64_t batch_count, + ScratchAllocator* allocator) { + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, stride_a, type_b, + DeviceMemoryBase{const_cast(a)}, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, stride_c, batch_count, allocator); + } + + template + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64_t m, int64_t n, + int64_t k, Scalar alpha, + const DeviceMemory& a, int64_t lda, + int64_t stride_a, const DeviceMemory& b, + int64_t ldb, int64_t stride_b, Scalar beta, + DeviceMemory* c, int64_t ldc, + int64_t stride_c, int64_t batch_count, + ScratchAllocator* allocator) { + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, a, lda, + stride_a, type_b, b, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, stride_c, batch_count, allocator); + } + + template + xla::Status RunBatched( + Stream& stream, blas::Transpose trans_a, blas::Transpose trans_b, + int64_t m, int64_t n, int64_t k, Scalar alpha, + const ArraySlice*>& a, int64_t lda, + const ArraySlice*>& b, int64_t ldb, Scalar beta, + const ArraySlice*>& c, int64_t ldc, + int64_t batch_count, ScratchAllocator* allocator) { + // NOTE: Scalar types shall be verified for correctness vs T!! + auto type = dnn::ToDataType::value; + auto cvt = [](auto x) { + using TT = ArraySlice; + auto ptr = reinterpret_cast(&x); + return *reinterpret_cast(ptr); + }; + + auto res = ContiguousStrides(cvt(a), cvt(b), cvt(c), batch_count); + if (res.ok()) { + auto strides = std::move(res.ValueOrDie()); + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type, *a[0], lda, + strides[0] / sizeof(T), type, *b[0], ldb, strides[1] / sizeof(T), + Convert(beta).real(), // only real betas are supported!! + type, c[0], ldc, strides[2] / sizeof(T), batch_count, allocator); + } + return xla::InternalError("RunBatched: ArraySlice NYI!"); + } + + template + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, uint64 m, uint64 n, uint64 k, + Scalar alpha, const T** a, int lda, const T** b, + int ldb, Scalar beta, T** c, int64_t ldc, + int64_t batch_count, ScratchAllocator* allocator) { + auto type = dnn::ToDataType::value; + return RunBatchedImpl(stream, trans_a, trans_b, m, n, k, &alpha, type, + reinterpret_cast(a), lda, type, + reinterpret_cast(b), ldb, &beta, type, + reinterpret_cast(c), ldc, batch_count, + allocator); + } + + ~BlasLtGemmRunner(); + BlasLtGemmRunner& operator=(BlasLtGemmRunner&& rhs) noexcept = default; + BlasLtGemmRunner(BlasLtGemmRunner&& rhs) noexcept = default; + + private: + explicit BlasLtGemmRunner(StreamExecutor* parent); + + template + xla::StatusOr Autotune( + const std::vector& algorithms, + TuneFunc&& benchmark_func); + + xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64_t m, int64_t n, + int64_t k, const void* alpha, + blas::DataType type_a, const void** a, int64_t lda, + blas::DataType type_b, const void** b, int64_t ldb, + const void* beta, blas::DataType type_c, void** c, + int64_t ldc, int64_t batch_count, + ScratchAllocator* allocator); + + xla::Status RunStridedBatchedImpl( + Stream& stream, blas::Transpose trans_a, blas::Transpose trans_b, + int64_t m, int64_t n, int64_t k, xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64_t lda, + int64_t stride_a, blas::DataType type_b, const DeviceMemoryBase& b, + int64_t ldb, int64_t stride_b, double beta, blas::DataType type_c, + DeviceMemoryBase* c, int64_t ldc, int64_t stride_c, int64_t batch_count, + ScratchAllocator* allocator); + + xla::StatusOr> ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64_t batch_count); + + static bool autotune_enabled_; + std::unique_ptr mutex_; + std::unique_ptr autotune_config_; + absl::flat_hash_map + grouped_gemm_map_; + absl::flat_hash_map + strided_gemm_map_; +}; + +} // namespace gpu + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_