Skip to content

Commit

Permalink
Use pip-compile to help with consistent Python dependency resolution (#…
Browse files Browse the repository at this point in the history
…371)

# Summary

- All Python packages, except for a few build dependencies, are now
installed using **pip-tools**.
- The JAX and upstream T5X/PAX containers are now built in a two-stage
procedure:
1. The **'meal kit'** stage: source packages are downloaded, wheels
built if necessary (for TE, tensorflow-text, lingvo, etc.), but **no**
package is installed. Instead, manifest files are created in the
`/opt/pip-tools.d` folder to instruct which packages shall be installed
by pip-tools. The stage is named due to its similarity in how
ingredients in a meal kit are prepared while deferring the final cooking
step.
2. The **'final'** (cooking🔥) stage: this is when pip-tools collectively
compile the manifests from the various container layers and then
sync-install everything to exactly match the resolved versions.
- Note that downstream containers will **build on top of the meal kit
image of its base container**, thus ensuring all packages and
dependencies are installed exactly once to avoid conflicts and image
bloating.
- The meal kit and final images are published as
- mealkit: `ghcr.io/nvidia/image:mealkit` and
`ghcr.io/nvidia/image:mealkit-YYYY-MM-DD`
- final: `ghcr.io/nvidia/image:latest` and
`ghcr.io/nvidia/image:nightly-YYYY-MM-DD`

# Additional changes to the workflows

- `/opt/jax-source` is renamed to `/opt/jax`. The `-source` suffix is
only added to packages that needs compilation, e.g. XLA and TE.
- The CI workflow is now matricized against CPU arch.
- The reusable `_build_*.yaml` workflows are simplified to build only
one image for a single architecture at a time. The logic for creating
multi-arch images is relocated into the `_publish_container.yaml`
workflows and involved during the nightly runs only.
- TE is now built as a wheel and shipped in the JAX core meal kit image.
- TE unit tests will be performed using the upstream-pax image due to
the dependency on praxis.
- Build workflows now produce sitreps following the paradigm of #229.
- Removed the various one-off workflows for pinned CUDA/JAX versions.
- Refactored the PAX arm64 Dockerfile in preparation for #338

# What remains to be done

- [ ] Update the Rosetta container build + test process to use the
upstream T5X/PAX mealkit (ghcr.io/nvidia/upstream-t5x:mealkit,
ghcr.io/nvidia/upstream-pax:mealkit) containers

# Reviewing tips

This PR requires a multitude of reviewers due to its size and scope. I'd
truly appreciate code owners to review any changes related to their
previous contributions. An incomplete list of reviewer-scope is:
- @terrykong, @ashors1, @sharathts, @maanug-nv: Rosetta, TE, T5X and PAX
MGMN tests
- @nouiz: JAX, TE and T5X build
- @joker-eph: PAX arm64 build
- @nluehr: Base image, NCCL, PAX
- @DwarKapex: base/JAX/XLA build, workflow logic

Closes #223
Closes #230 
Closes #231 
Closes #232 
Closes #233 
Closes #271
Fixes #328
Fixes #337 

Co-authored-by: Terry Kong <terryk@nvidia.com>

---------

Co-authored-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: Vladislav Kozlov <vkozlov@nvidia.com>
  • Loading branch information
3 people authored Nov 21, 2023
1 parent 2aa961a commit ca0b396
Show file tree
Hide file tree
Showing 49 changed files with 1,697 additions and 2,321 deletions.
22 changes: 21 additions & 1 deletion .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
ARG BASE_IMAGE=nvidia/cuda:12.2.0-devel-ubuntu22.04
ARG GIT_USER_NAME="JAX Toolbox"
ARG GIT_USER_EMAIL=jax@nvidia.com

FROM ${BASE_IMAGE}
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME

