Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dockerfile.jax: install JAX for TE build #997

Merged
merged 17 commits into from
Aug 19, 2024
31 changes: 20 additions & 11 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ ARG BUILD_DATE
## Build JAX
###############################################################################

FROM ${BASE_IMAGE} as builder
FROM ${BASE_IMAGE} AS builder
ARG URLREF_JAX
ARG URLREF_TRANSFORMER_ENGINE
ARG URLREF_XLA
ARG SRC_PATH_JAX
ARG SRC_PATH_TRANSFORMER_ENGINE
ARG SRC_PATH_XLA
ARG BAZEL_CACHE
ARG BUILD_PATH_JAXLIB
Expand Down Expand Up @@ -54,14 +56,23 @@ RUN build-jax.sh \
--xla-arm64-patch /opt/xla-arm64-neon.patch \
--clean

## Transformer engine: check out source and build wheel
RUN <<"EOF" bash -ex -o pipefail
pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF

###############################################################################
## Pack jaxlib wheel and various source dirs into a pre-installation image
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} as mealkit
FROM ${BASE_IMAGE} AS mealkit
ARG URLREF_FLAX
ARG URLREF_TRANSFORMER_ENGINE
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_FLAX
Expand Down Expand Up @@ -102,20 +113,18 @@ git-clone.sh ${URLREF_FLAX} ${SRC_PATH_FLAX}
echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in
EOF

## Transformer engine: check out source and build wheel
# Copy TransformerEngine wheel from the builder stage
ENV NVTE_FRAMEWORK=jax
ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE}
RUN <<"EOF" bash -ex -o pipefail
pip install ninja && rm -rf ~/.cache/pip
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
python setup.py bdist_wheel && rm -rf build
echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" >> /opt/pip-tools.d/requirements-te.in
COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
RUN <<"EOF" bash -ex
ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl
echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" > /opt/pip-tools.d/requirements-te.in
EOF

###############################################################################
## Install primary packages and transitive dependencies
###############################################################################

FROM mealkit as final
FROM mealkit AS final
RUN pip-finalize.sh
1 change: 1 addition & 0 deletions .github/container/Dockerfile.triton
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ EOF
RUN <<"EOF" bash -ex
mkdir /opt/llvm-build
pushd /opt/llvm-build
pip install ninja && rm -rf ~/.cache/pip
cmake -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=clang \
Expand Down
14 changes: 12 additions & 2 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,24 @@ popd

## Build jaxlib
mkdir -p "${BUILD_PATH_JAXLIB}"
if [[ ! -e "/usr/local/cuda/lib" ]]; then
ln -s /usr/local/cuda/lib64 /usr/local/cuda/lib
fi

if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then
echo 'try-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
fi
cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF
build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
EOF
time python "${SRC_PATH_JAX}/build/build.py" \
--editable \
--use_clang \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--bazel_options=--repo_env=HERMETIC_CUDA_VERSION=$CUDA_VERSION \
--bazel_options=--repo_env=HERMETIC_CUDNN_VERSION=$TF_CUDNN_VERSION \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
Expand Down
31 changes: 29 additions & 2 deletions .github/container/install-cudnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,37 @@ if [[ -z "${libcudnn_version}" || -z "${libcudnn_dev_version}" ]]; then
exit 1
fi

apt-get update
apt-get install -y \
${libcudnn_name}=${libcudnn_version} \
${libcudnn_dev_name}=${libcudnn_dev_version}

apt-get clean
rm -rf /var/lib/apt/lists/*

# Create a prefix with include/ and lib/ directories containing symlinks to the cuDNN
# version that was just installed; this is useful to pass to XLA to avoid it fetching
# its own copy of cuDNN.
prefix=/opt/nvidia/cudnn
if [[ -d "${prefix}" ]]; then
echo "Skipping link farm creation"
exit 1
fi
arch=$(uname -m)-linux-gnu
for cudnn_file in $(dpkg -L ${libcudnn_name} ${libcudnn_dev_name} | sort -u); do
# Real files and symlinks are linked into $prefix
if [[ -f "${cudnn_file}" || -h "${cudnn_file}" ]]; then
# Replace /usr with $prefix
nosysprefix="${cudnn_file#"/usr/"}"
# include/x86_64-linux-gpu -> include/
noarchinclude="${nosysprefix/#"include/${arch}"/include}"
# cudnn_v9.h -> cudnn.h
noverheader="${noarchinclude/%"_v${CUDNN_MAJOR_VERSION}.h"/.h}"
# lib/x86_64-linux-gnu -> lib/
noarchlib="${noverheader/#"lib/${arch}"/lib}"
link_name="${prefix}/${noarchlib}"
link_dir=$(dirname "${link_name}")
mkdir -p "${link_dir}"
ln -s "${cudnn_file}" "${link_name}"
else
echo "Skipping ${cudnn_file}"
fi
done
67 changes: 45 additions & 22 deletions .github/container/install-nccl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,51 @@ export TZ=America/Los_Angeles
# If NCCL is already installed, don't reinstall it. Print a message and exit
if dpkg -s libnccl2 libnccl-dev &> /dev/null; then
echo "NCCL is already installed. Skipping installation."
exit 0
else
apt-get update

# Extract CUDA version from `nvcc --version` output line
# Input: "Cuda compilation tools, release X.Y, VX.Y.Z"
# Output: X.Y
cuda_version=$(nvcc --version | sed -n 's/^.*release \([0-9]*\.[0-9]*\).*$/\1/p')

# Find latest NCCL version compatible with existing CUDA by matching
# ${cuda_version} in the package version string
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then
echo "Could not find compatible NCCL version for CUDA ${cuda_version}"
exit 1
fi

apt-get install -y \
libnccl2=${libnccl2_version} \
libnccl-dev=${libnccl_dev_version}

apt-get clean
rm -rf /var/lib/apt/lists/*
fi

apt-get update

# Extract CUDA version from `nvcc --version` output line
# Input: "Cuda compilation tools, release X.Y, VX.Y.Z"
# Output: X.Y
cuda_version=$(nvcc --version | sed -n 's/^.*release \([0-9]*\.[0-9]*\).*$/\1/p')

# Find latest NCCL version compatible with existing CUDA by matching
# ${cuda_version} in the package version string
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then
echo "Could not find compatible NCCL version for CUDA ${cuda_version}"
exit 1
# Create a prefix with include/ and lib/ directories containing symlinks to the NCCL
# version installed at the system level; this is useful to pass to XLA to avoid it
# fetching its own copy.
prefix=/opt/nvidia/nccl
if [[ -d "${prefix}" ]]; then
echo "Skipping link farm creation"
exit 1
fi

apt-get install -y \
libnccl2=${libnccl2_version} \
libnccl-dev=${libnccl_dev_version}

apt-get clean
rm -rf /var/lib/apt/lists/*
arch=$(uname -m)-linux-gnu
for nccl_file in $(dpkg -L libnccl2 libnccl-dev | sort -u); do
# Real files and symlinks are linked into $prefix
if [[ -f "${nccl_file}" || -h "${nccl_file}" ]]; then
# Replace /usr with $prefix and remove arch-specific lib directories
nosysprefix="${nccl_file#"/usr/"}"
noarchlib="${nosysprefix/#"lib/${arch}"/lib}"
link_name="${prefix}/${noarchlib}"
link_dir=$(dirname "${link_name}")
mkdir -p "${link_dir}"
ln -s "${nccl_file}" "${link_name}"
else
echo "Skipping ${nccl_file}"
fi
done
Loading