Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible builds #223

Closed
1 task
terrykong opened this issue Sep 12, 2023 · 6 comments · Fixed by #371
Closed
1 task

Reproducible builds #223

terrykong opened this issue Sep 12, 2023 · 6 comments · Fixed by #371
Assignees

Comments

@terrykong
Copy link
Contributor

terrykong commented Sep 12, 2023

Starting this discussion with an initial proposal of how to support reproducible builds to freeze the ecosystem state leading up to a release.

Problem

At the moment, we have not implemented a mechanism to pin version X of jax, Y of paxml, and Z of a transitive dependency of paxml in our containers; where X, Y, and Z can be versions on pypi, git-refs, or even a full distribution. Currently, we rely on nightly images to pin state, but these are not reproducible since running the same build script may install more recent libraries or transitive dependencies that break functionality.

This is an issue when we want to rebuild a container for a release, because we become cautious of changing build commands in previous layers (e.g., jax or pax) in fear that the build would pick up an unintentional change.

Requirements

There are two things we need to freeze:

  1. The version of a python package, e.g., pax, jax, transformer-engine
  2. The distribution of a library, e.g., pax + (all of our patches)

Proposal

Here is my initial proposal, which we can iterate on.

Freezing python package versions

There are mature solutions to handle python package dependencies like pip-tools (pip-compile) and we can make use of those tools here. We can have:

  • requirements-jax.in
  • requirements-pax.in
  • requirements-t5x.in

Where the jax images install requirements-jax.in and the t5x images installs pip_compiled(requirements-jax.in, requirements-t5x.in).

Freezing the state of the nightly

Previously, we froze the state of all upstream repos by building it in a docker image. In order to re-build that image, we will need to update our CI to do something like the following:

