Skip to content

Commit

Permalink
Split GpuEvent into platform specific implementations
Browse files Browse the repository at this point in the history
- Removes `GpuEvent` and moves all remaining functionality into `CudaEvent`/`RocmEvent`
- Moves `DestroyEvent` and `InitEvent` from `GpuDriver` into `CudaEvent` and `RocmEvent`
- Makes `CudaTimer` and `RocmTimer` use `CudaEvent` and `RocmEvent` as a replacement for `GpuEvent`.
- Replace `GpuEvent::Init` function by factory functions in `CudaEvent` and `RocmEvent`
- Add basic test or `CudaEvent` and `RocmEvent`.

PiperOrigin-RevId: 683830727
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Oct 11, 2024
1 parent 2b60c15 commit 4db7b74
Show file tree
Hide file tree
Showing 28 changed files with 470 additions and 345 deletions.
34 changes: 29 additions & 5 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,35 @@ cc_library(
"//xla/stream_executor:event",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_test(
name = "cuda_event_test",
srcs = ["cuda_event_test.cc"],
backends = ["gpu"],
tags = ["cuda-only"],
deps = [
":cuda_event",
":cuda_executor",
":cuda_platform_id",
"//xla/stream_executor:event",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"@com_google_googletest//:gtest_main",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down Expand Up @@ -1029,7 +1052,6 @@ cc_library(
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_event_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_kernel_header",
"//xla/stream_executor/gpu:gpu_semaphore",
Expand Down Expand Up @@ -1192,7 +1214,6 @@ cc_library(
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_activate_context",
Expand All @@ -1207,17 +1228,20 @@ cc_library(

cc_library(
name = "cuda_timer",
srcs = ["cuda_timer.cc"],
srcs = [
"cuda_timer.cc",
],
hdrs = ["cuda_timer.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":cuda_event",
":cuda_status",
"//xla/stream_executor:event_based_timer",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_semaphore",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_activate_context",
Expand Down
26 changes: 0 additions & 26 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -965,14 +965,6 @@ bool GpuDriver::HostUnregister(Context* context, void* location) {
return true;
}

absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) {
if (*event == nullptr) {
return absl::InvalidArgumentError("input event cannot be null");
}

ScopedActivateContext activated{context};
return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event");
}

absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) {
ScopedActivateContext activated{context};
Expand Down Expand Up @@ -1077,24 +1069,6 @@ absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context,
return absl::OkStatus();
}

absl::Status GpuDriver::InitEvent(Context* context, CUevent* result,
EventFlags flags) {
int cuflags;
switch (flags) {
case EventFlags::kDefault:
cuflags = CU_EVENT_DEFAULT;
break;
case EventFlags::kDisableTiming:
cuflags = CU_EVENT_DISABLE_TIMING;
break;
default:
LOG(FATAL) << "impossible event flags: " << int(flags);
}

ScopedActivateContext activated{context};
return cuda::ToStatus(cuEventCreate(result, cuflags));
}

int GpuDriver::GetDeviceCount() {
int device_count = 0;
auto status = cuda::ToStatus(cuDeviceGetCount(&device_count));
Expand Down
77 changes: 73 additions & 4 deletions xla/stream_executor/cuda/cuda_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ limitations under the License.
#include <cstdint>

#include "absl/base/casts.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {
namespace gpu {
Expand All @@ -33,11 +36,45 @@ absl::Status WaitStreamOnEvent(Context* context, CUstream stream,
ScopedActivateContext activation(context);
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
}

void DestroyEvent(Context* context, CUevent event) {
if (event == nullptr) {
return;
}

ScopedActivateContext activated{context};
auto result =
cuda::ToStatus(cuEventDestroy(event), "Error destroying CUDA event");
if (!result.ok()) {
LOG(ERROR) << result.message();
}
}

enum class EventFlags { kDefault, kDisableTiming };
absl::StatusOr<CUevent> InitEvent(Context* context, EventFlags flags) {
int cuflags;
switch (flags) {
case EventFlags::kDefault:
cuflags = CU_EVENT_DEFAULT;
break;
case EventFlags::kDisableTiming:
cuflags = CU_EVENT_DISABLE_TIMING;
break;
default:
LOG(FATAL) << "impossible event flags: " << int(flags);
}

ScopedActivateContext activated{context};
CUevent event_handle;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventCreate(&event_handle, cuflags)));
return event_handle;
}

} // namespace

Event::Status CudaEvent::PollForStatus() {
ScopedActivateContext activated(context());
CUresult res = cuEventQuery(gpu_event());
ScopedActivateContext activated(context_);
CUresult res = cuEventQuery(handle_);
if (res == CUDA_SUCCESS) {
return Event::Status::kComplete;
} else if (res == CUDA_ERROR_NOT_READY) {
Expand All @@ -47,8 +84,40 @@ Event::Status CudaEvent::PollForStatus() {
}

absl::Status CudaEvent::WaitForEventOnExternalStream(std::intptr_t stream) {
return WaitStreamOnEvent(context(), absl::bit_cast<CUstream>(stream),
gpu_event());
return WaitStreamOnEvent(context_, absl::bit_cast<CUstream>(stream), handle_);
}

absl::StatusOr<CudaEvent> CudaEvent::Create(Context* context,
bool allow_timing) {
TF_ASSIGN_OR_RETURN(
CUevent event_handle,
InitEvent(context, allow_timing ? EventFlags::kDefault
: EventFlags::kDisableTiming));

return CudaEvent(context, event_handle);
}

CudaEvent::~CudaEvent() { DestroyEvent(context_, handle_); }

CudaEvent& CudaEvent::operator=(CudaEvent&& other) {
if (this == &other) {
return *this;
}

DestroyEvent(context_, handle_);

context_ = other.context_;
handle_ = other.handle_;
other.context_ = nullptr;
other.handle_ = nullptr;

return *this;
}

CudaEvent::CudaEvent(CudaEvent&& other)
: context_(other.context_), handle_(other.handle_) {
other.context_ = nullptr;
other.handle_ = nullptr;
}

} // namespace gpu
Expand Down
30 changes: 25 additions & 5 deletions xla/stream_executor/cuda/cuda_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,42 @@ limitations under the License.
#include <cstdint>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_event.h"

namespace stream_executor::gpu {

class GpuContext;

// This class implements Event::PollForStatus for CUDA devices.
class CudaEvent : public GpuEvent {
class CudaEvent : public Event {
public:
explicit CudaEvent(Context *context) : GpuEvent(context) {}

Event::Status PollForStatus() override;

absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override;

// Creates a new CudaEvent. If allow_timing is false, the event will not
// support timing, which is cheaper to create.
static absl::StatusOr<CudaEvent> Create(Context* context, bool allow_timing);

CUevent GetHandle() const { return handle_; }

~CudaEvent() override;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
CudaEvent(CudaEvent&& other);
CudaEvent& operator=(CudaEvent&& other);

private:
explicit CudaEvent(Context* context, CUevent handle)
: context_(context), handle_(handle) {}

// The Context used to which this object and GpuEventHandle are bound.
Context* context_;

// The underlying CUDA event handle.
CUevent handle_;
};

} // namespace stream_executor::gpu
Expand Down
55 changes: 55 additions & 0 deletions xla/stream_executor/cuda/cuda_event_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_event.h"

#include <utility>

#include <gtest/gtest.h>
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_executor.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace stream_executor::gpu {
namespace {
using ::tsl::testing::IsOk;

TEST(CudaEventTest, CreateEvent) {
TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
stream_executor::PlatformManager::PlatformWithId(
stream_executor::cuda::kCudaPlatformId));
CudaExecutor executor{platform, 0};
ASSERT_THAT(executor.Init(), IsOk());

TF_ASSERT_OK_AND_ASSIGN(CudaEvent event,
CudaEvent::Create(executor.gpu_context(), false));

EXPECT_NE(event.GetHandle(), nullptr);
EXPECT_EQ(event.PollForStatus(), Event::Status::kComplete);

CUevent handle = event.GetHandle();
CudaEvent event2 = std::move(event);
EXPECT_EQ(event2.GetHandle(), handle);
}

} // namespace

} // namespace stream_executor::gpu
12 changes: 6 additions & 6 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
Expand Down Expand Up @@ -817,11 +816,11 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device,
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<GpuEvent>> CudaExecutor::CreateGpuEvent(
absl::StatusOr<std::unique_ptr<CudaEvent>> CudaExecutor::CreateGpuEvent(
bool allow_timing) {
auto gpu_event = std::make_unique<CudaEvent>(gpu_context());
TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing));
return std::move(gpu_event);
TF_ASSIGN_OR_RETURN(auto event,
CudaEvent::Create(gpu_context(), allow_timing));
return std::make_unique<CudaEvent>(std::move(event));
}

absl::StatusOr<std::unique_ptr<Event>> CudaExecutor::CreateEvent() {
Expand All @@ -830,7 +829,8 @@ absl::StatusOr<std::unique_ptr<Event>> CudaExecutor::CreateEvent() {

absl::StatusOr<std::unique_ptr<Stream>> CudaExecutor::CreateStream(
std::optional<std::variant<StreamPriority, int>> priority) {
TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false));
TF_ASSIGN_OR_RETURN(auto event,
CudaEvent::Create(gpu_context(), /*allow_timing=*/false));
TF_ASSIGN_OR_RETURN(auto stream,
CudaStream::Create(this, std::move(event), priority));
absl::MutexLock l(&alive_gpu_streams_mu_);
Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ limitations under the License.
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_collectives.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/fft.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_types.h"
Expand Down Expand Up @@ -179,7 +179,7 @@ class CudaExecutor : public GpuExecutor {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

// Creates a GpuEvent for the given stream.
absl::StatusOr<std::unique_ptr<GpuEvent>> CreateGpuEvent(bool allow_timing);
absl::StatusOr<std::unique_ptr<CudaEvent>> CreateGpuEvent(bool allow_timing);

// Returns true if a delay kernel is supported.
absl::StatusOr<bool> DelayKernelIsSupported();
Expand Down
Loading

0 comments on commit 4db7b74

Please sign in to comment.