Skip to content

Commit

Permalink
Work around upstream hermetic Python changes
Browse files Browse the repository at this point in the history
- Always use editable installations of jax and jaxlib, which avoids
  having to deal with wheel filenames that have dates in them. Using an
  editable installation requires that we disable hashes.
- Do not tackle the problem that this leaves the JAX source repo in a
  dirty state. People will have to be careful not to merge the modified
  requirements files.
  • Loading branch information
olupton committed May 23, 2024
1 parent c7223f7 commit 9c872f8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 48 deletions.
17 changes: 9 additions & 8 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax:base
ARG BUILD_PATH_JAXLIB=/opt/jaxlib
ARG URLREF_JAX=https://github.com/google/jax.git#main
ARG URLREF_XLA=https://github.com/openxla/xla.git#main
ARG URLREF_FLAX=https://github.com/google/flax.git#main
Expand All @@ -24,6 +25,7 @@ ARG URLREF_XLA
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG BAZEL_CACHE
ARG BUILD_PATH_JAXLIB
ARG GIT_USER_NAME
ARG GIT_USER_EMAIL

Expand All @@ -44,6 +46,7 @@ RUN ARCH="$(dpkg --print-architecture)" && \
ADD xla-arm64-neon.patch /opt
RUN build-jax.sh \
--bazel-cache ${BAZEL_CACHE} \
--build-path-jaxlib ${BUILD_PATH_JAXLIB} \
--src-path-jax ${SRC_PATH_JAX} \
--src-path-xla ${SRC_PATH_XLA} \
--sm all \
Expand All @@ -63,6 +66,7 @@ ARG SRC_PATH_XLA
ARG SRC_PATH_FLAX
ARG SRC_PATH_TRANSFORMER_ENGINE
ARG BUILD_DATE
ARG BUILD_PATH_JAXLIB

ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
Expand All @@ -73,7 +77,7 @@ ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0
ENV CUDA_MODULE_LOADING=EAGER


COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
COPY --from=builder /usr/local/bin/bazel /usr/local/bin/bazel
Expand All @@ -82,13 +86,11 @@ COPY --from=builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yam
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

RUN mkdir -p /opt/pip-tools.d

## Editable installations of jax and jaxlib
RUN <<"EOF" bash -ex
# Encourage a newer numpy so that pip's dependency resolver will allow newer
# versions of other packages that rely on newer numpy, but also include fixes
# for compatibility with newer JAX versions. e.g. chex.
echo "numpy >= 1.24.1" >> /opt/pip-tools.d/requirements-jax.in
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in
echo "-e file://${BUILD_PATH_JAXLIB}" > /opt/pip-tools.d/requirements-jax.in
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
EOF

## Flax
Expand All @@ -113,5 +115,4 @@ EOF
###############################################################################

FROM mealkit as final

RUN pip-finalize.sh
72 changes: 34 additions & 38 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
set -e

## Utility methods
Expand All @@ -20,13 +20,6 @@ supported_compute_capabilities() {
fi
}

clean() {
bazel clean --expunge || true
rm -rf bazel
rm -rf .jax_configure.bazelrc
rm -rf ${HOME}/.cache/bazel
}

## Parse command-line arguments

