Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding default XLA/GPU env vars for all JAX-based containers #114

Merged
merged 9 commits into from
Aug 18, 2023
6 changes: 5 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG BUILD_DATE
ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false"
nluehr marked this conversation as resolved.
Show resolved Hide resolved
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_IB_SL=1

COPY --from=jax-builder ${SRC_PATH_JAX}-no-git ${SRC_PATH_JAX}
COPY --from=jax-builder ${SRC_PATH_XLA}-no-git ${SRC_PATH_XLA}
Expand All @@ -73,4 +77,4 @@ ARG SRC_PATH_XLA
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

COPY --from=jax-builder ${SRC_PATH_JAX}/.git ${SRC_PATH_JAX}/.git
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,19 @@ We currently enable training and evaluation for the following models:
| [t5(t5x)](./rosetta/rosetta/projects/t5x) | ✔️ | ✔️ | ✔️ |

We will update this table as new models become available, so stay tuned.

## Environment Variables

The [JAX image](ghcr.io/nvidia/jax) is embedded with the following flags and environment variables for performance tuning:

| XLA Flags | Value | Explanation |
| --------- | ----- | ----------- |
| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels |
| `--xla_gpu_enable_async_all_gather` | `true` | allows XLA to run NCCL [AllGather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#allgather) kernels on a separate CUDA stream to allow overlap with compute kernels |
| `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels |
| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels |

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
| `NCCL_IB_SL` | `1` | defines the InfiniBand Service Level ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-ib-sl)) |
Loading