From 3f7e238cdb001d9ac9a2c8ee4944db492060e69d Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Fri, 25 Oct 2024 15:09:41 +0000 Subject: [PATCH] gpu_blas_lt_gemm_runner --- tensorflow/compiler/xla/stream_executor/BUILD | 1 + .../compiler/xla/stream_executor/gpu/BUILD | 14 +- .../gpu/gpu_blas_lt_gemm_runner.cc | 341 ++++++++++++++++++ .../gpu/gpu_blas_lt_gemm_runner.h | 260 +++++++++++++ 4 files changed, 609 insertions(+), 7 deletions(-) create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD index 0425da4aea423a..34f7034934856a 100644 --- a/tensorflow/compiler/xla/stream_executor/BUILD +++ b/tensorflow/compiler/xla/stream_executor/BUILD @@ -450,6 +450,7 @@ tsl_gpu_library( ":temporary_memory_manager", ":timer", "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_blas_lt_gemm_runner", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index 1c327bb60ca64f..9479d3cae5eb80 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -87,13 +87,13 @@ cc_library( srcs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.cc"]), hdrs = if_gpu_is_configured(["gpu_blas_lt_gemm_runner.h"]), deps = if_gpu_is_configured([ - "//tensorflow/core:autotuning_proto_cc", - "//tensorflow/core:autotune_results_proto_cc", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/stream_executor:scratch_allocator", - "//tensorflow/compiler/xla/service/gpu:autotuner_util", - "//tensorflow/compiler/xla:debug_options_flags", - ":gpu_blas_lt", + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//tensorflow/compiler/xla:autotune_results_proto_cc", + # "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/stream_executor:scratch_allocator", + "//tensorflow/compiler/xla/service/gpu:autotuner_util", + "//tensorflow/compiler/xla:debug_options_flags", + ":gpu_blas_lt", ]), ) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc new file mode 100644 index 00000000000000..8693b0e42300f8 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.cc @@ -0,0 +1,341 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/service/gpu/autotuner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" + +namespace stream_executor { +namespace gpu { + +bool BlasLtGemmRunner::autotune_enabled_ = true; + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs) { + return AsTuple(rhs) == AsTuple(lhs); +} + +std::ostream& operator <<(std::ostream& os, const StridedGemmConfig& cfg) { + return os << "trans_a/b: " << (int)cfg.trans_a << "/" << (int)cfg.trans_b << + " m: " << cfg.m << " n: " << cfg.n << " k: " << cfg.k << + " batch_count: " << cfg.batch_count << + " lda: " << cfg.lda << " ldb: " << cfg.ldb << " ldc: " << cfg.ldc << + " stride_a: " << cfg.stride_a << " stride_b: " << cfg.stride_b << + " stride_c: " << cfg.stride_c << + " type_a: " << (int)cfg.type_a << " type_b: " << (int)cfg.type_b << + " type_c: " << (int)cfg.type_c << + " alpha: " << cfg.alpha << " beta: " << cfg.beta; +} + +BlasLtGemmRunner::BlasLtGemmRunner(StreamExecutor *parent) : + mutex_(std::make_unique< absl::Mutex >()), + autotune_config_(std::make_unique< xla::gpu::AutotuneConfig >( + xla::gpu::DeviceConfig{parent, nullptr}, + xla::GetDebugOptionsFromFlags())) + { } + +BlasLtGemmRunner::~BlasLtGemmRunner() { } + + +/*static*/ BlasLtGemmRunner& BlasLtGemmRunner::i(const Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets a different cache instance + static std::vector> meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) { + autotune_enabled_ = xla::GetDebugOptionsFromFlags().xla_gpu_autotune_level() > 0; + res.reset(new BlasLtGemmRunner(stream->parent())); + xla::gpu::AutotunerUtil::LoadAutotuneResultsFromFileOnce(*res->autotune_config_); + } + return *res; +} + +template < class TuneFunc > +xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > BlasLtGemmRunner::Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func) { + gpu::BlasLt::MatmulAlgorithm best_algo; + float best_ms = std::numeric_limits< float >::max(), total_ms = 0; + uint32_t n_warmups = 1, n_iters = 5, n_total = n_warmups + n_iters, i = 0; + + for (uint32_t j = 0; j < algorithms.size(); j++) { + const auto& algo = algorithms[j]; + if (!benchmark_func(algo, nullptr).ok()) continue; + + blas::ProfileResult profile; + for (i = 0, total_ms = 0; i < n_total; i++) { + auto res = benchmark_func(algo, &profile); + if (!res.ok() || !profile.is_valid()) { + VLOG(1) << j << ": gemm algorithm is not valid: " /* << res.error_message() */; + break; + } + if (i >= n_warmups) total_ms += profile.elapsed_time_in_ms(); + } + if (i < n_total) continue; // invalid algorithm + total_ms /= n_iters; + VLOG(2) << j << ": gemm algorithm " << profile.algorithm() << " took " + << total_ms << "ms, workspace: " << algo.workspace_size; + if (total_ms < best_ms) { + best_ms = total_ms, best_algo = algo; + } + } // for algorithms + if (!best_algo.opaque_algo.has_value()) { + return xla::InternalError("No valid gemm algorithms found!"); + } + return best_algo; +} + +xla::StatusOr< std::array< uint64_t, 3 >> BlasLtGemmRunner::ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count) { + + uint64_t bsa = 0, bsb = 0, bsc = 0; + using CT = const uint8_t; + for(int64 i = 0; i < batch_count-1; i++) { + uint64_t da = (CT *)a[i + 1]->opaque() - (CT *)a[i]->opaque(), + db = (CT *)b[i + 1]->opaque() - (CT *)b[i]->opaque(), + dc = (CT *)c[i + 1]->opaque() - (CT *)c[i]->opaque(); + if(i == 0) { + bsa = da, bsb = db, bsc = dc; + } else if(!(bsa == da && bsb == db && bsc == dc)) { // strides mismatch + return xla::InternalError("Strides are not consistent!"); + } + } + return std::array< uint64_t, 3 >{ bsa, bsb, bsc }; +} + +xla::Status BlasLtGemmRunner::RunBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator) +{ + + TF_ASSIGN_OR_RETURN(auto compute_type, + gpu::GetBlasComputationType(type_a, type_c, 0)); + + GroupedGemmConfig cfg{ + .m = (int64)m, + .n = (int64)n, + .k = (int64)k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .type_d = type_c, + .lda = (int64)lda, + .ldb = (int64)ldb, + .ldc = (int64)ldc, + .ldd = (int64)ldc, + .compute_type = compute_type, + .a = a, + .b = b, + .c = const_cast< const void **>(c), + .d = c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = grouped_gemm_map_.find(cfg); + if (res == grouped_gemm_map_.end()) { + // NOTE: we assume that pointers a,b,c come from the device mem + // hence we need to block stream here + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::CreateGroupedMatmulPlan(&stream, cfg)); + res = grouped_gemm_map_.emplace(cfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << stream.parent() << ": new GGemm config: " << + grouped_gemm_map_.size() << " #valid algorithms: " << algorithms.size(); + + BlasLt::MatmulAlgorithm best_algo; + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No GG algorithms found!"); + best_algo = algorithms[0]; // otherwise use default algorithm + } else { + TF_ASSIGN_OR_RETURN(auto best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo, allocator); + } + return res->second->ExecuteOnStream(&stream, cfg, profile); + })); + } + TF_RETURN_IF_ERROR(res->second->SetAlgorithm(best_algo, allocator)); + } + return res->second->ExecuteOnStream(&stream, cfg); +} + +xla::Status BlasLtGemmRunner::RunStridedBatchedImpl(Stream& stream, + blas::Transpose trans_a, blas::Transpose trans_b, int64 m, int64 n, int64 k, + xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) +{ + StridedGemmConfig scfg{ + .m = m, + .n = n, + .k = k, + .batch_count = (int64)batch_count, + .trans_a = trans_a, + .trans_b = trans_b, + .alpha = alpha, + .beta = beta, + .type_a = type_a, + .type_b = type_b, + .type_c = type_c, + .lda = lda, + .ldb = ldb, + .ldc = ldc, + .stride_a = stride_a, + .stride_b = stride_b, + .stride_c = stride_c, + }; + + absl::MutexLock lock(mutex_.get()); + + auto res = strided_gemm_map_.find(scfg); + while (res == strided_gemm_map_.end()) { + int64 row_a = m, col_a = k, row_b = k, col_b = n; + if (trans_a == blas::Transpose::kTranspose) std::swap(row_a, col_a); + if (trans_b == blas::Transpose::kTranspose) std::swap(row_b, col_b); + + auto order = MatrixLayout::Order::kColumnMajor; + GemmConfig cfg = { + .lhs_layout = MatrixLayout(type_a, row_a, col_a, order, batch_count, + lda, stride_a, trans_a), + + .rhs_layout = MatrixLayout(type_b, row_b, col_b, order, batch_count, + ldb, stride_b, trans_b), + + .c_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .output_layout = MatrixLayout(type_c, m, n, order, batch_count, + ldc, stride_c), + .alpha = alpha, + .beta = beta, + .compute_precision = -1, + .epilogue = gpu::BlasLt::Epilogue::kDefault, + }; + + TF_ASSIGN_OR_RETURN(auto plan_res, + gpu::BlasLt::GetMatmulPlan(&stream, cfg)); + res = strided_gemm_map_.emplace(scfg, std::move(plan_res)).first; + + size_t num_solutions = autotune_enabled_ ? gpu::BlasLt::kMaxAlgorithms : 1; + // discard solutions with non-zero workspace if allocator is not given + TF_ASSIGN_OR_RETURN(auto algorithms, res->second->GetAlgorithms( + num_solutions, allocator == nullptr ? 0 : 1ull << 32)); + + VLOG(1) << &stream << " dev " << stream.parent() << '(' << + stream.parent()->device_ordinal() << "): new StridedBatched config: " + << strided_gemm_map_.size() << " #algorithms: " << algorithms.size(); + + if (!autotune_enabled_) { + if (algorithms.empty()) return xla::InternalError("No algorithms found!"); + res->second->SetAlgorithm(algorithms[0]); + break; + } + + BlasLt::MatmulAlgorithm best_algo{ .id = blas::kNoAlgorithm }; + xla::gpu::AutotuneCacheKey key(ToCSVString(cfg, /*full_string*/false)); + auto opt_res = xla::gpu::AutotunerUtil::TryToFindInInMemoryCache(key); + if (opt_res.has_value()) { + auto id = *opt_res; + for (const auto& algo : algorithms) { + if (algo.id == id) best_algo = algo; + } + if (best_algo.id == blas::kNoAlgorithm) { + LOG(WARNING) << "Best algorithm not valid: need to autotune.."; + } + } + + if (best_algo.id == blas::kNoAlgorithm) { + TF_ASSIGN_OR_RETURN(best_algo, Autotune(algorithms, + [&](const gpu::BlasLt::MatmulAlgorithm& algo, blas::ProfileResult *profile){ + if (profile == nullptr) { + return res->second->SetAlgorithm(algo); + } + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator, // allocator + profile); + })); + xla::gpu::AutotunerUtil::CacheValue ares = best_algo.id; + // reread algorithm ID from cache again (in case some other thread has + // already added this config to the cache to be sure we use the same ID) + auto new_id = xla::gpu::AutotunerUtil::AddResultToInMemoryCache(key, ares, + *autotune_config_); + + if (new_id != best_algo.id) { + for (const auto& algo : algorithms) { + if (algo.id == new_id) best_algo = algo; + } + } + } // best_algo.id == blas::kNoAlgorithm + + res->second->SetAlgorithm(best_algo); + break; + } // while + return res->second->ExecuteOnStream( + &stream, a, b, *c, *c, + DeviceMemoryBase{}, // bias + DeviceMemoryBase{}, // aux + DeviceMemoryBase{}, // a_scale + DeviceMemoryBase{}, // b_scale + DeviceMemoryBase{}, // c_scale + DeviceMemoryBase{}, // d_scale + DeviceMemoryBase{}, // d_amax + absl::nullopt, // workspace + allocator); // allocator +} + +} // namespace gpu + +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h new file mode 100644 index 00000000000000..2cda507bb8f9d1 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/util.h" + +using tensorflow::gtl::ArraySlice; +typedef ::std::int64_t int64; + + +namespace xla { +namespace gpu { +class AutotuneConfig; +} +} + +namespace stream_executor { + +namespace gpu { + +struct StridedGemmConfig { + int64 m, n, k, batch_count; + blas::Transpose trans_a, trans_b; + xla::complex128 alpha; + double beta; + blas::DataType type_a, type_b, type_c; + int64 lda, ldb, ldc; + int64 stride_a, stride_b, stride_c; +}; + +namespace { + +auto AsTuple(const GroupedGemmConfig& p) { + // NOTE: alpha, beta and data pointers are not included in cache !! + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, + p.type_a, p.type_b, p.type_c, p.type_d, + p.lda, p.ldb, p.ldc, p.ldd, + p.compute_type); +} + +auto AsTuple(const StridedGemmConfig& p) { + return std::make_tuple(p.m, p.n, p.k, p.batch_count, + p.trans_a, p.trans_b, p.alpha.real(), p.alpha.imag(), p.beta, + p.type_a, p.type_b, p.type_c, + p.lda, p.ldb, p.ldc, + p.stride_a, p.stride_b, p.stride_c); +} + +} // namespace + +bool operator ==(const GroupedGemmConfig& rhs, const GroupedGemmConfig& lhs); +bool operator ==(const StridedGemmConfig& rhs, const StridedGemmConfig& lhs); + +template +H AbslHashValue(H h, const GroupedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +template +H AbslHashValue(H h, const StridedGemmConfig& params) { + return H::combine(std::move(h), AsTuple(params)); +} + +struct BlasLtGemmRunner { + + static BlasLtGemmRunner& i(const Stream *stream); + + template < class Scalar > + xla::complex128 Convert(Scalar x) { + if constexpr(std::is_same::value || + std::is_same::value) { + return static_cast< xla::complex128 >(x); + } else { + return static_cast< double >(x); + } + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, + const DeviceMemory& b, int64 ldb, + Scalar beta, DeviceMemory *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, 0, type_b, b, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, 0, 1, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC > + xla::Status Run(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, + const TypeB *b, int64 ldb, + Scalar beta, TypeC *c, int64 ldc, + ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, DeviceMemoryBase{const_cast< TypeA *>(a)}, lda, 0, + type_b, DeviceMemoryBase{const_cast< TypeB *>(b)}, ldb, 0, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, 0, 1, allocator); + } + + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const TypeA* a, int64 lda, int64 stride_a, + const TypeB* b, int64 ldb, int64 stride_b, + Scalar beta, TypeC* c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + DeviceMemoryBase mem_c{c}; + return RunStridedBatchedImpl( + stream, trans_a, trans_b, m, n, k, Convert(alpha), type_a, + DeviceMemoryBase{const_cast(a)}, lda, stride_a, type_b, + DeviceMemoryBase{const_cast(a)}, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, &mem_c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class TypeA, class TypeB, class TypeC> + xla::Status RunStridedBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + Scalar alpha, const DeviceMemory& a, int64 lda, int64 stride_a, + const DeviceMemory& b, int64 ldb, int64 stride_b, + Scalar beta, DeviceMemory *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator) { + + auto type_a = dnn::ToDataType::value, + type_b = dnn::ToDataType::value, + type_c = dnn::ToDataType::value; + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type_a, a, lda, stride_a, type_b, b, ldb, stride_b, + Convert(beta).real(), // only real betas are supported!! + type_c, c, ldc, stride_c, batch_count, allocator); + } + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, Scalar alpha, + const ArraySlice *> &a, int64 lda, + const ArraySlice *> &b, int64 ldb, Scalar beta, + const ArraySlice *> &c, int64 ldc, + int64 batch_count, ScratchAllocator* allocator) { + + // NOTE: Scalar types shall be verified for correctness vs T!! + auto type = dnn::ToDataType::value; + auto cvt = [](auto x){ + using TT = ArraySlice; + auto ptr = reinterpret_cast(&x); + return *reinterpret_cast(ptr); + }; + + auto res = ContiguousStrides(cvt(a), cvt(b), cvt(c), batch_count); + if (res.ok()) { + auto strides = std::move(res.ValueOrDie()); + return RunStridedBatchedImpl(stream, trans_a, trans_b, m, n, k, + Convert(alpha), + type, *a[0], lda, strides[0] / sizeof(T), + type, *b[0], ldb, strides[1] / sizeof(T), + Convert(beta).real(), // only real betas are supported!! + type, c[0], ldc, strides[2] / sizeof(T), batch_count, allocator); + } + return xla::InternalError("RunBatched: port::ArraySlice NYI!"); + } + + + template < class Scalar, class T > + xla::Status RunBatched(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, uint64 m, uint64 n, uint64 k, + Scalar alpha, const T** a, int lda, + const T** b, int ldb, Scalar beta, + T** c, int64 ldc, int64 batch_count, ScratchAllocator* allocator){ + + auto type = dnn::ToDataType::value; + return RunBatchedImpl(stream, trans_a, trans_b, m, n, k, + &alpha, type, reinterpret_cast< const void **>(a), lda, + type, reinterpret_cast< const void **>(b), ldb, &beta, + type, reinterpret_cast< void **>(c), ldc, batch_count, allocator); + } + + ~BlasLtGemmRunner(); + BlasLtGemmRunner& operator=(BlasLtGemmRunner&& rhs) noexcept = default; + BlasLtGemmRunner(BlasLtGemmRunner&& rhs) noexcept = default; + +private: + explicit BlasLtGemmRunner(StreamExecutor *parent); + + template < class TuneFunc > + xla::StatusOr< gpu::BlasLt::MatmulAlgorithm > Autotune( + const std::vector< gpu::BlasLt::MatmulAlgorithm >& algorithms, + TuneFunc&& benchmark_func); + + + xla::Status RunBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, + const void *alpha, blas::DataType type_a, const void** a, int64 lda, + blas::DataType type_b, const void** b, int64 ldb, const void *beta, + blas::DataType type_c, void** c, int64 ldc, int64 batch_count, + ScratchAllocator* allocator); + + xla::Status RunStridedBatchedImpl(Stream& stream, blas::Transpose trans_a, + blas::Transpose trans_b, int64 m, int64 n, int64 k, xla::complex128 alpha, + blas::DataType type_a, const DeviceMemoryBase& a, int64 lda, int64 stride_a, + blas::DataType type_b, const DeviceMemoryBase& b, int64 ldb, int64 stride_b, + double beta, + blas::DataType type_c, DeviceMemoryBase *c, int64 ldc, int64 stride_c, + int64 batch_count, ScratchAllocator* allocator); + + xla::StatusOr< std::array< uint64_t, 3 >> ContiguousStrides( + const ArraySlice& a, + const ArraySlice& b, + const ArraySlice& c, int64 batch_count); + + static bool autotune_enabled_; + std::unique_ptr< absl::Mutex > mutex_; + std::unique_ptr< xla::gpu::AutotuneConfig > autotune_config_; + absl::flat_hash_map grouped_gemm_map_; + absl::flat_hash_map strided_gemm_map_; +}; + +} // namespace gpu + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_BLAS_LT_GEMM_RUNNER_H_