Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ship Transformer Engine in the JAX container #132

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base
ARG REPO_JAX="https://github.com/google/jax.git"
ARG REPO_XLA="https://github.com/openxla/xla.git"
ARG REPO_TE="https://github.com/NVIDIA/TransformerEngine.git"
ARG REF_JAX=main
ARG REF_XLA=main
ARG REF_TE=main
ARG SRC_PATH_JAX=/opt/jax-source
ARG SRC_PATH_XLA=/opt/xla-source
ARG SRC_PATH_TE=/opt/transformer-engine
ARG BAZEL_CACHE=/tmp
ARG BUILD_DATE

Expand All @@ -15,16 +18,20 @@ ARG BUILD_DATE
FROM ${BASE_IMAGE} as jax-builder
ARG REPO_JAX
ARG REPO_XLA
ARG REPO_TE
ARG REF_JAX
ARG REF_XLA
ARG REF_TE
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_TE
ARG BAZEL_CACHE

RUN git clone "${REPO_JAX}" "${SRC_PATH_JAX}" && cd "${SRC_PATH_JAX}" && git checkout ${REF_JAX}
RUN --mount=type=ssh \
--mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \
git clone "${REPO_XLA}" "${SRC_PATH_XLA}" && cd "${SRC_PATH_XLA}" && git checkout ${REF_XLA}
RUN git clone "${REPO_TE}" "${SRC_PATH_TE}" && cd "${SRC_PATH_TE}" && git checkout ${REF_TE} && git submodule init && git submodule update --recursive

ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
RUN build-jax.sh \
Expand All @@ -36,6 +43,7 @@ RUN build-jax.sh \

RUN cp -r ${SRC_PATH_JAX} ${SRC_PATH_JAX}-no-git && rm -rf ${SRC_PATH_JAX}-no-git/.git
RUN cp -r ${SRC_PATH_XLA} ${SRC_PATH_XLA}-no-git && rm -rf ${SRC_PATH_XLA}-no-git/.git
RUN cp -r ${SRC_PATH_TE} ${SRC_PATH_TE}-no-git && rm -rf ${SRC_PATH_TE}-no-git/.git

###############################################################################
## Build 'runtime' flavor without the git metadata
Expand All @@ -45,15 +53,24 @@ ARG BASE_IMAGE
FROM ${BASE_IMAGE} as runtime-image
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_TE
ARG BUILD_DATE
ENV BUILD_DATE=${BUILD_DATE}

COPY --from=jax-builder ${SRC_PATH_JAX}-no-git ${SRC_PATH_JAX}
COPY --from=jax-builder ${SRC_PATH_XLA}-no-git ${SRC_PATH_XLA}
COPY --from=jax-builder ${SRC_PATH_TE}-no-git ${SRC_PATH_TE}