(get git-sha of all upstream repos nightly) -> (write all SHA's to a manifest) -> (clone repos you need from the SHA-manifest) -> build

Freezing the state of distributions

We already use patchlists, so we can continue to do this. One point of improvement though is perhaps we need a way to guarantee that the git-ref in the patchlist has not been force-updated.

Since a git-ref (branch or PR) can be rebased, once an engineer updates the git-ref, you can't build the old build anymore. We need to store the patches somewhere. Perhaps we can freeze the diff of the git-ref in some location alongside the requirements-*.in and the "SHA manifest".

Nightly CI

We can have the nightly CI freeze all of this metadata (it would also build from this metadata), and make write it somewhere in version control so that we can fork it for releases.

Metadata location

There a couple of options we have for metadata location; here are two:

  1. Within the JAX-Toolbox git repo (have nightly CI commit to metadata branch)
    • pros: simple; authoring CI is easier; will not clutter main branch
    • cons: can cause the repo to grow (especially if we store patch diffs here too)
  2. In an auxiliary git repo
    • pros: keeps JAX-Toolbox history clean
    • cons: adds complexity to the process since atomicity is more difficult to guarantee; engineers need to be aware of multiple repos

I'm more in favor or (1) due to simplicity, but am open to other options. FWIW with (1), we can merge in (main + metadata) to form a release branch.

Installation/Building

It may be obvious, but since some of these repos have more complex build steps; often requiring apt packages, or python packages not listed in their install_requires, we'll also need one or more ./install-jax-req.sh to setup or build any wheels we need.

This could potentially live alongside the "SHA manifest"

Tasks

@terrykong
Copy link
Contributor Author

@nluehr

@nouiz
Copy link
Collaborator

nouiz commented Sep 12, 2023

we could also just print the output and the information will be in the CI log.
Even easier, but less useful.
If it takes times for a better solution, it could be a temporary solution.
Or even just call pip freeze at the end of the build script.

@terrykong
Copy link
Contributor Author

So I've made some progress looking into pip-compile. One benefit is we can represent our requirements like so:

# jax-requirements.in
jax @ file:///opt/jax-source
jaxlib @ file:///opt/jax-source/dist/jaxlib-blah-blah.wheel

and have a "t5x environment", that uses the above as a constraint

# t5x-requirements.in
-c jax-requirements.in
t5x @ file:///opt/t5x

And this seems to work well.

So for generating the jax requirements we'd do:

pip-compile jax-requirements.in
# produces jax-requirements.txt

and for t5x:

pip-compile t5x-requirements.in

and we'd get a t5x-requirements.txt out, without having to install anything.

Then we can have the "base" container have all of these *-requirements.txt as reference and the downstream build should install them a-la-carte.

For installing we can rely in pip-tools pip-sync which does the following:

Now that you have a requirements.txt, you can use pip-sync to update your virtual environment to reflect exactly what's in there. This will install/upgrade/uninstall everything necessary to match the requirements.txt contents.

The only wrinkle in all of this is VCS installs. So for t5x, there are installs like this:

seqio @ git+https://github.com/google/seqio

which are problematic b/c these reference main and that's a moving target. This issue is that if we specify the following in our requirements:

seqio @ git+https://github.com/google/seqio@30311dd4788cf30ccc1b0b0f79f976592ab17af2

This will cause a dependency conflict b/c pip treats the URLS as different dependencies. I tried a few other variations with no luck.

One way around this is to have a post-processing step on the pip-compile produced *-requirements.txt file and look up the head commit for all VCS installs and append the SHA to them. This also avoids us having to surgically search-and-replace VCS python requirements.

My current thinking is the process would look something like this (pseudocode) in the base container:

for lib in [t5x, pax, jax]; do
  bash setup-${lib}.sh  # do things like install apt packages or build the jax wheel
  pip-compile ${lib}-requirements.in --output-file ${lib}-requirements.txt
  add_shas_inplace_for_head_installs ${lib}-requirements.txt
done

Then the jax/t5x/pax containers would just look like

FROM base_container
ARG LIB
RUN pip-sync ${LIB}-requirements.txt

@ko3n1g
Copy link
Contributor

ko3n1g commented Sep 29, 2023

Working with pip-compile certainly works, just make sure to not constraint to an .in file but require the .txt requirement instead (the former does not contain versions and thus your sub-environments are not guaranteed to end up with the same versions of packages defined in your base-dependencies). So if you have two environments B and C which shall inherit from a base environment A, the workflow would be:

  1. Create pinned A.txt via pip-compile --outputfile=A.txt A.in
  2. Create pinned B.txt via pip-compile --outputfile=B.txt B.in A.txt (which is the same as if B.in would require (-r) A.txt).
  3. Same for environment C

The only problem left now is that potentially B requires a different version of some package defined in A as C does. That requirement doesn’t necessarily stem from you but some 3rd-level deps. If you want to be 100% safe, it is better to create an intermediate.txt as the requirement of all 3 in files. You now have versions of all packages promised to be compatible among your environments. You can use that as a constraint (not requirement!) for creating the individual environments as described above (just add -c intermediate.txt to each .in file).

Hope that helps!

Btw, Poetry automates a lot of this away. Might be worth to check out too.

@terrykong
Copy link
Contributor Author

@ko3n1g You're right; the workflow pip-tools recommends is using the pinned constraint file, but there are some issues with doing it this way since the downstream-requirements.in has been shown to have more strict constraints on transitive deps so that by the time the upstream-requirements.in has been compiled, the version of some packages is too new and causes a dependency conflict

I like the idea of having a intermediate.in, but that also isn't possible at the moment. because downstream-1.requirements.in and downstream-2.requirements.in have conflicting transitive dep constraints.

Ideally we'd like the upstream-reuqirements.txt to match and be a subset of downstream-{1,2}.requirements.in, but it's okay in my opinion for a few packages to be different (notably protobuf; which I've observed is diff).

@ko3n1g
Copy link
Contributor

ko3n1g commented Sep 30, 2023

You're right, @terrykong, paxml and t5x indeed have contradicting requirements with protobuf==3.20.3 and protobuf==3.19.6, my bad.

Since pinning between all three requirements is just not possible (for now), the closest you can get to your ideal state would be by doing the following:

First, a small change at jax-requirements.in is necessary since the dependency resolver is not happy about FS and VCS requirements of the same package coming from different dependencies.

# This is in-file of `jax-requirements.in`
jax @ git+https://github.com/google/jax
# TODO: Brittle, may be possible to use --find-links=./wheel_dir instead
#jaxlib @ file:///opt/jax-source/dist/jaxlib-0.4.14-cp310-cp310-linux_x86_64.whl
transformer_engine @ file:///opt/transformer-engine
numpy==1.23.1 # Enforced via t5x -> git+https://github.com/google/CommonLoopUtils#egg=clu

(I hope installing from FS is not a hard constraint, I’m of course not deep enough into the specifics of this project as to judge it.)

Assuming this works, however, we can proceed with swapping the constraints of the two upstream dependencies simply from -c jax-requirements.in to -c jax-requirements.txt.

We can build all pinned dependencies as you already proposed in your PR:

pip-compile --output-file=jax-requirements.txt jax-requirements.in \
&& pip-compile --output-file=upstream-paxml-requirements.txt upstream-paxml-requirements.in \
&& pip-compile --output-file=upstream-t5x-requirements.txt upstream-t5x-requirements.in

Now, there’s at least a pinned overlap between {jax,paxml} as well as {jax,t5x}. The different versions of (some) packages shared between paxml and t5x can as we know just not be addressed until an upstream release resolves it. This approach of course introduces a higher maintenance effort at the benefit of better cross-compatibility between containers. If that is something for now or rather later, I cannot really judge yet.

Regardless of which approach you end up deciding, I would probably recommend tracking these files in git, and update/build pinned requirements via a scheduled workflow which opens a PR (so basically something like Dependabot).

terrykong added a commit that referenced this issue Dec 1, 2023
…371)

# Summary

- All Python packages, except for a few build dependencies, are now
installed using **pip-tools**.
- The JAX and upstream T5X/PAX containers are now built in a two-stage
procedure:
1. The **'meal kit'** stage: source packages are downloaded, wheels
built if necessary (for TE, tensorflow-text, lingvo, etc.), but **no**
package is installed. Instead, manifest files are created in the
`/opt/pip-tools.d` folder to instruct which packages shall be installed
by pip-tools. The stage is named due to its similarity in how
ingredients in a meal kit are prepared while deferring the final cooking
step.
2. The **'final'** (cooking🔥) stage: this is when pip-tools collectively
compile the manifests from the various container layers and then
sync-install everything to exactly match the resolved versions.
- Note that downstream containers will **build on top of the meal kit
image of its base container**, thus ensuring all packages and
dependencies are installed exactly once to avoid conflicts and image
bloating.
- The meal kit and final images are published as
- mealkit: `ghcr.io/nvidia/image:mealkit` and
`ghcr.io/nvidia/image:mealkit-YYYY-MM-DD`
- final: `ghcr.io/nvidia/image:latest` and
`ghcr.io/nvidia/image:nightly-YYYY-MM-DD`

# Additional changes to the workflows

- `/opt/jax-source` is renamed to `/opt/jax`. The `-source` suffix is
only added to packages that needs compilation, e.g. XLA and TE.
- The CI workflow is now matricized against CPU arch.
- The reusable `_build_*.yaml` workflows are simplified to build only
one image for a single architecture at a time. The logic for creating
multi-arch images is relocated into the `_publish_container.yaml`
workflows and involved during the nightly runs only.
- TE is now built as a wheel and shipped in the JAX core meal kit image.
- TE unit tests will be performed using the upstream-pax image due to
the dependency on praxis.
- Build workflows now produce sitreps following the paradigm of #229.
- Removed the various one-off workflows for pinned CUDA/JAX versions.
- Refactored the PAX arm64 Dockerfile in preparation for #338

# What remains to be done

- [ ] Update the Rosetta container build + test process to use the
upstream T5X/PAX mealkit (ghcr.io/nvidia/upstream-t5x:mealkit,
ghcr.io/nvidia/upstream-pax:mealkit) containers

# Reviewing tips

This PR requires a multitude of reviewers due to its size and scope. I'd
truly appreciate code owners to review any changes related to their
previous contributions. An incomplete list of reviewer-scope is:
- @terrykong, @ashors1, @sharathts, @maanug-nv: Rosetta, TE, T5X and PAX
MGMN tests
- @nouiz: JAX, TE and T5X build
- @joker-eph: PAX arm64 build
- @nluehr: Base image, NCCL, PAX
- @DwarKapex: base/JAX/XLA build, workflow logic

Closes #223
Closes #230 
Closes #231 
Closes #232 
Closes #233 
Closes #271
Fixes #328
Fixes #337 

Co-authored-by: Terry Kong <terryk@nvidia.com>

---------

Co-authored-by: Terry Kong <terryk@nvidia.com>
Co-authored-by: Vladislav Kozlov <vkozlov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment