From e2b918d57b7c0b353bdda56db4809d5a6fd407db Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Mon, 28 Oct 2024 07:39:42 -0500 Subject: [PATCH] buffer init fix and gpu_hlo_runner test --- xla/service/gpu/BUILD | 16 ++- xla/service/gpu/kernels/BUILD | 6 +- xla/service/gpu/stream_executor_util.cc | 2 - xla/service/gpu/tests/BUILD | 33 +++++ xla/service/gpu/tests/gpu_hlo_runner_test.cc | 130 +++++++++++++++++++ 5 files changed, 179 insertions(+), 8 deletions(-) create mode 100644 xla/service/gpu/tests/gpu_hlo_runner_test.cc diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 8997fc9e44cc9..cec998a457f18 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -4408,10 +4408,15 @@ xla_cc_test( ], ) -cuda_library( +gpu_kernel_library( name = "stream_executor_util_kernel", - srcs = if_cuda_is_configured(["stream_executor_util_kernel.cu.cc"]), - deps = ["@local_config_cuda//cuda:cuda_headers"], + srcs = ["stream_executor_util_kernel.cu.cc"], + tags = ["gpu"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( @@ -4423,7 +4428,6 @@ cc_library( deps = [ ":cublas_cudnn", ":launch_dimensions", - ":stream_executor_util_kernel", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:util", @@ -4453,7 +4457,9 @@ cc_library( "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", - ], + ] + if_gpu_is_configured([ + ":stream_executor_util_kernel", + ]), ) xla_cc_test( diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 4652f8cafd8cb..43fc69b936945 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -167,7 +167,11 @@ gpu_kernel_library( "//xla:types", "//xla/stream_executor/gpu:gpu_types_header", "@tsl//tsl/lib/math:math_util", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_test( diff --git a/xla/service/gpu/stream_executor_util.cc b/xla/service/gpu/stream_executor_util.cc index c278fbe8b4ff3..df7444656b7db 100644 --- a/xla/service/gpu/stream_executor_util.cc +++ b/xla/service/gpu/stream_executor_util.cc @@ -493,7 +493,6 @@ static void InitializeTypedBuffer(se::Stream* stream, // Nothing more to do return; } -#ifdef GOOGLE_CUDA // Repeat the host_buffer_size elements at the start of `buf` to the end CHECK_EQ(elements_to_fill, buffer.size() / sizeof(T) - host_buffer_size); se::StreamExecutor* executor = stream->parent(); @@ -514,7 +513,6 @@ static void InitializeTypedBuffer(se::Stream* stream, se::BlockDim(blocks_per_grid, 1, 1), *kernel, buffer, host_buffer_bytes, static_cast(buffer.size()))); -#endif } void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 8f8317f88114b..0983f508ab5a4 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -215,6 +215,39 @@ xla_test( ], ) +xla_test( + name = "gpu_hlo_runner_test", + srcs = ["gpu_hlo_runner_test.cc"], + backends = ["gpu"], + deps = [ + ":gpu_codegen_test", + "//xla:error_spec", + "//xla:test", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gemm_rewriter", + "//xla/service/gpu:gpu_executable", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tests:filecheck", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ] +) + xla_test( name = "gemm_rewrite_test", srcs = ["gemm_rewrite_test.cc"], diff --git a/xla/service/gpu/tests/gpu_hlo_runner_test.cc b/xla/service/gpu/tests/gpu_hlo_runner_test.cc new file mode 100644 index 0000000000000..af28627561f95 --- /dev/null +++ b/xla/service/gpu/tests/gpu_hlo_runner_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "xla/error_spec.h" +#include "xla/literal_comparison.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_utils.h" + +namespace xla { +namespace gpu { + +template +std::vector MakePointerVector(std::vector& input_vec) { + std::vector output_pointers; + output_pointers.reserve(input_vec.size()); + for (auto& input : input_vec) { + output_pointers.push_back(&input); + } + return output_pointers; +} + + +class HloRunnerTest : public GpuCodegenTest {}; + +TEST_F(HloRunnerTest, RunSingle) { + + std::ifstream ifs("input.hlo"); + ASSERT_TRUE(ifs.good()); + + std::stringstream buffer; + buffer << ifs.rdbuf(); + + HloModuleConfig config = GetModuleConfigForTest(); +#if 1 + //config.set_num_partitions(8); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(buffer.str(), + config)); + + auto ref_module = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, test_runner_.CreateExecutable(std::move(module), true)); + + VLOG(0) << "Creating fake args.."; + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, xla::MakeFakeArguments(ref_module.get(), + true, /*pseudo-random*/ + false /* use large range*/)); + auto arg_ptrs = MakePointerVector(fake_arguments); + + auto& ref_runner = HloTestBase::reference_runner_; + TF_ASSERT_OK_AND_ASSIGN( + auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true)); + + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ReadLiteralFromProto("/tf/xla/expected.pb")); + // TF_ASSERT_OK_AND_ASSIGN(auto truth, + // ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs, nullptr)); + // WriteLiteralToTempFile(truth, "expected"); + //VLOG(0) << "Got expected literal from file.. running test"; + + TF_ASSERT_OK_AND_ASSIGN( + auto test_res, test_runner_.ExecuteWithExecutable(exec.get(), arg_ptrs)); + + VLOG(0) << "Running reference exec.."; + TF_ASSERT_OK_AND_ASSIGN( + auto truth, ref_runner.ExecuteWithExecutable(ref_exec.get(), arg_ptrs)); + + ErrorSpec error_spec{1e-2, 1e-3}; + //ErrorSpec error_spec(1e-5 /*abs*/, 1e-5 /*rel*/); + ASSERT_EQ(literal_comparison::Near(/*expected=*/truth, + /*actual=*/test_res, + /*error=*/error_spec, + /*detailed_message=*/true, {}), absl::OkStatus()); + + // EXPECT_TRUE(RunAndCompare(std::move(module), + // // absl::Span< xla::Literal * const>(arg_ptrs.data(), arg_ptrs.size()), error_spec)); +#else + int NumReplicas = 8, NumParts = 1; + config.set_replica_count(NumReplicas); + config.set_num_partitions(NumParts); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(buffer.str(), config)); + DeviceAssignment assn(/*replica_count=*/NumReplicas, + /*computation_count=*/NumParts); + for (int64_t i = 0, k = 0; i < NumReplicas; i++) + for (int64_t j = 0; j < NumParts; j++) { + assn(i, j) = k++; + } + + auto fake_arguments = xla::MakeFakeArguments( + module.get(), + true, /*pseudo-random*/ + false /* use large range*/).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(auto exec, + test_runner_.CreateExecutable(std::move(module), true)); + + for(int i = 0; i < 10; i++) { + VLOG(0) << "Running iteration #" << i; + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + HloTestBase::ExecuteReplicated( + [&](int64_t){ return exec.get(); }, + [&fake_arguments](int64_t replica_id) + { return fake_arguments.size(); }, + [&fake_arguments](int64_t replica_id, int64_t idx) + { return &fake_arguments[idx]; }, + NumReplicas, false /*run hlo*/, &assn)); + ASSERT_EQ(results.size(), NumReplicas); + } +#endif +} + +} // namespace gpu +} // namespace xla + \ No newline at end of file