###############################################################################
## Install Python and essential tools
Expand All @@ -17,13 +22,28 @@ RUN apt-get update && \
git \
lld \
vim \
bat \
curl \
git \
gnupg \
rsync \
python-is-python3 \
python3-pip \
liblzma-dev \
wget \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade --no-cache-dir pip
RUN <<"EOF" bash -ex
git config --global user.name "${GIT_USER_NAME}"
git config --global user.email "${GIT_USER_EMAIL}"
EOF
RUN pip install --upgrade --no-cache-dir pip pip-tools && rm -rf ~/.cache/*
RUN mkdir -p /opt/pip-tools.d
ADD --chmod=777 \
get-source.sh \
pip-finalize.sh \
/usr/local/bin/

###############################################################################
## Install cuDNN
Expand Down
74 changes: 48 additions & 26 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base
ARG REPO_JAX="https://github.com/google/jax.git"
ARG REPO_XLA="https://github.com/openxla/xla.git"
ARG REPO_FLAX="https://github.com/google/flax.git"
ARG REPO_TE="https://github.com/NVIDIA/TransformerEngine.git"
ARG REF_JAX=main
ARG REF_XLA=main
ARG SRC_PATH_JAX=/opt/jax-source
ARG REF_FLAX=main
ARG REF_TE=main
ARG SRC_PATH_JAX=/opt/jax
ARG SRC_PATH_XLA=/opt/xla-source
ARG SRC_PATH_FLAX=/opt/flax
ARG SRC_PATH_TE=/opt/transformer-engine-source
ARG GIT_USER_NAME="JAX Toolbox"
ARG GIT_USER_EMAIL=jax@nvidia.com

ARG BAZEL_CACHE=/tmp
ARG BUILD_DATE

###############################################################################
## Build JAX
###############################################################################

FROM ${BASE_IMAGE} as jax-builder
FROM ${BASE_IMAGE} as builder
ARG REPO_JAX
ARG REPO_XLA
ARG REF_JAX
ARG REF_XLA
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG BAZEL_CACHE
ARG GIT_USER_NAME
ARG GIT_USER_EMAIL

RUN git clone "${REPO_JAX}" "${SRC_PATH_JAX}" && cd "${SRC_PATH_JAX}" && git checkout ${REF_JAX}
RUN --mount=type=ssh \
Expand All @@ -30,8 +41,8 @@ RUN --mount=type=ssh \
RUN <<EOF bash -ex
cd ${SRC_PATH_XLA}

git config user.name "JAX Toolbox"
git config user.email "jax@nvidia.com"
git config --global user.name "${GIT_USER_NAME}"
git config --global user.email "${GIT_USER_EMAIL}"
git remote add -f ashors1 https://github.com/ashors1/xla
git cherry-pick --allow-empty $(git merge-base ashors/main ashors1/revert-84222)..ashors1/revert-84222
git remote remove ashors1
Expand All @@ -47,15 +58,12 @@ RUN build-jax.sh \
--xla-arm64-patch /opt/xla-arm64-neon.patch \
--clean

RUN cp -r ${SRC_PATH_JAX} ${SRC_PATH_JAX}-no-git && rm -rf ${SRC_PATH_JAX}-no-git/.git
RUN cp -r ${SRC_PATH_XLA} ${SRC_PATH_XLA}-no-git && rm -rf ${SRC_PATH_XLA}-no-git/.git

###############################################################################
## Build 'runtime' flavor without the git metadata
## Pack jaxlib wheel and various source dirs into a pre-installation image
###############################################################################

ARG BASE_IMAGE
FROM ${BASE_IMAGE} as runtime-image
FROM ${BASE_IMAGE} as mealkit
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG BUILD_DATE
Expand All @@ -67,29 +75,43 @@ ENV NCCL_IB_SL=1
ENV NCCL_NVLS_ENABLE=0
ENV CUDA_MODULE_LOADING=EAGER

COPY --from=jax-builder ${SRC_PATH_JAX}-no-git ${SRC_PATH_JAX}
COPY --from=jax-builder ${SRC_PATH_XLA}-no-git ${SRC_PATH_XLA}
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

RUN pip --disable-pip-version-check install ${SRC_PATH_JAX}/dist/*.whl && \
pip --disable-pip-version-check install -e ${SRC_PATH_JAX} && \
rm -rf ~/.cache/pip/
RUN mkdir -p /opt/pip-tools.d
RUN <<"EOF" bash -ex
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/manifest.jax
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax
EOF

# Install software stack in JAX ecosystem
# Made this optional since tensorstore cannot build on Ubuntu 20.04 + ARM
RUN { pip install flax || true; } && rm -rf ~/.cache/*
## Flax
ARG REPO_FLAX
ARG REF_FLAX
ARG SRC_PATH_FLAX
RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax

## Transformer engine: check out source and build wheel
ARG REPO_TE
ARG REF_TE
ARG SRC_PATH_TE
ENV NVTE_FRAMEWORK=jax
ENV SRC_PATH_TE=${SRC_PATH_TE}
RUN <<"EOF" bash -ex
set -o pipefail
pip install ninja && rm -rf ~/.cache/pip
get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE}
pushd ${SRC_PATH_TE}
python setup.py bdist_wheel && rm -rf build
echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te
EOF

# TODO: properly configure entrypoint
# COPY entrypoint.d/ /opt/nvidia/entrypoint.d/

###############################################################################
## Build 'devel' image with build scripts and git metadata
## Install primary packages and transitive dependencies
###############################################################################

FROM runtime-image as devel-image
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA

ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
FROM mealkit as final

COPY --from=jax-builder ${SRC_PATH_JAX}/.git ${SRC_PATH_JAX}/.git
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
RUN pip-finalize.sh
69 changes: 43 additions & 26 deletions .github/container/Dockerfile.pax.amd64
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
# syntax=docker/dockerfile:1-labs
###############################################################################
## Pax
###############################################################################

ARG BASE_IMAGE=ghcr.io/nvidia/jax:latest
FROM ${BASE_IMAGE}

ADD install-pax.sh /usr/local/bin
ADD install-flax.sh /usr/local/bin
ADD install-te.sh /usr/local/bin

ENV NVTE_FRAMEWORK=jax
ARG REPO_PAXML=https://github.com/google/paxml.git
ARG REPO_PRAXIS=https://github.com/google/praxis.git
ARG REF_PAXML=main
ARG REF_PRAXIS=main
ARG REPO_TE=https://github.com/NVIDIA/TransformerEngine.git
# TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch
# This should be reverted to main ASAP
ARG REF_TE=7976bd003fcf084dd068069b92a9a79b1743316a
ARG SRC_PATH_PAXML=/opt/paxml
ARG SRC_PATH_PRAXIS=/opt/praxis

###############################################################################
## Download source and add auxiliary scripts
###############################################################################

FROM ${BASE_IMAGE} as mealkit
ARG REPO_PAXML
ARG REPO_PRAXIS
ARG REF_PAXML
ARG REF_PRAXIS
ARG SRC_PATH_PAXML
ARG SRC_PATH_PRAXIS

# update TE manifest file to install the [test] extras
RUN sed -i "s/transformer-engine @/transformer-engine[test] @/g" /opt/pip-tools.d/manifest.te

RUN <<"EOF" bash -ex
install-pax.sh --defer --from_paxml ${REPO_PAXML} --from_praxis ${REPO_PRAXIS} --ref_paxml ${REF_PAXML} --ref_praxis ${REF_PRAXIS}
install-flax.sh --defer
install-te.sh --defer --from ${REPO_TE} --ref ${REF_TE}

if [[ -f /opt/requirements-defer.txt ]]; then
# SKIP_HEAD_INSTALLS avoids having to install jax from Github source so that
# we do not overwrite the jax that was already installed.
SKIP_HEAD_INSTALLS=true pip install -r /opt/requirements-defer.txt
fi
if [[ -f /opt/cleanup.sh ]]; then
bash -ex /opt/cleanup.sh
fi
get-source.sh -f ${REPO_PAXML} -r ${REF_PAXML} -d ${SRC_PATH_PAXML}
get-source.sh -f ${REPO_PRAXIS} -r ${REF_PRAXIS} -d ${SRC_PATH_PRAXIS}
echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/manifest.pax
echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/manifest.pax

for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do
pushd ${src}
sed -i "s| @ git+https://github.com/google/flax||g" requirements.in
sed -i "s| @ git+https://github.com/google/jax||g" requirements.in
if git diff --quiet; then
echo "URL specs no longer present in select dependencies for ${src}"
exit 1
else
git commit -a -m "remove URL specs from select dependencies for ${src}"
fi
popd
done
EOF

ADD test-pax.sh /usr/local/bin

###############################################################################
## Install accumulated packages from the base image and the previous stage
###############################################################################

FROM mealkit as final

RUN pip-finalize.sh
Loading

0 comments on commit ca0b396

Please sign in to comment.