Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding default XLA/GPU env vars for all JAX-based containers #114

Merged
merged 9 commits into from
Aug 18, 2023

Conversation

terrykong
Copy link
Contributor

@terrykong terrykong commented Jul 11, 2023

Addresses #105 and #109

  • --xla_gpu_enable_latency_hiding_scheduler=true: Allows XLA:GPU to move communication collectives to increase overlap with compute kernels.
  • --xla_gpu_enable_async_all_gather=true: Allows XLA:GPU to run All Gather NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.
  • --xla_gpu_enable_async_reduce_scatter=true: Allows XLA:GPU to run Reduce Scatter NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.
  • --xla_gpu_enable_triton_gemm=false: Disallows Trition GeMM kernels; uses CUBLAS GeMM kernels instead. CUBLAS kernels are currently better tuned for GPUs and thus provide better performance.
  • CUDA_DEVICE_MAX_CONNECTIONS=1: Use a single queue for GPU work, lowers latency of each stream operation. OK since XLA already orders launches.
  • NCCL_IB_SL=1: defines the InfiniBand Service Level (1)

@nluehr @ashors1

@terrykong terrykong changed the title Adding default env vars for all JAX-based containers that should improve perf Adding default XLA/GPU env vars for all JAX-based containers Jul 11, 2023
@terrykong terrykong requested review from yhtang and nluehr July 11, 2023 17:01
@yhtang
Copy link
Collaborator

yhtang commented Jul 17, 2023

Could someone please edit the PR description to document the rationale behind each flag?

@terrykong
Copy link
Contributor Author

Could someone please edit the PR description to document the rationale behind each flag?

@nluehr @abhinavgoel95 ?

@abhinavgoel95
Copy link
Contributor

  1. --xla_gpu_enable_latency_hiding_scheduler=true: Allows XLA:GPU to move communication collectives to increase overlap with compute kernels.
  2. --xla_gpu_enable_async_all_gather=true: Allows XLA:GPU to run All Gather NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.
  3. --xla_gpu_enable_async_reduce_scatter=true: Allows XLA:GPU to run Reduce Scatter NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.
  4. --xla_gpu_enable_triton_gemm=false: Disallows Trition GeMM kernels, uses CUBLAS GeMM kernels instead.

@nluehr
Copy link
Contributor

nluehr commented Jul 17, 2023

@abhinavgoel95 what about the following?
NCCL_PROTO=LL128
NCCL_AVOID_RECORD_STREAMS=1
NCCL_IB_SL=1

@yhtang
Copy link
Collaborator

yhtang commented Jul 18, 2023

@terrykong terrykong self-assigned this Jul 18, 2023
@nluehr
Copy link
Contributor

nluehr commented Jul 19, 2023

@abhinavgoel95 How much speedup do we see from setting NCCL_PROTO=LL128 in practice?
The combiner threshold may come into play here as for larger collectives, I think LL128 could reduce effective bandwidth by ~6%.

@nluehr
Copy link
Contributor

nluehr commented Jul 20, 2023

From my reading of the NCCL code, the only safe way to enable LL128 is to not specify NCCL_PROTO at all.
LL128 is valid only on Volta, Ampere, and Hopper with NVLink. PCIE can cause packet re-ordering that results in silent corruption.
Since we expect users will run our models on a variety of clusters, I think we shouldn't define NCCL_PROTO either in the JAX dockerfile nor in the paxml run scripts. If removing this results in worse performance, we should open a bug with the NCCL team.

@ashors1
Copy link
Contributor

ashors1 commented Jul 21, 2023

We're also sometimes interested in setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 to increase the amount of memory allocated to XLA. Is this something we'd want to consider adding to the base container as well?

@nluehr
Copy link
Contributor

nluehr commented Jul 21, 2023

Whether we want to set XLA_PYTHON_CLIENT_MEM_FRACTION depends on the workload (particularly whether any other GPU libraries are being used outside of XLA).

My opinion is that this would be OK to set in the paxml container, but for the JAX container in general we shouldn't set it because there we expect users to extend it in varied workflows.

@yhtang
Copy link
Collaborator

yhtang commented Jul 25, 2023

NCCL_AVOID_RECORD_STREAMS is NOT documented anywhere. Shall we open an NVBug?

@nluehr
Copy link
Contributor

nluehr commented Jul 26, 2023

As best I can tell, NCCL_AVOID_RECORD_STREAMS is a feature of pytorch rather than NCCL. (It seems it's now been renamed to TORCH_NCCL_AVOID_RECORD_STREAMS here).
So I think it's safe to drop it entirely.

Copy link
Collaborator

@yhtang yhtang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we figure out a source for the undocumented NCCL_AVOID_RECORD_STREAMS variable?

@terrykong
Copy link
Contributor Author

@abhinavgoel95

@abhinavgoel95
Copy link
Contributor

@yhtang @terrykong it is safe to drop NCCL_AVOID_RECORD_STREAMS. We do not need to upstream this. It is a PyTorch specific feature.

@terrykong
Copy link
Contributor Author

I removed mention of that env var. Can I re-request your review @yhtang ?

@terrykong terrykong requested a review from yhtang August 17, 2023 20:14
@yhtang yhtang merged commit 11987e0 into main Aug 18, 2023
36 of 41 checks passed
@yhtang yhtang deleted the jax-env-vars branch August 18, 2023 21:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants