-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding default XLA/GPU env vars for all JAX-based containers #114
Conversation
Could someone please edit the PR description to document the rationale behind each flag? |
|
|
@abhinavgoel95 what about the following? |
|
@abhinavgoel95 How much speedup do we see from setting |
From my reading of the NCCL code, the only safe way to enable LL128 is to not specify NCCL_PROTO at all. |
We're also sometimes interested in setting |
Whether we want to set XLA_PYTHON_CLIENT_MEM_FRACTION depends on the workload (particularly whether any other GPU libraries are being used outside of XLA). My opinion is that this would be OK to set in the paxml container, but for the JAX container in general we shouldn't set it because there we expect users to extend it in varied workflows. |
|
As best I can tell, NCCL_AVOID_RECORD_STREAMS is a feature of pytorch rather than NCCL. (It seems it's now been renamed to TORCH_NCCL_AVOID_RECORD_STREAMS here). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we figure out a source for the undocumented NCCL_AVOID_RECORD_STREAMS
variable?
@yhtang @terrykong it is safe to drop NCCL_AVOID_RECORD_STREAMS. We do not need to upstream this. It is a PyTorch specific feature. |
I removed mention of that env var. Can I re-request your review @yhtang ? |
Addresses #105 and #109
--xla_gpu_enable_latency_hiding_scheduler=true
: Allows XLA:GPU to move communication collectives to increase overlap with compute kernels.--xla_gpu_enable_async_all_gather=true
: Allows XLA:GPU to run All Gather NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.--xla_gpu_enable_async_reduce_scatter=true
: Allows XLA:GPU to run Reduce Scatter NCCL kernels on a separate CUDA stream to allow overlap with compute kernels.--xla_gpu_enable_triton_gemm=false
: Disallows Trition GeMM kernels; uses CUBLAS GeMM kernels instead. CUBLAS kernels are currently better tuned for GPUs and thus provide better performance.CUDA_DEVICE_MAX_CONNECTIONS=1
: Use a single queue for GPU work, lowers latency of each stream operation. OK since XLA already orders launches.NCCL_IB_SL=1
: defines the InfiniBand Service Level (1)@nluehr @ashors1