RUN pip --disable-pip-version-check install ${SRC_PATH_JAX}/dist/*.whl && \
pip --disable-pip-version-check install -e ${SRC_PATH_JAX} && \
rm -rf ~/.cache/pip/
RUN <<EOF
# Transformer Engine installation dependencies
pip install --no-cache-dir pybind11 ninja packaging
# Install JAX + Transformer Engine
NVTE_FRAMEWORK=jax pip --disable-pip-version-check --no-cache-dir install -e \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering what you think about setting ENV NVTE_FRAMEWORK=jax. Since we must set that environment variable to correctly install TE, it seems more friendly for re-installs. For example, when installing t5x next, it'll look something like pip install -e /opt/transformer-engine -e /opt/t5x and if NVTE_FRAMEWORK is set in the base container, we can save a few characters when installing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The env variable isn't needed anymore if jax is already installed.
By default, the setup.py file try to import all the fw and will build for those installed.

${SRC_PATH_JAX}/dist/*.whl \
${SRC_PATH_JAX} \
${SRC_PATH_TE}
rm -rf ~/.cache/pip/
EOF

# Install software stack in JAX ecosystem
# Made this optional since tensorstore cannot build on Ubuntu 20.04 + ARM
Expand All @@ -69,8 +86,10 @@ RUN { pip install flax || true; } && rm -rf ~/.cache/pip
FROM runtime-image as devel-image
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_TE

ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

COPY --from=jax-builder ${SRC_PATH_JAX}/.git ${SRC_PATH_JAX}/.git
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
COPY --from=jax-builder ${SRC_PATH_TE}/.git ${SRC_PATH_TE}/.git
12 changes: 0 additions & 12 deletions .github/container/Dockerfile.te

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/_build_pax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
type: string
description: 'Base docker image that provides JAX'
required: false
default: ghcr.io/nvidia/jax-te:latest
default: ghcr.io/nvidia/jax:latest
BUILD_DATE:
type: string
description: "Build date in YYYY-MM-DD format"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_build_t5x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
type: string
description: 'Base docker image that provides JAX'
required: false
default: ghcr.io/nvidia/jax-te:latest
default: ghcr.io/nvidia/jax:latest
BUILD_DATE:
type: string
description: "Build date in YYYY-MM-DD format"
Expand Down
90 changes: 0 additions & 90 deletions .github/workflows/_build_te.yaml

This file was deleted.

42 changes: 8 additions & 34 deletions .github/workflows/_sandbox.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,13 @@ name: "~Sandbox"
on:
workflow_dispatch:

jobs:
sandbox:
runs-on: ubuntu-22.04
steps:
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container

- name: Print usage
run: |
cat << EOF
This is an empty workflow file located in the main branch of your
repository. It serves as a testing ground for new GitHub Actions on
development branches before merging them to the main branch. By
defining and overloading this workflow on your development branch,
you can test new actions without affecting your main branch, ensuring
a smooth integration process once the changes are ready to be merged.
jobs:

Usage:

1. In your development branch, modify the sandbox.yml workflow file
to include the new actions you want to test. Make sure to commit
the changes to the development branch.
2. Navigate to the 'Actions' tab in your repository, select the
'~Sandbox' workflow, and choose your development branch from the
branch dropdown menu. Click on 'Run workflow' to trigger the
workflow on your development branch.
3. Once you have tested and verified the new actions in the Sandbox
workflow, you can incorporate them into your main workflow(s) and
merge the development branch into the main branch. Remember to
revert the changes to the sandbox.yml file in the main branch to
keep it empty for future testing.
EOF
build:
uses: ./.github/workflows/_build_jax.yaml
secrets: inherit
2 changes: 1 addition & 1 deletion .github/workflows/_test_te.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
type: string
description: 'JAX-TE image build by NVIDIA/JAX-Toolbox'
required: true
default: 'ghcr.io/nvidia/jax-te:latest'
default: 'ghcr.io/nvidia/jax:latest'
outputs:
UNIT_TEST_ARTIFACT_NAME:
description: 'Name of the unit test artifact for downstream workflows'
Expand Down
17 changes: 3 additions & 14 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,6 @@ jobs:
REF_XLA: ${{ needs.metadata.outputs.REF_XLA }}
secrets: inherit

build-te:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_te.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_TE: ${{ needs.metadata.outputs.REPO_TE }}
REF_TE: ${{ needs.metadata.outputs.REF_TE }}
secrets: inherit

build-t5x:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_t5x.yaml
Expand Down Expand Up @@ -170,7 +160,7 @@ jobs:
secrets: inherit

build-summary:
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax]
needs: [build-base, build-jax, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax]
runs-on: ubuntu-22.04
steps:
- name: Generate job summary for container build
Expand All @@ -183,7 +173,6 @@ jobs:
| ------------ | -------------------------------------------------- |
| Base | ${{ needs.build-base.outputs.DOCKER_TAGS }} |
| JAX | ${{ needs.build-jax.outputs.DOCKER_TAGS }} |
| JAX-TE | ${{ needs.build-te.outputs.DOCKER_TAGS }} |
| T5X | ${{ needs.build-t5x.outputs.DOCKER_TAGS }} |
| PAX | ${{ needs.build-pax.outputs.DOCKER_TAGS }} |
| ROSETTA(t5x) | ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} |
Expand All @@ -198,10 +187,10 @@ jobs:
secrets: inherit

test-te:
needs: build-te
needs: [build-jax, test-jax]
uses: ./.github/workflows/_test_te.yaml
with:
JAX_TE_IMAGE: ${{ needs.build-te.outputs.DOCKER_TAGS }}
JAX_TE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
secrets: inherit

test-t5x:
Expand Down
64 changes: 0 additions & 64 deletions .github/workflows/nightly-te-build.yaml

This file was deleted.

Loading
Loading