Skip to content

Commit

Permalink
Split RocmContext into its own target.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684523108
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 10, 2024
1 parent 689662f commit bca2c6a
Show file tree
Hide file tree
Showing 7 changed files with 370 additions and 308 deletions.
16 changes: 0 additions & 16 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,6 @@ class GpuDriver {
static absl::Status GetDeviceName(GpuDeviceHandle device,
std::string* device_name);

// Given a device to create a context for, returns a context handle into the
// context outparam, which must not be null.
//
// N.B. CUDA contexts are weird. They are implicitly associated with the
// calling thread. Current documentation on contexts and their influence on
// userspace processes is given here:
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf
static absl::Status CreateContext(int device_ordinal, GpuDeviceHandle device,
Context** context);

// Destroys the provided context via cuCtxDestroy.
// Don't do this while clients could still be using the context, per the docs
// bad things will happen.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e
static void DestroyContext(Context* context);

// Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control
Expand Down
80 changes: 57 additions & 23 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,56 @@ cc_library(
)

cc_library(
name = "rocm_driver",
srcs = ["rocm_driver.cc"],
hdrs = [
"rocm_driver.h",
"rocm_driver_wrapper.h",
name = "rocm_context",
srcs = ["rocm_context.cc"],
hdrs = ["rocm_context.h"],
tags = [
"gpu",
"rocm-only",
] + if_google([
# TODO(b/360374983): Remove this tag once the target can be built without --config=rocm.
"manual",
]),
deps = [
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor:device_description",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:context_map",
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
],
)

cc_library(
name = "rocm_driver_wrapper",
hdrs = ["rocm_driver_wrapper.h"],
defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"},
tags = [
"gpu",
"rocm-only",
] + if_google([
# TODO(klucke): Remove this tag once the target can be built without --config=rocm.
"manual",
]),
deps = [
"@local_config_rocm//rocm:hip", # buildcleaner: keep
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
],
)

cc_library(
name = "rocm_driver",
srcs = ["rocm_driver.cc"],
tags = [
"gpu",
"rocm-only",
Expand All @@ -68,6 +111,8 @@ cc_library(
"manual",
]),
deps = [
":rocm_context",
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor",
"//xla/stream_executor/gpu:context",
Expand All @@ -78,17 +123,13 @@ cc_library(
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@local_config_rocm//rocm:hip", # buildcleaner: keep
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:dso_loader",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
Expand All @@ -110,21 +151,13 @@ cc_library(
"manual",
]),
deps = [
":rocm_driver",
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_driver_header",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -140,7 +173,7 @@ cc_library(
"manual",
]),
deps = [
":rocm_driver",
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor:event",
"//xla/stream_executor/gpu:context",
Expand All @@ -165,8 +198,10 @@ cc_library(
"manual",
]),
deps = [
":rocm_context",
":rocm_diagnostics",
":rocm_driver",
":rocm_driver_wrapper",
":rocm_event",
":rocm_kernel",
":rocm_platform_id",
Expand Down Expand Up @@ -875,14 +910,14 @@ cc_library(
"manual",
]),
deps = [
":rocm_platform_id",
":rocm_rpath",
"//xla/stream_executor",
"//xla/stream_executor:dnn",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:scratch_allocator",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/host:host_platform_id",
"//xla/stream_executor/rocm:rocm_platform_id",
] + if_static(
[":all_runtime"],
),
Expand Down Expand Up @@ -932,7 +967,7 @@ cc_library(
"manual",
]),
deps = [
":rocm_driver",
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor:event",
"//xla/stream_executor:platform",
Expand Down Expand Up @@ -964,8 +999,7 @@ cc_library(
"manual",
]),
deps = [
":rocm_driver",
":rocm_event",
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor:event_based_timer",
"//xla/stream_executor/gpu:context",
Expand Down
Loading

0 comments on commit bca2c6a

Please sign in to comment.