Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Cuda and RocmTimer #18165

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,9 @@ cc_library(
cuda_library(
name = "delay_kernel_cuda",
srcs = [
"delay_kernel.h",
"delay_kernel_cuda.cu.cc",
],
hdrs = ["delay_kernel.h"],
# copybara:uncomment compatible_with = ["//buildenv/target:non_prod"],
tags = [
"cuda-only",
Expand Down Expand Up @@ -1014,7 +1014,6 @@ cc_library(
name = "cuda_executor",
srcs = [
"cuda_executor.cc",
"delay_kernel.h",
],
hdrs = [
"cuda_executor.h",
Expand All @@ -1034,7 +1033,6 @@ cc_library(
":cuda_stream",
":cuda_timer",
":cuda_version_parser",
":delay_kernel_cuda",
"//xla/stream_executor",
"//xla/stream_executor:blas",
"//xla/stream_executor:command_buffer",
Expand All @@ -1054,7 +1052,6 @@ cc_library(
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_kernel_header",
"//xla/stream_executor/gpu:gpu_semaphore",
"//xla/stream_executor/gpu:gpu_stream_header",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/stream_executor/gpu:read_numa_node",
Expand Down Expand Up @@ -1239,6 +1236,7 @@ cc_library(
deps = [
":cuda_event",
":cuda_status",
":delay_kernel_cuda",
"//xla/stream_executor:event_based_timer",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:context",
Expand All @@ -1254,3 +1252,31 @@ cc_library(
"@tsl//tsl/platform:statusor",
],
)

xla_test(
name = "cuda_timer_test",
srcs = ["cuda_timer_test.cc"],
backends = ["gpu"],
tags = ["cuda-only"],
deps = [
":cuda_executor",
":cuda_platform_id",
":cuda_timer",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:kernel",
"//xla/stream_executor:kernel_spec",
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:typed_kernel_factory",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:gpu_test_kernels_cuda",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
31 changes: 10 additions & 21 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_stream.h"
#include "xla/stream_executor/cuda/cuda_timer.h"
#include "xla/stream_executor/cuda/cuda_version_parser.h"
#include "xla/stream_executor/cuda/delay_kernel.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
Expand All @@ -67,7 +66,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/gpu/read_numa_node.h"
Expand Down Expand Up @@ -426,18 +424,15 @@ absl::StatusOr<std::unique_ptr<Kernel>> CudaExecutor::LoadKernel(

absl::StatusOr<std::unique_ptr<EventBasedTimer>>
CudaExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) {
GpuSemaphore semaphore{};
const CudaTimer::TimerType timer_type =
(use_delay_kernel && ShouldLaunchDelayKernel() &&
delay_kernels_supported_)
? CudaTimer::TimerType::kDelayKernel
: CudaTimer::TimerType::kEventBased;

if (use_delay_kernel && ShouldLaunchDelayKernel() &&
delay_kernels_supported_) {
TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream));
}
TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true));
TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true));
TF_RETURN_IF_ERROR(stream->RecordEvent(start_event.get()));
return std::make_unique<CudaTimer>(gpu_context(), std::move(start_event),
std::move(stop_event), stream,
std::move(semaphore));
TF_ASSIGN_OR_RETURN(CudaTimer timer,
CudaTimer::Create(gpu_context(), stream, timer_type));
return std::make_unique<CudaTimer>(std::move(timer));
}

bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) {
Expand Down Expand Up @@ -816,15 +811,9 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device,
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<CudaEvent>> CudaExecutor::CreateGpuEvent(
bool allow_timing) {
TF_ASSIGN_OR_RETURN(auto event,
CudaEvent::Create(gpu_context(), allow_timing));
return std::make_unique<CudaEvent>(std::move(event));
}

absl::StatusOr<std::unique_ptr<Event>> CudaExecutor::CreateEvent() {
return CreateGpuEvent(/*allow_timing=*/false);
TF_ASSIGN_OR_RETURN(auto event, CudaEvent::Create(gpu_context(), false));
return std::make_unique<CudaEvent>(std::move(event));
}

absl::StatusOr<std::unique_ptr<Stream>> CudaExecutor::CreateStream(
Expand Down
4 changes: 0 additions & 4 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License.
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_collectives.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
Expand Down Expand Up @@ -178,9 +177,6 @@ class CudaExecutor : public GpuExecutor {
bool UnloadGpuBinary(const void* gpu_binary)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

// Creates a GpuEvent for the given stream.
absl::StatusOr<std::unique_ptr<CudaEvent>> CreateGpuEvent(bool allow_timing);

// Returns true if a delay kernel is supported.
absl::StatusOr<bool> DelayKernelIsSupported();

Expand Down
38 changes: 27 additions & 11 deletions xla/stream_executor/cuda/cuda_timer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/cuda/delay_kernel.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
Expand All @@ -40,11 +41,7 @@ absl::StatusOr<float> GetEventElapsedTime(Context* context, CUevent start,
ScopedActivateContext activated{context};
// The stop event must have completed in order for cuEventElapsedTime to
// work.
auto status = cuda::ToStatus(cuEventSynchronize(stop));
if (!status.ok()) {
LOG(ERROR) << "failed to synchronize the stop event: " << status;
return false;
}
TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventSynchronize(stop)));

float elapsed_milliseconds;

Expand All @@ -56,8 +53,8 @@ absl::StatusOr<float> GetEventElapsedTime(Context* context, CUevent start,

} // namespace

CudaTimer::CudaTimer(Context* context, std::unique_ptr<CudaEvent> start_event,
std::unique_ptr<CudaEvent> stop_event, GpuStream* stream,
CudaTimer::CudaTimer(Context* context, CudaEvent start_event,
CudaEvent stop_event, GpuStream* stream,
GpuSemaphore semaphore)
: semaphore_(std::move(semaphore)),
context_(context),
Expand All @@ -80,9 +77,9 @@ CudaTimer::~CudaTimer() {

absl::StatusOr<absl::Duration> CudaTimer::GetElapsedDuration() {
if (is_stopped_) {
return absl::InternalError("Measuring inactive timer");
return absl::FailedPreconditionError("Measuring inactive timer");
}
TF_RETURN_IF_ERROR(stream_->RecordEvent(stop_event_.get()));
TF_RETURN_IF_ERROR(stream_->RecordEvent(&stop_event_));
// If we launched the delay kernel then check if it already timed out.
if (semaphore_) {
if (*semaphore_ == GpuSemaphoreState::kTimedOut) {
Expand All @@ -96,10 +93,29 @@ absl::StatusOr<absl::Duration> CudaTimer::GetElapsedDuration() {
}
}
TF_ASSIGN_OR_RETURN(float elapsed_milliseconds,
GetEventElapsedTime(context_, start_event_->GetHandle(),
stop_event_->GetHandle()));
GetEventElapsedTime(context_, start_event_.GetHandle(),
stop_event_.GetHandle()));
is_stopped_ = true;
return absl::Milliseconds(elapsed_milliseconds);
}

absl::StatusOr<CudaTimer> CudaTimer::Create(Context* context, GpuStream* stream,
TimerType timer_type) {
GpuSemaphore semaphore{};

if (timer_type == TimerType::kDelayKernel) {
TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream));
}

TF_ASSIGN_OR_RETURN(CudaEvent start_event,
CudaEvent::Create(context, /*allow_timing=*/true));
TF_ASSIGN_OR_RETURN(CudaEvent stop_event,
CudaEvent::Create(context, /*allow_timing=*/true));

TF_RETURN_IF_ERROR(stream->RecordEvent(&start_event));

return CudaTimer(context, std::move(start_event), std::move(stop_event),
stream, std::move(semaphore));
}

} // namespace stream_executor::gpu
19 changes: 14 additions & 5 deletions xla/stream_executor/cuda/cuda_timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,29 @@ limitations under the License.
namespace stream_executor::gpu {
class CudaTimer : public EventBasedTimer {
public:
CudaTimer(Context* context, std::unique_ptr<CudaEvent> start_event,
std::unique_ptr<CudaEvent> stop_event, GpuStream* stream,
GpuSemaphore semaphore);
~CudaTimer() override;
CudaTimer(CudaTimer&&) = default;
CudaTimer& operator=(CudaTimer&&) = default;

absl::StatusOr<absl::Duration> GetElapsedDuration() override;

enum class TimerType {
kDelayKernel,
kEventBased,
};
static absl::StatusOr<CudaTimer> Create(Context* context, GpuStream* stream,
TimerType timer_type);

private:
CudaTimer(Context* context, CudaEvent start_event, CudaEvent stop_event,
GpuStream* stream, GpuSemaphore semaphore);

GpuSemaphore semaphore_;
bool is_stopped_ = false;
Context* context_;
GpuStream* stream_;
std::unique_ptr<CudaEvent> start_event_;
std::unique_ptr<CudaEvent> stop_event_;
CudaEvent start_event_;
CudaEvent stop_event_;
};

} // namespace stream_executor::gpu
Expand Down
111 changes: 111 additions & 0 deletions xla/stream_executor/cuda/cuda_timer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_timer.h"

#include <cstdint>
#include <memory>
#include <optional>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/time/time.h"
#include "xla/stream_executor/cuda/cuda_executor.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/typed_kernel_factory.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace stream_executor::gpu {
namespace {
using ::testing::Gt;
using ::tsl::testing::IsOk;

class CudaTimerTest : public ::testing::TestWithParam<CudaTimer::TimerType> {
public:
void LaunchSomeKernel(StreamExecutor* executor, Stream* stream) {
using AddI32Kernel =
TypedKernelFactory<DeviceMemory<int32_t>, DeviceMemory<int32_t>,
DeviceMemory<int32_t>>;

MultiKernelLoaderSpec spec(/*arity=*/3);
spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32");
TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec));

int64_t length = 4;
int64_t byte_length = sizeof(int32_t) * length;

// Prepare arguments: a=1, b=2, c=0
DeviceMemory<int32_t> a = executor->AllocateArray<int32_t>(length, 0);
DeviceMemory<int32_t> b = executor->AllocateArray<int32_t>(length, 0);
DeviceMemory<int32_t> c = executor->AllocateArray<int32_t>(length, 0);

ASSERT_THAT(stream->Memset32(&a, 1, byte_length), IsOk());
ASSERT_THAT(stream->Memset32(&b, 2, byte_length), IsOk());
ASSERT_THAT(stream->MemZero(&c, byte_length), IsOk());

ASSERT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c),
IsOk());
}

std::optional<CudaExecutor> executor_;
std::unique_ptr<Stream> stream_;
GpuStream* gpu_stream_;

private:
void SetUp() override {
TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
stream_executor::PlatformManager::PlatformWithId(
stream_executor::cuda::kCudaPlatformId));
executor_.emplace(platform, 0);
ASSERT_THAT(executor_->Init(), IsOk());
TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt));
gpu_stream_ = AsGpuStream(stream_.get());
}
};

TEST_P(CudaTimerTest, Create) {
TF_ASSERT_OK_AND_ASSIGN(
CudaTimer timer,
CudaTimer::Create(executor_->gpu_context(), gpu_stream_, GetParam()));

// We don't really care what kernel we launch here as long as it takes a
// non-zero amount of time.
LaunchSomeKernel(&executor_.value(), stream_.get());

TF_ASSERT_OK_AND_ASSIGN(absl::Duration timer_result,
timer.GetElapsedDuration());
EXPECT_THAT(timer_result, Gt(absl::ZeroDuration()));
EXPECT_THAT(timer.GetElapsedDuration(),
tsl::testing::StatusIs(absl::StatusCode::kFailedPrecondition));
}

INSTANTIATE_TEST_SUITE_P(CudaTimerTest, CudaTimerTest,
::testing::Values(CudaTimer::TimerType::kEventBased,
CudaTimer::TimerType::kDelayKernel));

} // namespace
} // namespace stream_executor::gpu
3 changes: 3 additions & 0 deletions xla/stream_executor/event_based_timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace stream_executor {
class EventBasedTimer {
public:
virtual ~EventBasedTimer() = default;
EventBasedTimer() = default;
EventBasedTimer(EventBasedTimer&&) = default;
EventBasedTimer& operator=(EventBasedTimer&&) = default;

// Stops the timer on the first call and returns the elapsed duration.
// Subsequent calls error out.
Expand Down
Loading
Loading