Skip to content

Commit

Permalink
Use cached driver version from ComputeCapability in nvptx_compiler in…
Browse files Browse the repository at this point in the history
…stead of re-fetching via GpuDriver::GetDriverVersion.

PiperOrigin-RevId: 689014301
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 23, 2024
1 parent 4285d97 commit 8476142
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 20 deletions.
7 changes: 5 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2223,9 +2223,12 @@ GpuCompiler::CompileToBackendResult(

// Test whether LinkModules is supported.
bool can_use_link_modules = (executor != nullptr);
se::GpuComputeCapability gpu_compute_capability =
gpu_device_info.gpu_compute_capability();
if (can_use_link_modules) {
TF_ASSIGN_OR_RETURN(can_use_link_modules,
CanUseLinkModules(module->config()));
TF_ASSIGN_OR_RETURN(
can_use_link_modules,
CanUseLinkModules(module->config(), gpu_compute_capability));
}
const bool split_modules =
can_use_link_modules &&
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/gpu_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class GpuCompiler : public LLVMCompiler {
}

virtual absl::StatusOr<bool> CanUseLinkModules(
const HloModuleConfig& config) {
const HloModuleConfig& config,
se::GpuComputeCapability& gpu_compute_capability) {
return false;
}

Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,9 +785,13 @@ class KernelCacheTest : public HloTestBase {
CHECK(tsl::Env::Default()->LocalTempFilename(&cache_file_name_));
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
se::GpuComputeCapability cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
TF_ASSERT_OK_AND_ASSIGN(bool can_use_link_modules,
dynamic_cast<GpuCompiler*>(backend().compiler())
->CanUseLinkModules(config));
->CanUseLinkModules(config, cc));
if (!can_use_link_modules) {
GTEST_SKIP() << "Caching compiled kernels requires support of linking.";
}
Expand Down
40 changes: 26 additions & 14 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
Expand Down Expand Up @@ -611,8 +612,10 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs);
}

TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(module_config.debug_options()));
TF_ASSIGN_OR_RETURN(
se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(module_config.debug_options(),
std::get<se::CudaComputeCapability>(gpu_version)));

if (linking_method == se::PtxLinkingMethod::kNvJitLink && relocatable) {
VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable "
Expand Down Expand Up @@ -897,7 +900,8 @@ static absl::StatusOr<stream_executor::SemanticVersion> GetAsmCompilerVersion(
}

absl::StatusOr<se::PtxLinkingMethod> NVPTXCompiler::ChooseLinkingMethod(
const DebugOptions& debug_options) {
const DebugOptions& debug_options,
se::CudaComputeCapability& compute_capability) {
se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options);
std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir;

Expand All @@ -919,17 +923,17 @@ absl::StatusOr<se::PtxLinkingMethod> NVPTXCompiler::ChooseLinkingMethod(

int ptxas_version =
asm_compiler_version.major() * 1000 + asm_compiler_version.minor() * 10;
TF_ASSIGN_OR_RETURN(int driver_version,
se::gpu::GpuDriver::GetDriverVersion());
int driver_version =
compute_capability.major * 1000 + compute_capability.minor * 10;

if (driver_version >= ptxas_version) {
return LinkingMethod::kDriver;
}

LOG_FIRST_N(WARNING, 1)
<< "The NVIDIA driver's CUDA version is "
<< absl::StrFormat("%d.%d", driver_version / 1000,
(driver_version % 1000) / 10)
<< absl::StrFormat("%d.%d", compute_capability.major,
compute_capability.minor)
<< " which is older than the PTX compiler version "
<< asm_compiler_version
<< ". Because the driver is older than the PTX compiler version, XLA is "
Expand All @@ -941,12 +945,20 @@ absl::StatusOr<se::PtxLinkingMethod> NVPTXCompiler::ChooseLinkingMethod(
}

absl::StatusOr<bool> NVPTXCompiler::CanUseLinkModules(
const HloModuleConfig& hlo_module_config) {
// TODO(phawkins): rather than comparing version numbers, it might be more
// robust if we simply tried to link something the first time we compile.
TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(hlo_module_config.debug_options()));
return linking_method != se::PtxLinkingMethod::kNone;
const HloModuleConfig& hlo_module_config,
se::GpuComputeCapability& gpu_compute_capability) {
if (std::holds_alternative<se::CudaComputeCapability>(
gpu_compute_capability)) {
// TODO(phawkins): rather than comparing version numbers, it might be more
// robust if we simply tried to link something the first time we compile.
TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(hlo_module_config.debug_options(),
std::get<se::CudaComputeCapability>(
gpu_compute_capability)));
return linking_method != se::PtxLinkingMethod::kNone;
}

return false;
}

absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
Expand All @@ -959,7 +971,7 @@ absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
std::get<stream_executor::CudaComputeCapability>(compute_capability);

TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method,
ChooseLinkingMethod(debug_options));
ChooseLinkingMethod(debug_options, cc));
VLOG(1) << "Linking " << modules.size()
<< " modules with linking method: " << linking_method;

Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/nvptx_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class NVPTXCompiler : public GpuCompiler {
const HloModule* debug_module, const CompileOptions& options) override;

absl::StatusOr<bool> CanUseLinkModules(
const HloModuleConfig& module_config) override;
const HloModuleConfig& module_config,
se::GpuComputeCapability& gpu_compute_capability) override;

private:
absl::StatusOr<std::vector<uint8_t>> LinkModules(
Expand All @@ -105,7 +106,8 @@ class NVPTXCompiler : public GpuCompiler {
const DebugOptions& debug_options) override;

absl::StatusOr<stream_executor::PtxLinkingMethod> ChooseLinkingMethod(
const DebugOptions& debug_options);
const DebugOptions& debug_options,
se::CudaComputeCapability& compute_capability);

// Tries to compile the given ptx string to cubin. Returns a vector with the
// compiled cubin if compilation succeeded.
Expand Down

0 comments on commit 8476142

Please sign in to comment.