diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 86d930459..b94a6a23f 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -75,7 +75,6 @@ ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false" ENV CUDA_DEVICE_MAX_CONNECTIONS=1 ENV NCCL_NVLS_ENABLE=0 -ENV CUDA_MODULE_LOADING=EAGER COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB} COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX} diff --git a/README.md b/README.md index 6da6861f0..11f9b4ada 100644 --- a/README.md +++ b/README.md @@ -306,7 +306,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | -------------------- | ----- | ----------- | | `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches | | `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. | -| `CUDA_MODULE_LOADING` | `EAGER` | Disables lazy-loading ([1](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables)) which uses slightly more GPU memory. | There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).