diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 8419bbf96..86d930459 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -86,13 +86,10 @@ COPY --from=builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yam ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ RUN mkdir -p /opt/pip-tools.d -# Use the same requirements as JAX's hermetic Python configuration, but use an -# editable installation of jaxlib -- bazel won't accept that for the hermetic -# test environment + +## Editable installations of jax and jaxlib RUN <<"EOF" bash -ex -cp "${SRC_PATH_JAX}/build/requirements.in" /opt/pip-tools.d/requirements-jax.in -sed -i -e "s#-r test-requirements.txt##" /opt/pip-tools.d/requirements-jax.in -sed -i -e "s#jaxlib @ file://${BUILD_PATH_JAXLIB}#-e ${BUILD_PATH_JAXLIB}#" /opt/pip-tools.d/requirements-jax.in +echo "-e file://${BUILD_PATH_JAXLIB}" > /opt/pip-tools.d/requirements-jax.in echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in EOF