From 4db7b7409f768067073934ffdda0eb5676d48d7f Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 8 Oct 2024 18:09:51 -0700 Subject: [PATCH] Split GpuEvent into platform specific implementations - Removes `GpuEvent` and moves all remaining functionality into `CudaEvent`/`RocmEvent` - Moves `DestroyEvent` and `InitEvent` from `GpuDriver` into `CudaEvent` and `RocmEvent` - Makes `CudaTimer` and `RocmTimer` use `CudaEvent` and `RocmEvent` as a replacement for `GpuEvent`. - Replace `GpuEvent::Init` function by factory functions in `CudaEvent` and `RocmEvent` - Add basic test or `CudaEvent` and `RocmEvent`. PiperOrigin-RevId: 683830727 --- xla/stream_executor/cuda/BUILD | 34 ++++++-- xla/stream_executor/cuda/cuda_driver.cc | 26 ------ xla/stream_executor/cuda/cuda_event.cc | 77 +++++++++++++++++- xla/stream_executor/cuda/cuda_event.h | 30 +++++-- xla/stream_executor/cuda/cuda_event_test.cc | 55 +++++++++++++ xla/stream_executor/cuda/cuda_executor.cc | 12 +-- xla/stream_executor/cuda/cuda_executor.h | 4 +- xla/stream_executor/cuda/cuda_stream.cc | 22 +++-- xla/stream_executor/cuda/cuda_stream.h | 17 ++-- xla/stream_executor/cuda/cuda_timer.cc | 10 +-- xla/stream_executor/cuda/cuda_timer.h | 11 ++- xla/stream_executor/gpu/BUILD | 27 ------- xla/stream_executor/gpu/gpu_driver.h | 19 ----- xla/stream_executor/gpu/gpu_event.cc | 47 ----------- xla/stream_executor/gpu/gpu_event.h | 62 -------------- xla/stream_executor/gpu/gpu_stream.cc | 10 --- xla/stream_executor/gpu/gpu_stream.h | 14 +--- xla/stream_executor/rocm/BUILD | 34 +++++++- xla/stream_executor/rocm/rocm_driver.cc | 52 ------------ xla/stream_executor/rocm/rocm_event.cc | 89 ++++++++++++++++++++- xla/stream_executor/rocm/rocm_event.h | 30 +++++-- xla/stream_executor/rocm/rocm_event_test.cc | 55 +++++++++++++ xla/stream_executor/rocm/rocm_executor.cc | 13 +-- xla/stream_executor/rocm/rocm_executor.h | 4 +- xla/stream_executor/rocm/rocm_stream.cc | 23 ++++-- xla/stream_executor/rocm/rocm_stream.h | 18 +++-- xla/stream_executor/rocm/rocm_timer.cc | 10 +-- xla/stream_executor/rocm/rocm_timer.h | 10 +-- 28 files changed, 470 insertions(+), 345 deletions(-) create mode 100644 xla/stream_executor/cuda/cuda_event_test.cc delete mode 100644 xla/stream_executor/gpu/gpu_event.cc delete mode 100644 xla/stream_executor/gpu/gpu_event.h create mode 100644 xla/stream_executor/rocm/rocm_event_test.cc diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 2911405ceb2912..289666d1bbb00b 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -674,12 +674,35 @@ cc_library( "//xla/stream_executor:event", "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/base", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "cuda_event_test", + srcs = ["cuda_event_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_event", + ":cuda_executor", + ":cuda_platform_id", + "//xla/stream_executor:event", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_googletest//:gtest_main", "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", ], ) @@ -1029,7 +1052,6 @@ cc_library( "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:gpu_command_buffer", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_semaphore", @@ -1192,7 +1214,6 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:scoped_activate_context", @@ -1207,17 +1228,20 @@ cc_library( cc_library( name = "cuda_timer", - srcs = ["cuda_timer.cc"], + srcs = [ + "cuda_timer.cc", + ], hdrs = ["cuda_timer.h"], tags = [ "cuda-only", "gpu", ], deps = [ + ":cuda_event", ":cuda_status", "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_semaphore", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:scoped_activate_context", diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index 5c47b5cf9bbd0e..848245cade5201 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -965,14 +965,6 @@ bool GpuDriver::HostUnregister(Context* context, void* location) { return true; } -absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) { - if (*event == nullptr) { - return absl::InvalidArgumentError("input event cannot be null"); - } - - ScopedActivateContext activated{context}; - return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event"); -} absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) { ScopedActivateContext activated{context}; @@ -1077,24 +1069,6 @@ absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context, return absl::OkStatus(); } -absl::Status GpuDriver::InitEvent(Context* context, CUevent* result, - EventFlags flags) { - int cuflags; - switch (flags) { - case EventFlags::kDefault: - cuflags = CU_EVENT_DEFAULT; - break; - case EventFlags::kDisableTiming: - cuflags = CU_EVENT_DISABLE_TIMING; - break; - default: - LOG(FATAL) << "impossible event flags: " << int(flags); - } - - ScopedActivateContext activated{context}; - return cuda::ToStatus(cuEventCreate(result, cuflags)); -} - int GpuDriver::GetDeviceCount() { int device_count = 0; auto status = cuda::ToStatus(cuDeviceGetCount(&device_count)); diff --git a/xla/stream_executor/cuda/cuda_event.cc b/xla/stream_executor/cuda/cuda_event.cc index 6d466bd906cd77..3755fb81374860 100644 --- a/xla/stream_executor/cuda/cuda_event.cc +++ b/xla/stream_executor/cuda/cuda_event.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -33,11 +36,45 @@ absl::Status WaitStreamOnEvent(Context* context, CUstream stream, ScopedActivateContext activation(context); return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */)); } + +void DestroyEvent(Context* context, CUevent event) { + if (event == nullptr) { + return; + } + + ScopedActivateContext activated{context}; + auto result = + cuda::ToStatus(cuEventDestroy(event), "Error destroying CUDA event"); + if (!result.ok()) { + LOG(ERROR) << result.message(); + } +} + +enum class EventFlags { kDefault, kDisableTiming }; +absl::StatusOr InitEvent(Context* context, EventFlags flags) { + int cuflags; + switch (flags) { + case EventFlags::kDefault: + cuflags = CU_EVENT_DEFAULT; + break; + case EventFlags::kDisableTiming: + cuflags = CU_EVENT_DISABLE_TIMING; + break; + default: + LOG(FATAL) << "impossible event flags: " << int(flags); + } + + ScopedActivateContext activated{context}; + CUevent event_handle; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventCreate(&event_handle, cuflags))); + return event_handle; +} + } // namespace Event::Status CudaEvent::PollForStatus() { - ScopedActivateContext activated(context()); - CUresult res = cuEventQuery(gpu_event()); + ScopedActivateContext activated(context_); + CUresult res = cuEventQuery(handle_); if (res == CUDA_SUCCESS) { return Event::Status::kComplete; } else if (res == CUDA_ERROR_NOT_READY) { @@ -47,8 +84,40 @@ Event::Status CudaEvent::PollForStatus() { } absl::Status CudaEvent::WaitForEventOnExternalStream(std::intptr_t stream) { - return WaitStreamOnEvent(context(), absl::bit_cast(stream), - gpu_event()); + return WaitStreamOnEvent(context_, absl::bit_cast(stream), handle_); +} + +absl::StatusOr CudaEvent::Create(Context* context, + bool allow_timing) { + TF_ASSIGN_OR_RETURN( + CUevent event_handle, + InitEvent(context, allow_timing ? EventFlags::kDefault + : EventFlags::kDisableTiming)); + + return CudaEvent(context, event_handle); +} + +CudaEvent::~CudaEvent() { DestroyEvent(context_, handle_); } + +CudaEvent& CudaEvent::operator=(CudaEvent&& other) { + if (this == &other) { + return *this; + } + + DestroyEvent(context_, handle_); + + context_ = other.context_; + handle_ = other.handle_; + other.context_ = nullptr; + other.handle_ = nullptr; + + return *this; +} + +CudaEvent::CudaEvent(CudaEvent&& other) + : context_(other.context_), handle_(other.handle_) { + other.context_ = nullptr; + other.handle_ = nullptr; } } // namespace gpu diff --git a/xla/stream_executor/cuda/cuda_event.h b/xla/stream_executor/cuda/cuda_event.h index a562ea8be939cf..7f7b8368c73068 100644 --- a/xla/stream_executor/cuda/cuda_event.h +++ b/xla/stream_executor/cuda/cuda_event.h @@ -19,22 +19,42 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" namespace stream_executor::gpu { class GpuContext; // This class implements Event::PollForStatus for CUDA devices. -class CudaEvent : public GpuEvent { +class CudaEvent : public Event { public: - explicit CudaEvent(Context *context) : GpuEvent(context) {} - Event::Status PollForStatus() override; - absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + + // Creates a new CudaEvent. If allow_timing is false, the event will not + // support timing, which is cheaper to create. + static absl::StatusOr Create(Context* context, bool allow_timing); + + CUevent GetHandle() const { return handle_; } + + ~CudaEvent() override; + CudaEvent(const CudaEvent&) = delete; + CudaEvent& operator=(const CudaEvent&) = delete; + CudaEvent(CudaEvent&& other); + CudaEvent& operator=(CudaEvent&& other); + + private: + explicit CudaEvent(Context* context, CUevent handle) + : context_(context), handle_(handle) {} + + // The Context used to which this object and GpuEventHandle are bound. + Context* context_; + + // The underlying CUDA event handle. + CUevent handle_; }; } // namespace stream_executor::gpu diff --git a/xla/stream_executor/cuda/cuda_event_test.cc b/xla/stream_executor/cuda/cuda_event_test.cc new file mode 100644 index 00000000000000..f99375b33fbcd6 --- /dev/null +++ b/xla/stream_executor/cuda/cuda_event_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_event.h" + +#include + +#include +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_executor.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::tsl::testing::IsOk; + +TEST(CudaEventTest, CreateEvent) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::cuda::kCudaPlatformId)); + CudaExecutor executor{platform, 0}; + ASSERT_THAT(executor.Init(), IsOk()); + + TF_ASSERT_OK_AND_ASSIGN(CudaEvent event, + CudaEvent::Create(executor.gpu_context(), false)); + + EXPECT_NE(event.GetHandle(), nullptr); + EXPECT_EQ(event.PollForStatus(), Event::Status::kComplete); + + CUevent handle = event.GetHandle(); + CudaEvent event2 = std::move(event); + EXPECT_EQ(event2.GetHandle(), handle); +} + +} // namespace + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index f2a95547931ba6..4108910d1fc647 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -65,7 +65,6 @@ limitations under the License. #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" @@ -817,11 +816,11 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, return absl::OkStatus(); } -absl::StatusOr> CudaExecutor::CreateGpuEvent( +absl::StatusOr> CudaExecutor::CreateGpuEvent( bool allow_timing) { - auto gpu_event = std::make_unique(gpu_context()); - TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); - return std::move(gpu_event); + TF_ASSIGN_OR_RETURN(auto event, + CudaEvent::Create(gpu_context(), allow_timing)); + return std::make_unique(std::move(event)); } absl::StatusOr> CudaExecutor::CreateEvent() { @@ -830,7 +829,8 @@ absl::StatusOr> CudaExecutor::CreateEvent() { absl::StatusOr> CudaExecutor::CreateStream( std::optional> priority) { - TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + TF_ASSIGN_OR_RETURN(auto event, + CudaEvent::Create(gpu_context(), /*allow_timing=*/false)); TF_ASSIGN_OR_RETURN(auto stream, CudaStream::Create(this, std::move(event), priority)); absl::MutexLock l(&alive_gpu_streams_mu_); diff --git a/xla/stream_executor/cuda/cuda_executor.h b/xla/stream_executor/cuda/cuda_executor.h index 46905a0ccb4cef..1b148a69b32c22 100644 --- a/xla/stream_executor/cuda/cuda_executor.h +++ b/xla/stream_executor/cuda/cuda_executor.h @@ -36,6 +36,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_collectives.h" +#include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -179,7 +179,7 @@ class CudaExecutor : public GpuExecutor { ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // Creates a GpuEvent for the given stream. - absl::StatusOr> CreateGpuEvent(bool allow_timing); + absl::StatusOr> CreateGpuEvent(bool allow_timing); // Returns true if a delay kernel is supported. absl::StatusOr DelayKernelIsSupported(); diff --git a/xla/stream_executor/cuda/cuda_stream.cc b/xla/stream_executor/cuda/cuda_stream.cc index 0512f7f66963f5..dc71014ae079a9 100644 --- a/xla/stream_executor/cuda/cuda_stream.cc +++ b/xla/stream_executor/cuda/cuda_stream.cc @@ -27,7 +27,6 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/platform.h" @@ -90,7 +89,7 @@ absl::StatusOr CreateStream(Context* context, int priority) { } // namespace absl::StatusOr> CudaStream::Create( - GpuExecutor* executor, std::unique_ptr completed_event, + GpuExecutor* executor, CudaEvent completed_event, std::optional> priority) { int stream_priority = [&]() { if (priority.has_value() && std::holds_alternative(priority.value())) { @@ -110,22 +109,29 @@ absl::StatusOr> CudaStream::Create( absl::Status CudaStream::WaitFor(Stream* other) { CudaStream* other_stream = static_cast(other); - GpuEvent* other_completed_event = other_stream->completed_event(); - TF_RETURN_IF_ERROR(other_stream->RecordEvent(other_completed_event)); - + TF_RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); return WaitStreamOnEvent(executor_->gpu_context(), gpu_stream(), - other_completed_event->gpu_event()); + other_stream->completed_event_.GetHandle()); } absl::Status CudaStream::RecordEvent(Event* event) { return stream_executor::gpu::RecordEvent( - executor_->gpu_context(), static_cast(event)->gpu_event(), + executor_->gpu_context(), static_cast(event)->GetHandle(), gpu_stream()); } absl::Status CudaStream::WaitFor(Event* event) { return WaitStreamOnEvent(executor_->gpu_context(), gpu_stream(), - static_cast(event)->gpu_event()); + static_cast(event)->GetHandle()); +} + +absl::Status CudaStream::RecordCompletedEvent() { + return RecordEvent(&completed_event_); +} + +CudaStream::~CudaStream() { + BlockHostUntilDone().IgnoreError(); + executor_->DeallocateStream(this); } } // namespace gpu diff --git a/xla/stream_executor/cuda/cuda_stream.h b/xla/stream_executor/cuda/cuda_stream.h index 8f760bb396687b..61395e5de3be22 100644 --- a/xla/stream_executor/cuda/cuda_stream.h +++ b/xla/stream_executor/cuda/cuda_stream.h @@ -24,8 +24,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/platform.h" @@ -41,18 +41,23 @@ class CudaStream : public GpuStream { absl::Status WaitFor(Event* event) override; static absl::StatusOr> Create( - GpuExecutor* executor, std::unique_ptr completed_event, + GpuExecutor* executor, CudaEvent completed_event, std::optional> priority); + ~CudaStream() override; + private: - CudaStream(GpuExecutor* executor, std::unique_ptr completed_event, + CudaStream(GpuExecutor* executor, CudaEvent completed_event, std::optional> priority, CUstream stream_handle) - : GpuStream(executor, std::move(completed_event), priority, - stream_handle), - executor_(executor) {} + : GpuStream(executor, priority, stream_handle), + executor_(executor), + completed_event_(std::move(completed_event)) {} + + absl::Status RecordCompletedEvent(); GpuExecutor* executor_; + CudaEvent completed_event_; }; } // namespace gpu diff --git a/xla/stream_executor/cuda/cuda_timer.cc b/xla/stream_executor/cuda/cuda_timer.cc index fba6d7bc83edb8..c9b0f083691b51 100644 --- a/xla/stream_executor/cuda/cuda_timer.cc +++ b/xla/stream_executor/cuda/cuda_timer.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" @@ -56,8 +56,8 @@ absl::StatusOr GetEventElapsedTime(Context* context, CUevent start, } // namespace -CudaTimer::CudaTimer(Context* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream, +CudaTimer::CudaTimer(Context* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream, GpuSemaphore semaphore) : semaphore_(std::move(semaphore)), context_(context), @@ -96,8 +96,8 @@ absl::StatusOr CudaTimer::GetElapsedDuration() { } } TF_ASSIGN_OR_RETURN(float elapsed_milliseconds, - GetEventElapsedTime(context_, start_event_->gpu_event(), - stop_event_->gpu_event())); + GetEventElapsedTime(context_, start_event_->GetHandle(), + stop_event_->GetHandle())); is_stopped_ = true; return absl::Milliseconds(elapsed_milliseconds); } diff --git a/xla/stream_executor/cuda/cuda_timer.h b/xla/stream_executor/cuda/cuda_timer.h index dd4785a28dc33e..c8f8f332fa801a 100644 --- a/xla/stream_executor/cuda/cuda_timer.h +++ b/xla/stream_executor/cuda/cuda_timer.h @@ -18,20 +18,19 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" namespace stream_executor::gpu { class CudaTimer : public EventBasedTimer { public: - CudaTimer(Context* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream, + CudaTimer(Context* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream, GpuSemaphore semaphore); ~CudaTimer() override; @@ -42,8 +41,8 @@ class CudaTimer : public EventBasedTimer { bool is_stopped_ = false; Context* context_; GpuStream* stream_; - std::unique_ptr start_event_; - std::unique_ptr stop_event_; + std::unique_ptr start_event_; + std::unique_ptr stop_event_; }; } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index c77c3bb4019dcc..f5fd3ae390cb89 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -218,31 +218,6 @@ gpu_only_cc_library( ]), ) -gpu_only_cc_library( - name = "gpu_event_header", - hdrs = ["gpu_event.h"], - deps = [ - ":context", - ":gpu_types_header", - "//xla/stream_executor:event", - "@com_google_absl//absl/status", - ], -) - -gpu_only_cc_library( - name = "gpu_event", - srcs = ["gpu_event.cc"], - hdrs = ["gpu_event.h"], - deps = [ - ":context", - ":gpu_driver_header", - ":gpu_types_header", - "//xla/stream_executor:event", - "@com_google_absl//absl/base", - "@com_google_absl//absl/status", - ], -) - cc_library( name = "gpu_executor_header", hdrs = ["gpu_executor.h"], @@ -367,7 +342,6 @@ gpu_only_cc_library( name = "gpu_stream_header", hdrs = ["gpu_stream.h"], deps = [ - ":gpu_event_header", ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor:device_memory", @@ -390,7 +364,6 @@ gpu_only_cc_library( hdrs = ["gpu_stream.h"], deps = [ ":gpu_driver_header", - ":gpu_event_header", ":gpu_executor_header", ":gpu_kernel_header", ":gpu_types_header", diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index a6d806774d7646..7c329525ece833 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -69,25 +69,6 @@ class GpuDriver { // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management static void DestroyStream(Context* context, GpuStreamHandle stream); - // CUDA/HIP events can explicitly disable event TSC retrieval for some - // presumed performance improvement if timing is unnecessary. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types - enum class EventFlags { kDefault, kDisableTiming }; - - // Creates a new event associated with the given context. - // result is an outparam owned by the caller and must not be null. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types - static absl::Status InitEvent(Context* context, GpuEventHandle* result, - EventFlags flags); - - // Destroys *event and turns it into a nullptr. event may not be null, but - // *event may be, via cuEventDestroy/hipEventDestroy - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#event-management - static absl::Status DestroyEvent(Context* context, GpuEventHandle* event); - // Allocates a GPU memory space of size bytes associated with the given // context via cuMemAlloc/hipMalloc. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467 diff --git a/xla/stream_executor/gpu/gpu_event.cc b/xla/stream_executor/gpu/gpu_event.cc deleted file mode 100644 index 952377c76ecfaa..00000000000000 --- a/xla/stream_executor/gpu/gpu_event.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_event.h" - -#include - -#include "absl/base/casts.h" -#include "absl/status/status.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_types.h" - -namespace stream_executor { -namespace gpu { - -GpuEvent::GpuEvent(Context* context) : context_(context), gpu_event_(nullptr) {} - -GpuEvent::~GpuEvent() { Destroy().IgnoreError(); } - -absl::Status GpuEvent::Init(bool allow_timing) { - return GpuDriver::InitEvent(context_, &gpu_event_, - allow_timing - ? GpuDriver::EventFlags::kDefault - : GpuDriver::EventFlags::kDisableTiming); -} - -absl::Status GpuEvent::Destroy() { - return GpuDriver::DestroyEvent(context_, &gpu_event_); -} - -GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; } - -} // namespace gpu -} // namespace stream_executor diff --git a/xla/stream_executor/gpu/gpu_event.h b/xla/stream_executor/gpu/gpu_event.h deleted file mode 100644 index c5fa4390ddc841..00000000000000 --- a/xla/stream_executor/gpu/gpu_event.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ - -#include - -#include "absl/status/status.h" -#include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_types.h" - -namespace stream_executor { -namespace gpu { - -class GpuContext; - -// GpuEvent wraps a GpuEventHandle in the platform-independent Event interface. -class GpuEvent : public Event { - public: - explicit GpuEvent(Context* context); - - ~GpuEvent() override; - - // Populates the CUDA-platform-specific elements of this object. - absl::Status Init(bool allow_timing); - - // Deallocates any platform-specific elements of this object. This is broken - // out (not part of the destructor) to allow for error reporting. - absl::Status Destroy(); - - // The underlying CUDA event element. - GpuEventHandle gpu_event(); - - protected: - Context* context() const { return context_; } - - private: - // The Executor used to which this object and GpuEventHandle are bound. - Context* context_; - - // The underlying CUDA event element. - GpuEventHandle gpu_event_; -}; - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index 82751bd5bce36b..21e45e87adceab 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -19,26 +19,20 @@ limitations under the License. #include #include #include -#include #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/str_format.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/nvtx_utils.h" @@ -120,10 +114,6 @@ absl::Status GpuStream::DoHostCallbackWithStatus( } GpuStream::~GpuStream() { - BlockHostUntilDone().IgnoreError(); - parent()->DeallocateStream(this); - - completed_event_.reset(); GpuDriver::DestroyStream(parent_->gpu_context(), gpu_stream_); } diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index e7130fad590f05..78080f31c09905 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" @@ -49,13 +48,10 @@ namespace gpu { // Thread-safe post-initialization. class GpuStream : public StreamCommon { public: - GpuStream(GpuExecutor* parent, std::unique_ptr completed_event, + GpuStream(GpuExecutor* parent, std::optional> priority, GpuStreamHandle gpu_stream) - : StreamCommon(parent), - parent_(parent), - gpu_stream_(gpu_stream), - completed_event_(std::move(completed_event)) { + : StreamCommon(parent), parent_(parent), gpu_stream_(gpu_stream) { if (priority.has_value()) { stream_priority_ = priority.value(); } @@ -69,11 +65,6 @@ class GpuStream : public StreamCommon { } PlatformSpecificHandle platform_specific_handle() const override; - // Retrieves an event which indicates that all work enqueued into the stream - // has completed. Ownership of the event is not transferred to the caller, the - // event is owned by this stream. - GpuEvent* completed_event() { return completed_event_.get(); } - // Returns the GpuStreamHandle value for passing to the CUDA API. // // Precond: this GpuStream has been allocated (otherwise passing a nullptr @@ -113,7 +104,6 @@ class GpuStream : public StreamCommon { GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. std::variant stream_priority_; - std::unique_ptr completed_event_; }; // Helper functions to simplify extremely common flows. diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index dc1d39dad08f1f..619de91d9ac99d 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -177,12 +177,39 @@ cc_library( ":rocm_status", "//xla/stream_executor:event", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/base", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "rocm_event_test", + srcs = ["rocm_event_test.cc"], + backends = ["gpu"], + tags = ["rocm-only"] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_event", + ":rocm_executor", + ":rocm_platform_id", + "//xla/stream_executor:event", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_googletest//:gtest_main", + "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", ], ) @@ -228,7 +255,6 @@ cc_library( "//xla/stream_executor/gpu:gpu_command_buffer", "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_stream", @@ -968,12 +994,12 @@ cc_library( ]), deps = [ ":rocm_driver_wrapper", + ":rocm_event", ":rocm_status", "//xla/stream_executor:event", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:scoped_activate_context", @@ -1000,10 +1026,10 @@ cc_library( ]), deps = [ ":rocm_driver_wrapper", + ":rocm_event", ":rocm_status", "//xla/stream_executor:event_based_timer", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/log", diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 4269dbbed331f7..af09f6f756df69 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -785,30 +785,6 @@ void GpuDriver::HostDeallocate(Context* context, void* location) { } } -absl::Status GpuDriver::DestroyEvent(Context* context, GpuEventHandle* event) { - if (*event == nullptr) { - return absl::InvalidArgumentError("input event cannot be null"); - } - - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipEventDestroy(*event); - *event = nullptr; - - switch (res) { - case hipSuccess: - return absl::OkStatus(); - case hipErrorDeinitialized: - case hipErrorNotInitialized: - return absl::FailedPreconditionError( - absl::StrFormat("error destroying ROCM event in device %d: %s", - context->device_ordinal(), ToString(res).c_str())); - default: - return absl::InternalError( - absl::StrFormat("error destroying ROCM event in device %d: %s", - context->device_ordinal(), ToString(res).c_str())); - } -} - absl::Status GpuDriver::SynchronizeStream(Context* context, GpuStreamHandle stream) { ScopedActivateContext activated{context}; @@ -909,34 +885,6 @@ absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context, return absl::OkStatus(); } -absl::Status GpuDriver::InitEvent(Context* context, GpuEventHandle* event, - EventFlags flags) { - int hipflags; - switch (flags) { - case EventFlags::kDefault: - hipflags = hipEventDefault; - break; - case EventFlags::kDisableTiming: - hipflags = hipEventDisableTiming | hipEventReleaseToSystem; - break; - default: - LOG(FATAL) << "impossible event flags: " << int(hipflags); - } - - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipEventCreateWithFlags(event, hipflags); - - if (res == hipSuccess) { - return absl::OkStatus(); - } else if (res == hipErrorMemoryAllocation) { - return absl::ResourceExhaustedError( - "could not create ROCM event: out of device memory"); - } else { - return absl::FailedPreconditionError( - absl::StrCat("could not create ROCM event: ", ToString(res))); - } -} - int GpuDriver::GetDeviceCount() { int device_count = 0; hipError_t res = wrap::hipGetDeviceCount(&device_count); diff --git a/xla/stream_executor/rocm/rocm_event.cc b/xla/stream_executor/rocm/rocm_event.cc index ec1212b0b1350f..7c20568f59cb1e 100644 --- a/xla/stream_executor/rocm/rocm_event.cc +++ b/xla/stream_executor/rocm/rocm_event.cc @@ -18,7 +18,11 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" @@ -26,6 +30,7 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -38,11 +43,56 @@ absl::Status WaitStreamOnEvent(Context* context, hipStream_t stream, "could not wait stream on event")); return absl::OkStatus(); } + +enum class EventFlags { kDefault, kDisableTiming }; +absl::StatusOr InitEvent(Context* context, EventFlags flags) { + int hipflags; + switch (flags) { + case EventFlags::kDefault: + hipflags = hipEventDefault; + break; + case EventFlags::kDisableTiming: + hipflags = hipEventDisableTiming | hipEventReleaseToSystem; + break; + default: + LOG(FATAL) << "impossible event flags: " << int(hipflags); + } + + ScopedActivateContext activated{context}; + hipEvent_t event; + hipError_t res = wrap::hipEventCreateWithFlags(&event, hipflags); + + if (res == hipSuccess) { + return event; + } + if (res == hipErrorMemoryAllocation) { + return absl::ResourceExhaustedError( + "could not create ROCM event: out of device memory"); + } + return absl::FailedPreconditionError( + absl::StrCat("could not create ROCM event: ", ToString(res))); +} + +void DestroyEvent(Context* context, hipEvent_t event) { + if (event == nullptr) { + return; + } + + ScopedActivateContext activated{context}; + hipError_t res = wrap::hipEventDestroy(event); + + if (res != hipSuccess) { + LOG(ERROR) << absl::StrFormat( + "error destroying ROCM event in device %d: %s", + context->device_ordinal(), ToString(res)); + } +} + } // namespace Event::Status RocmEvent::PollForStatus() { - ScopedActivateContext activated(context()); - hipError_t res = wrap::hipEventQuery(gpu_event()); + ScopedActivateContext activated(context_); + hipError_t res = wrap::hipEventQuery(handle_); if (res == hipSuccess) { return Event::Status::kComplete; @@ -54,9 +104,40 @@ Event::Status RocmEvent::PollForStatus() { } absl::Status RocmEvent::WaitForEventOnExternalStream(std::intptr_t stream) { - return WaitStreamOnEvent(context(), absl::bit_cast(stream), - gpu_event()); + return WaitStreamOnEvent(context_, absl::bit_cast(stream), + handle_); +} + +absl::StatusOr RocmEvent::Create(Context* context, + bool allow_timing) { + TF_ASSIGN_OR_RETURN( + hipEvent_t event_handle, + InitEvent(context, allow_timing ? EventFlags::kDefault + : EventFlags::kDisableTiming)); + + return RocmEvent(context, event_handle); +} + +RocmEvent::~RocmEvent() { DestroyEvent(context_, handle_); } + +RocmEvent::RocmEvent(RocmEvent&& other) + : context_(other.context_), handle_(other.handle_) { + other.context_ = nullptr; + other.handle_ = nullptr; } +RocmEvent& RocmEvent::operator=(RocmEvent&& other) { + if (this == &other) { + return *this; + } + + DestroyEvent(context_, handle_); + + context_ = other.context_; + handle_ = other.handle_; + other.context_ = nullptr; + other.handle_ = nullptr; + return *this; +} } // namespace gpu } // namespace stream_executor diff --git a/xla/stream_executor/rocm/rocm_event.h b/xla/stream_executor/rocm/rocm_event.h index 0fc5d4284069f6..dfd9091a669956 100644 --- a/xla/stream_executor/rocm/rocm_event.h +++ b/xla/stream_executor/rocm/rocm_event.h @@ -19,20 +19,40 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" namespace stream_executor::gpu { // This class implements Event::PollForStatus for ROCm devices. -class RocmEvent : public GpuEvent { +class RocmEvent : public Event { public: - explicit RocmEvent(Context *context) : GpuEvent(context) {} - Event::Status PollForStatus() override; - absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + + // Creates a new RocmEvent. If allow_timing is false, the event will not + // support timing, which is cheaper to create. + static absl::StatusOr Create(Context* context, bool allow_timing); + + hipEvent_t GetHandle() const { return handle_; } + + ~RocmEvent() override; + RocmEvent(const RocmEvent&) = delete; + RocmEvent& operator=(const RocmEvent&) = delete; + RocmEvent(RocmEvent&& other); + RocmEvent& operator=(RocmEvent&& other); + + private: + explicit RocmEvent(Context* context, hipEvent_t handle) + : context_(context), handle_(handle) {} + + // The Context used to which this object and GpuEventHandle are bound. + Context* context_; + + // The underlying CUDA event handle. + hipEvent_t handle_; }; } // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_event_test.cc b/xla/stream_executor/rocm/rocm_event_test.cc new file mode 100644 index 00000000000000..33550fe7543c97 --- /dev/null +++ b/xla/stream_executor/rocm/rocm_event_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_event.h" + +#include + +#include +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_executor.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::tsl::testing::IsOk; + +TEST(RocmEventTest, CreateEvent) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::rocm::kROCmPlatformId)); + RocmExecutor executor{platform, 0}; + ASSERT_THAT(executor.Init(), IsOk()); + + TF_ASSERT_OK_AND_ASSIGN(RocmEvent event, + RocmEvent::Create(executor.gpu_context(), false)); + + EXPECT_NE(event.GetHandle(), nullptr); + EXPECT_EQ(event.PollForStatus(), Event::Status::kComplete); + + hipEvent_t handle = event.GetHandle(); + RocmEvent event2 = std::move(event); + EXPECT_EQ(event2.GetHandle(), handle); +} + +} // namespace + +} // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 2bbf1a1f20ca93..14af9afa502cc7 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "xla/stream_executor/rocm/rocm_executor.h" #include @@ -637,11 +638,12 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, return absl::OkStatus(); } -absl::StatusOr> RocmExecutor::CreateGpuEvent( +absl::StatusOr> RocmExecutor::CreateGpuEvent( bool allow_timing) { - auto gpu_event = std::make_unique(gpu_context()); - TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); - return std::move(gpu_event); + TF_ASSIGN_OR_RETURN( + auto event, + RocmEvent::Create(gpu_context(), /*allow_timing=*/allow_timing)); + return std::make_unique(std::move(event)); } absl::StatusOr> RocmExecutor::CreateEvent() { @@ -650,7 +652,8 @@ absl::StatusOr> RocmExecutor::CreateEvent() { absl::StatusOr> RocmExecutor::CreateStream( std::optional> priority) { - TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + TF_ASSIGN_OR_RETURN(auto event, + RocmEvent::Create(gpu_context(), /*allow_timing=*/false)); TF_ASSIGN_OR_RETURN(auto stream, RocmStream::Create(this, std::move(event), priority)); absl::MutexLock l(&alive_gpu_streams_mu_); diff --git a/xla/stream_executor/rocm/rocm_executor.h b/xla/stream_executor/rocm/rocm_executor.h index 7e635fb8dcf1b3..2b6cc2ca0f41db 100644 --- a/xla/stream_executor/rocm/rocm_executor.h +++ b/xla/stream_executor/rocm/rocm_executor.h @@ -42,7 +42,6 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -53,6 +52,7 @@ limitations under the License. #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_context.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -158,7 +158,7 @@ class RocmExecutor : public GpuExecutor { ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // Creates a GpuEvent for the given stream. - absl::StatusOr> CreateGpuEvent(bool allow_timing); + absl::StatusOr> CreateGpuEvent(bool allow_timing); // Guards the on-disk-module mapping. absl::Mutex disk_modules_mu_; diff --git a/xla/stream_executor/rocm/rocm_stream.cc b/xla/stream_executor/rocm/rocm_stream.cc index 6473024f39b709..24bf5352f5fa4f 100644 --- a/xla/stream_executor/rocm/rocm_stream.cc +++ b/xla/stream_executor/rocm/rocm_stream.cc @@ -27,11 +27,11 @@ limitations under the License. #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" @@ -105,7 +105,7 @@ absl::Status WaitStreamOnEvent(Context* context, hipStream_t stream, } // namespace absl::StatusOr> RocmStream::Create( - GpuExecutor* executor, std::unique_ptr completed_event, + GpuExecutor* executor, RocmEvent completed_event, std::optional> priority) { int stream_priority = [&]() { if (priority.has_value() && std::holds_alternative(priority.value())) { @@ -125,21 +125,30 @@ absl::StatusOr> RocmStream::Create( absl::Status RocmStream::WaitFor(Stream* other) { RocmStream* other_stream = static_cast(other); - GpuEvent* other_completed_event = other_stream->completed_event(); - TF_RETURN_IF_ERROR(other_stream->RecordEvent(other_completed_event)); + TF_RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); return WaitStreamOnEvent(executor_->gpu_context(), gpu_stream(), - other_completed_event->gpu_event()); + other_stream->completed_event_.GetHandle()); } absl::Status RocmStream::RecordEvent(Event* event) { return stream_executor::gpu::RecordEvent( - executor_->gpu_context(), static_cast(event)->gpu_event(), + executor_->gpu_context(), static_cast(event)->GetHandle(), gpu_stream()); } absl::Status RocmStream::WaitFor(Event* event) { return WaitStreamOnEvent(executor_->gpu_context(), gpu_stream(), - static_cast(event)->gpu_event()); + static_cast(event)->GetHandle()); } + +absl::Status RocmStream::RecordCompletedEvent() { + return RecordEvent(&completed_event_); +} + +RocmStream::~RocmStream() { + BlockHostUntilDone().IgnoreError(); + executor_->DeallocateStream(this); +} + } // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_stream.h b/xla/stream_executor/rocm/rocm_stream.h index 174aa29231fbdc..431c57f2ab817d 100644 --- a/xla/stream_executor/rocm/rocm_stream.h +++ b/xla/stream_executor/rocm/rocm_stream.h @@ -23,11 +23,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/stream.h" namespace stream_executor { @@ -40,18 +41,23 @@ class RocmStream : public GpuStream { absl::Status WaitFor(Event* event) override; static absl::StatusOr> Create( - GpuExecutor* executor, std::unique_ptr completed_event, + GpuExecutor* executor, RocmEvent completed_event, std::optional> priority); + ~RocmStream() override; + private: - RocmStream(GpuExecutor* executor, std::unique_ptr completed_event, + RocmStream(GpuExecutor* executor, RocmEvent completed_event, std::optional> priority, hipStream_t stream_handle) - : GpuStream(executor, std::move(completed_event), priority, - stream_handle), - executor_(executor) {} + : GpuStream(executor, priority, stream_handle), + executor_(executor), + completed_event_(std::move(completed_event)) {} + + absl::Status RecordCompletedEvent(); GpuExecutor* executor_; + RocmEvent completed_event_; }; } // namespace gpu diff --git a/xla/stream_executor/rocm/rocm_timer.cc b/xla/stream_executor/rocm/rocm_timer.cc index 2aa605aaccd5dd..18a5077b82b5ae 100644 --- a/xla/stream_executor/rocm/rocm_timer.cc +++ b/xla/stream_executor/rocm/rocm_timer.cc @@ -24,10 +24,10 @@ limitations under the License. #include "absl/time/time.h" #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_status.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -53,8 +53,8 @@ absl::StatusOr GetEventElapsedTime(Context* context, hipEvent_t start, } } // namespace -RocmTimer::RocmTimer(Context* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream) +RocmTimer::RocmTimer(Context* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream) : context_(context), stream_(stream), start_event_(std::move(start_event)), @@ -66,8 +66,8 @@ absl::StatusOr RocmTimer::GetElapsedDuration() { } TF_RETURN_IF_ERROR(stream_->RecordEvent(stop_event_.get())); TF_ASSIGN_OR_RETURN(float elapsed_milliseconds, - GetEventElapsedTime(context_, start_event_->gpu_event(), - stop_event_->gpu_event())); + GetEventElapsedTime(context_, start_event_->GetHandle(), + stop_event_->GetHandle())); is_stopped_ = true; return absl::Milliseconds(elapsed_milliseconds); } diff --git a/xla/stream_executor/rocm/rocm_timer.h b/xla/stream_executor/rocm/rocm_timer.h index 1cf25b3a034249..c9dae4f5c04019 100644 --- a/xla/stream_executor/rocm/rocm_timer.h +++ b/xla/stream_executor/rocm/rocm_timer.h @@ -22,15 +22,15 @@ limitations under the License. #include "absl/time/time.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/rocm/rocm_event.h" namespace stream_executor::gpu { class RocmTimer : public EventBasedTimer { public: - RocmTimer(Context* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream); + RocmTimer(Context* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream); absl::StatusOr GetElapsedDuration() override; @@ -38,8 +38,8 @@ class RocmTimer : public EventBasedTimer { bool is_stopped_ = false; Context* context_; GpuStream* stream_; - std::unique_ptr start_event_; - std::unique_ptr stop_event_; + std::unique_ptr start_event_; + std::unique_ptr stop_event_; }; } // namespace stream_executor::gpu