diff --git a/.github/container/install-cudnn.sh b/.github/container/install-cudnn.sh index 5eac3aeec..fab18edea 100755 --- a/.github/container/install-cudnn.sh +++ b/.github/container/install-cudnn.sh @@ -5,6 +5,8 @@ set -ex export DEBIAN_FRONTEND=noninteractive export TZ=America/Los_Angeles +CUDNN_MAJOR_VERSION=9 + apt-get update # Extract major CUDA version from `nvcc --version` output line @@ -18,8 +20,8 @@ cuda_major_version=$(nvcc --version | sed -n 's/^.*release \([0-9]*\.[0-9]*\).*$ # version of CUDA and cuDNN are compatible. # For example, CUDA 12.3 + cuDNN 8.9.6 (libcudnn8 version: 8.9.6.50-1+cuda12.2) is # considered to be compatible. -libcudnn_version=$(apt-cache show libcudnn8 | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]*\)$/\1/p" | head -n 1) -libcudnn_dev_version=$(apt-cache show libcudnn8-dev | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]\)$/\1/p" | head -n 1) +libcudnn_version=$(apt-cache show libcudnn${CUDNN_MAJOR_VERSION} | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]*\)$/\1/p" | head -n 1) +libcudnn_dev_version=$(apt-cache show libcudnn${CUDNN_MAJOR_VERSION}-dev | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]\)$/\1/p" | head -n 1) if [[ -z "${libcudnn_version}" || -z "${libcudnn_dev_version}" ]]; then echo "Could not find compatible cuDNN version for CUDA ${cuda_version}" exit 1 @@ -27,8 +29,8 @@ fi apt-get update apt-get install -y \ - libcudnn8=${libcudnn_version} \ - libcudnn8-dev=${libcudnn_dev_version} + libcudnn${CUDNN_MAJOR_VERSION}=${libcudnn_version} \ + libcudnn${CUDNN_MAJOR_VERSION}-dev=${libcudnn_dev_version} apt-get clean rm -rf /var/lib/apt/lists/*