diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index 6dbf5d58a..fd3da48ab 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -166,7 +166,7 @@ export TF_CUDNN_PATHS=/usr/lib/$(uname -p)-linux-gnu export TF_CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3-4) export TF_CUDA_MAJOR_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3) export TF_CUBLAS_VERSION=$(ls /usr/local/cuda/lib64/libcublas.so.*.*.* | cut -d . -f 3) -export TF_CUDNN_VERSION=$(echo "${NV_CUDNN_VERSION}" | cut -d . -f 1) +export TF_CUDNN_VERSION="9.3.0" #$(echo "${NV_CUDNN_VERSION}" | cut -d . -f 1) export TF_NCCL_VERSION=$(echo "${NCCL_VERSION}" | cut -d . -f 1) case "${CPU_ARCH}" in @@ -270,8 +270,6 @@ time python "${SRC_PATH_JAX}/build/build.py" \ --bazel_options=--override_repository=xla=$SRC_PATH_XLA \ --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn" \ --bazel_options=--repo_env=LOCAL_NCCL_PATH="/usr/local/lib/python3.10/dist-packages/nvidia/nccl" \ - --bazel_options=--repo_env=HERMETIC_CUDA_VERSION=$TF_CUDA_VERSION \ - --bazel_options=--repo_env=HERMETIC_CUDNN_VERSION=$TF_CUDNN_VERSION \ --output_path=${BUILD_PATH_JAXLIB} \ $BUILD_PARAM