diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 2ecf9a03e..86d930459 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -1,5 +1,6 @@ # syntax=docker/dockerfile:1-labs ARG BASE_IMAGE=ghcr.io/nvidia/jax:base +ARG BUILD_PATH_JAXLIB=/opt/jaxlib ARG URLREF_JAX=https://github.com/google/jax.git#main ARG URLREF_XLA=https://github.com/openxla/xla.git#main ARG URLREF_FLAX=https://github.com/google/flax.git#main @@ -24,6 +25,7 @@ ARG URLREF_XLA ARG SRC_PATH_JAX ARG SRC_PATH_XLA ARG BAZEL_CACHE +ARG BUILD_PATH_JAXLIB ARG GIT_USER_NAME ARG GIT_USER_EMAIL @@ -44,6 +46,7 @@ RUN ARCH="$(dpkg --print-architecture)" && \ ADD xla-arm64-neon.patch /opt RUN build-jax.sh \ --bazel-cache ${BAZEL_CACHE} \ + --build-path-jaxlib ${BUILD_PATH_JAXLIB} \ --src-path-jax ${SRC_PATH_JAX} \ --src-path-xla ${SRC_PATH_XLA} \ --sm all \ @@ -63,6 +66,7 @@ ARG SRC_PATH_XLA ARG SRC_PATH_FLAX ARG SRC_PATH_TRANSFORMER_ENGINE ARG BUILD_DATE +ARG BUILD_PATH_JAXLIB ENV BUILD_DATE=${BUILD_DATE} # The following environment variables tune performance @@ -73,7 +77,7 @@ ENV CUDA_DEVICE_MAX_CONNECTIONS=1 ENV NCCL_NVLS_ENABLE=0 ENV CUDA_MODULE_LOADING=EAGER - +COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB} COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX} COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA} COPY --from=builder /usr/local/bin/bazel /usr/local/bin/bazel @@ -82,13 +86,11 @@ COPY --from=builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yam ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ RUN mkdir -p /opt/pip-tools.d + +## Editable installations of jax and jaxlib RUN <<"EOF" bash -ex -# Encourage a newer numpy so that pip's dependency resolver will allow newer -# versions of other packages that rely on newer numpy, but also include fixes -# for compatibility with newer JAX versions. e.g. chex. -echo "numpy >= 1.24.1" >> /opt/pip-tools.d/requirements-jax.in -echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in -echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in +echo "-e file://${BUILD_PATH_JAXLIB}" > /opt/pip-tools.d/requirements-jax.in +echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in EOF ## Flax @@ -113,5 +115,4 @@ EOF ############################################################################### FROM mealkit as final - RUN pip-finalize.sh diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index 3e1a94c4a..8ef6e7090 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -1,5 +1,5 @@ #!/bin/bash - +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) set -e ## Utility methods @@ -20,13 +20,6 @@ supported_compute_capabilities() { fi } -clean() { - bazel clean --expunge || true - rm -rf bazel - rm -rf .jax_configure.bazelrc - rm -rf ${HOME}/.cache/bazel -} - ## Parse command-line arguments usage() { @@ -37,6 +30,7 @@ usage() { echo " OPTIONS DESCRIPTION" echo " --bazel-cache URI Path for local bazel cache or URL of remote bazel cache" echo " --build-param PARAM Param passed to the jaxlib build command. Can be passed many times." + echo " --build-path-jaxlib PATH Editable install location for jaxlib" echo " --clean Delete local configuration and bazel cache" echo " --clean-only Do not build, just cleanup" echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc." @@ -57,6 +51,7 @@ usage() { # Set defaults BAZEL_CACHE="" +BUILD_PATH_JAXLIB="/opt/jaxlib" BUILD_PARAM="" CLEAN=0 CLEANONLY=0 @@ -64,13 +59,12 @@ CPU_ARCH="$(dpkg --print-architecture)" CUDA_COMPUTE_CAPABILITIES="local" DEBUG=0 DRY=0 -EDITABLE=0 JAXLIB_ONLY=0 SRC_PATH_JAX="/opt/jax" SRC_PATH_XLA="/opt/xla" XLA_ARM64_PATCH_LIST="" -args=$(getopt -o h --long bazel-cache:,build-param:,clean,cpu-arch:,debug,jaxlib_only,no-clean,clean-only,dry,help,src-path-jax:,src-path-xla:,sm:,xla-arm64-patch: -- "$@") +args=$(getopt -o h --long bazel-cache:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,jaxlib_only,no-clean,clean-only,dry,help,src-path-jax:,src-path-xla:,sm:,xla-arm64-patch: -- "$@") if [[ $? -ne 0 ]]; then exit 1 fi @@ -86,6 +80,10 @@ while [ : ]; do BUILD_PARAM="$BUILD_PARAM $2" shift 2 ;; + --build-path-jaxlib) + BUILD_PATH_JAXLIB="$2" + shift 2 + ;; -h | --help) usage 1 ;; @@ -148,6 +146,15 @@ done SRC_PATH_JAX=$(realpath $SRC_PATH_JAX) SRC_PATH_XLA=$(realpath $SRC_PATH_XLA) +clean() { + pushd "${SRC_PATH_JAX}" + bazel clean --expunge || true + rm -rf bazel + rm -rf .jax_configure.bazelrc + rm -rf ${HOME}/.cache/bazel + popd +} + export DEBIAN_FRONTEND=noninteractive export TZ=America/Los_Angeles @@ -175,7 +182,7 @@ if [[ ! -z "${CUDA_COMPUTE_CAPABILITIES}" ]]; then export TF_CUDA_COMPUTE_CAPABILITIES=$(supported_compute_capabilities ${CPU_ARCH}) if [[ $? -ne 0 ]]; then exit 1; fi elif [[ "$CUDA_COMPUTE_CAPABILITIES" == "local" ]]; then - export TF_CUDA_COMPUTE_CAPABILITIES=$(./local_cuda_arch) + export TF_CUDA_COMPUTE_CAPABILITIES=$("${SCRIPT_DIR}/local_cuda_arch") else export TF_CUDA_COMPUTE_CAPABILITIES="${CUDA_COMPUTE_CAPABILITIES}" fi @@ -199,6 +206,7 @@ echo " Configuration " echo "--------------------------------------------------" print_var BAZEL_CACHE +print_var BUILD_PATH_JAXLIB print_var BUILD_PARAM print_var CLEAN print_var CLEANONLY @@ -215,7 +223,6 @@ print_var TF_CUDNN_VERSION print_var TF_NCCL_VERSION print_var CC_OPT_FLAGS -print_var BUILD_PARAM print_var XLA_ARM64_PATCH_LIST echo "==================================================" @@ -232,18 +239,6 @@ fi set -x -## install tools - -apt-get update -apt-get install -y \ - build-essential \ - checkinstall \ - git \ - wget \ - curl - -pip install wheel pre-commit mypy numpy build - # apply patch for XLA pushd $SRC_PATH_XLA @@ -258,13 +253,9 @@ fi popd ## Build jaxlib - -pushd $SRC_PATH_JAX - -# Delete old wheel if one already exist. -rm -rf dist/j*.whl - -time python build/build.py \ +mkdir -p "${BUILD_PATH_JAXLIB}" +time python "${SRC_PATH_JAX}/build/build.py" \ + --editable \ --use_clang \ --enable_cuda \ --cuda_path=$TF_CUDA_PATHS \ @@ -275,14 +266,21 @@ time python build/build.py \ --enable_nccl=true \ --bazel_options=--linkopt=-fuse-ld=lld \ --bazel_options=--override_repository=xla=$SRC_PATH_XLA \ + --output_path=${BUILD_PATH_JAXLIB} \ $BUILD_PARAM -popd +# Make sure that JAX depends on the local jaxlib installation +# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels +line="jaxlib @ file://${BUILD_PATH_JAXLIB}" +if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then + pushd "${SRC_PATH_JAX}" + echo "${line}" >> build/requirements.in + bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.10" + popd +fi ## Install the built packages -pushd $SRC_PATH_JAX - # Uninstall jaxlib in case this script was used before. if [[ "$JAXLIB_ONLY" == "0" ]]; then pip uninstall -y jax jaxlib @@ -291,15 +289,13 @@ else fi # install jaxlib -pip --disable-pip-version-check install dist/*.whl +pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB} # install jax if [[ "$JAXLIB_ONLY" == "0" ]]; then - pip --disable-pip-version-check install . + pip --disable-pip-version-check install -e "${SRC_PATH_JAX}" fi -popd - ## Cleanup pushd $SRC_PATH_JAX diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 28d219db7..efc38dc9c 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -7,8 +7,8 @@ usage() { echo "" echo "Usage: $0 [OPTION]... TESTS" echo " -b, --battery Specify predefined test batteries to run." - echo " --build-jaxlib Runs the JAX tests using jaxlib built form source." - echo " --cache-test-results yes|no|auto, passes through to bazel `--cache_test_results`" + echo " --build-jaxlib Runs the JAX tests using jaxlib built from source." + echo " --cache-test-results yes|no|auto, passes through to bazel --cache_test_results" echo " --reuse-jaxlib Runs the JAX tests using preinstalled jaxlib. (DEFAULT)" echo " --disable-x64 Disable 64-bit floating point support in JAX (some tests may fail)" echo " --enable-x64 Enable 64-bit floating point support in JAX (DEFAULT, required for some tests)"