usage() {
Expand All @@ -37,6 +30,7 @@ usage() {
echo " OPTIONS DESCRIPTION"
echo " --bazel-cache URI Path for local bazel cache or URL of remote bazel cache"
echo " --build-param PARAM Param passed to the jaxlib build command. Can be passed many times."
echo " --build-path-jaxlib PATH Editable install location for jaxlib"
echo " --clean Delete local configuration and bazel cache"
echo " --clean-only Do not build, just cleanup"
echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
Expand All @@ -57,20 +51,20 @@ usage() {

# Set defaults
BAZEL_CACHE=""
BUILD_PATH_JAXLIB="/opt/jaxlib"
BUILD_PARAM=""
CLEAN=0
CLEANONLY=0
CPU_ARCH="$(dpkg --print-architecture)"
CUDA_COMPUTE_CAPABILITIES="local"
DEBUG=0
DRY=0
EDITABLE=0
JAXLIB_ONLY=0
SRC_PATH_JAX="/opt/jax"
SRC_PATH_XLA="/opt/xla"
XLA_ARM64_PATCH_LIST=""

args=$(getopt -o h --long bazel-cache:,build-param:,clean,cpu-arch:,debug,jaxlib_only,no-clean,clean-only,dry,help,src-path-jax:,src-path-xla:,sm:,xla-arm64-patch: -- "$@")
args=$(getopt -o h --long bazel-cache:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,jaxlib_only,no-clean,clean-only,dry,help,src-path-jax:,src-path-xla:,sm:,xla-arm64-patch: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
Expand All @@ -86,6 +80,10 @@ while [ : ]; do
BUILD_PARAM="$BUILD_PARAM $2"
shift 2
;;
--build-path-jaxlib)
BUILD_PATH_JAXLIB="$2"
shift 2
;;
-h | --help)
usage 1
;;
Expand Down Expand Up @@ -148,6 +146,15 @@ done
SRC_PATH_JAX=$(realpath $SRC_PATH_JAX)
SRC_PATH_XLA=$(realpath $SRC_PATH_XLA)

clean() {
pushd "${SRC_PATH_JAX}"
bazel clean --expunge || true
rm -rf bazel
rm -rf .jax_configure.bazelrc
rm -rf ${HOME}/.cache/bazel
popd
}

export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

Expand Down Expand Up @@ -175,7 +182,7 @@ if [[ ! -z "${CUDA_COMPUTE_CAPABILITIES}" ]]; then
export TF_CUDA_COMPUTE_CAPABILITIES=$(supported_compute_capabilities ${CPU_ARCH})
if [[ $? -ne 0 ]]; then exit 1; fi
elif [[ "$CUDA_COMPUTE_CAPABILITIES" == "local" ]]; then
export TF_CUDA_COMPUTE_CAPABILITIES=$(./local_cuda_arch)
export TF_CUDA_COMPUTE_CAPABILITIES=$("${SCRIPT_DIR}/local_cuda_arch")
else
export TF_CUDA_COMPUTE_CAPABILITIES="${CUDA_COMPUTE_CAPABILITIES}"
fi
Expand All @@ -199,6 +206,7 @@ echo " Configuration "
echo "--------------------------------------------------"

print_var BAZEL_CACHE
print_var BUILD_PATH_JAXLIB
print_var BUILD_PARAM
print_var CLEAN
print_var CLEANONLY
Expand All @@ -215,7 +223,6 @@ print_var TF_CUDNN_VERSION
print_var TF_NCCL_VERSION
print_var CC_OPT_FLAGS

print_var BUILD_PARAM
print_var XLA_ARM64_PATCH_LIST

echo "=================================================="
Expand All @@ -232,18 +239,6 @@ fi

set -x

## install tools

apt-get update
apt-get install -y \
build-essential \
checkinstall \
git \
wget \
curl

pip install wheel pre-commit mypy numpy build

# apply patch for XLA
pushd $SRC_PATH_XLA

Expand All @@ -258,13 +253,9 @@ fi
popd

## Build jaxlib

pushd $SRC_PATH_JAX

# Delete old wheel if one already exist.
rm -rf dist/j*.whl

time python build/build.py \
mkdir -p "${BUILD_PATH_JAXLIB}"
time python "${SRC_PATH_JAX}/build/build.py" \
--editable \
--use_clang \
--enable_cuda \
--cuda_path=$TF_CUDA_PATHS \
Expand All @@ -275,14 +266,21 @@ time python build/build.py \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
--bazel_options=--override_repository=xla=$SRC_PATH_XLA \
--output_path=${BUILD_PATH_JAXLIB} \
$BUILD_PARAM

popd
# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jaxlib @ file://${BUILD_PATH_JAXLIB}"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.10"
popd
fi

## Install the built packages

pushd $SRC_PATH_JAX

# Uninstall jaxlib in case this script was used before.
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip uninstall -y jax jaxlib
Expand All @@ -291,15 +289,13 @@ else
fi

# install jaxlib
pip --disable-pip-version-check install dist/*.whl
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}

# install jax
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip --disable-pip-version-check install .
pip --disable-pip-version-check install -e "${SRC_PATH_JAX}"
fi

popd

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
4 changes: 2 additions & 2 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ usage() {
echo ""
echo "Usage: $0 [OPTION]... TESTS"
echo " -b, --battery Specify predefined test batteries to run."
echo " --build-jaxlib Runs the JAX tests using jaxlib built form source."
echo " --cache-test-results yes|no|auto, passes through to bazel `--cache_test_results`"
echo " --build-jaxlib Runs the JAX tests using jaxlib built from source."
echo " --cache-test-results yes|no|auto, passes through to bazel --cache_test_results"
echo " --reuse-jaxlib Runs the JAX tests using preinstalled jaxlib. (DEFAULT)"
echo " --disable-x64 Disable 64-bit floating point support in JAX (some tests may fail)"
echo " --enable-x64 Enable 64-bit floating point support in JAX (DEFAULT, required for some tests)"
Expand Down

0 comments on commit 9c872f8

Please sign in to comment.