Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-toolbox-triage: improve documentation #1104

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def check_container(date: datetime.date) -> bool:
jax_commit = get_commit(worker, "jax")
xla_commit = get_commit(worker, "xla")

logger.debug(result.stdout)
logger.info(f"Ran test case in {date} in {test_time:.1f}s")
test_pass = result.returncode == 0
logger.info(f"Ran test case in {date} in {test_time:.1f}s, pass={test_pass}")
logger.debug(result.stdout)
logger.debug(result.stderr)
add_summary_record(
"container",
{
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,4 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
* [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/)
* [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html)
* [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
* [Triaging regressions](docs/triage-tool.md)
249 changes: 249 additions & 0 deletions docs/triage-tool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Triage tool

`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an
individual commit of JAX or XLA.
It takes as input a command that returns an error (non-zero) code when run in "recent"
olupton marked this conversation as resolved.
Show resolved Hide resolved
containers, but which returns a success (zero) code when run in some "older" container.
The command must be executable within the containers, *i.e.* it cannot refer to files
that only exist on the host system.

The tool follows a three-step process:
1. A container-level search backwards from the "recent" container where the test is
known to fail, which identifies an "older" container where the test passes. This
search proceeds with an exponentially increasing step size and is based on the
`YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`.
2. A container-level binary search to refine this to the **latest** available
container where test passes and the **earliest** available container where it
fails.
3. A commit-level binary search, repeatedly building + testing inside the same
container, to identify a single commit of JAX (XLA) that causes the test to start
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
regression.

## Installation

The triage tool can be installed using `pip`:
```bash
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
```
or directly from a checkout of the JAX-Toolbox repository.
Because the tool needs to orchestrate running commands in multiple containers, it is
most convenient to install it in a virtual environment on the host system, rather than
attempting to install it inside a container.

The tool should be invoked on a machine with `docker` available and whatever GPUs are
needed to execute the test case.

## Usage

To use the tool, there are two compulsory arguments:
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
families to execute the test command in. Example: `jax` for a JAX unit test
failure, `maxtext` for a MaxText model execution failure
* A test command to triage.

The test command will be executed directly in the container, not inside a shell, so be
sure not to add excessive quotation marks (*i.e.* run
`jax-toolbox-triage --container=jax test-jax.sh foo` not
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
as fast and targeted as possible.
The expectation is that the test case will be executed successfully several times as
part of the triage, so you may want to tune some parameters to reduce the execution
time in the successful case.
For example, if `text-maxtext.sh --steps=500 ...` is failing on step 0, you should
probably reduce `--steps` to optimise execution time in the successful case.

A JSON status file and both info-level and debug-level logfiles are written to the
directory given by `--output-prefix`.

### Optimising container-level search performance

By default, the container-level search starts from the most recent available container,
if you already know that the test has been failing for a while, you can pass
`--end-date` to start the search further in the past.
If you are sure that the test is failing on the `--end-date` you have passed, you can
skip verification of that fact by passing `--skip-precondition-checks` (but see below
for other checks that this skips).

By default, the container-level backwards search for a date on which the test passed
tries the containers approximately [1, 2, 4, ...] days before `--end-date`.
This can be tuned by passing `--start-date`, which overrides the "end date minus one"
start value (but leaves the exponential growth of the search range width).
If you are sure that the test is passing on the `--start-date` you have passed, you can
skip verification of that fact by passing `--skip-precondition-checks`.

The combination of `--start-date`, `--end-date` and `--skip-precondition-checks` can be
used to skip the entire first stage of the bisection process.

The second stage of the triage process can be made to abort early using the
`--threshold-days` option; this stage will terminate once the delta between the latest
known-good and earliest known-bad containers is below the threshold.

If you need to re-start the tool for some reason, use of these options can help
bootstrap the tool using the results of a previous (partial) run.

### Optimising commit-level search performance

The third stage of the triage process involves repeatedly building JAX and XLA, which
can be sped up significantly using a Bazel cache.
By default, a local directory on the host machine (where the tool is being executed)
will be used, but it may be more efficient to use a persistent and/or pre-heated cache.
This can be achieved by passing the `--bazel-cache` option, which accepts absolute
paths and `http`/`https`/`grpc` URLs.

If `--skip-precondition-checks` is passed, a sanity check that the failure can be
reproduced after rebuilding the JAX/XLA commits from the first-known-bad container
inside that container will be skipped.

## Example

Here is an example execution for a JAX unit test failure, with some annotation:
```console
user@gpu-machine $ jax-toolbox-triage --container jax test-jax.sh //tests:nn_test_gpu
```
`--end-date` was not passed, and 2024-10-15 is the most recent available container
at the time of execution
```
[INFO] 2024-10-16 00:31:41 Checking end-of-range failure in 2024-10-15
```
`--skip-precondition-checks` was not passed, so the tool checks that the test does, in
fact, fail in the 2024-10-15 container
```
[INFO] 2024-10-16 00:33:36 Ran test case in 2024-10-15 in 114.8s, pass=False
```
`--start-date` was not passed, so the first (backwards search) stage of the triage
process starts with the container 1 day before the end of the range, *i.e.* 2024-10-14
```
[INFO] 2024-10-16 00:33:37 Starting coarse search with 2024-10-14 based on end_date=2024-10-15
[INFO] 2024-10-16 00:35:35 Ran test case in 2024-10-14 in 118.1s, pass=False
```
`end_date - 2 * (end_date - search_date)` = `2024-10-15 - 2 days` = `2024-10-13`
```
[INFO] 2024-10-16 00:38:11 Ran test case in 2024-10-13 in 122.4s, pass=False
```
In principle this would be 4 days before the end date, but the 2024-10-11 container
does not exist, so the tool chooses a nearby container that does exist and is older
than 2024-10-13
```
[INFO] 2024-10-16 00:40:53 Ran test case in 2024-10-12 in 127.7s, pass=False
```
Steps in date start to increase significantly
```
[INFO] 2024-10-16 00:43:28 Ran test case in 2024-10-09 in 119.3s, pass=False
[INFO] 2024-10-16 00:45:29 Ran test case in 2024-10-03 in 120.7s, pass=False
[INFO] 2024-10-16 00:47:27 Ran test case in 2024-09-21 in 116.3s, pass=False
```
The first stage of the triage process successfully identifies an old container where
this test passed
```
[INFO] 2024-10-16 00:51:22 Ran test case in 2024-08-28 in 194.0s, pass=True
[INFO] 2024-10-16 00:51:22 Coarse container-level search yielded [2024-08-28, 2024-09-21]...
```
The second stage of the triage process refines the container-level range by bisection
```
[INFO] 2024-10-16 00:53:19 Ran test case in 2024-09-09 in 115.5s, pass=True
[INFO] 2024-10-16 00:53:19 Refined container-level range to [2024-09-09, 2024-09-21]
[INFO] 2024-10-16 00:56:03 Ran test case in 2024-09-15 in 125.4s, pass=True
[INFO] 2024-10-16 00:56:03 Refined container-level range to [2024-09-15, 2024-09-21]
[INFO] 2024-10-16 00:58:07 Ran test case in 2024-09-18 in 122.9s, pass=True
[INFO] 2024-10-16 00:58:07 Refined container-level range to [2024-09-18, 2024-09-21]
```
The second stage of the triage process converges
```
[INFO] 2024-10-16 01:00:09 Ran test case in 2024-09-19 in 121.2s, pass=False
[INFO] 2024-10-16 01:00:09 Refined container-level range to [2024-09-18, 2024-09-19]
```
The third stage of the triage process begins, using:
- the first-known-bad container 2024-09-19
- first-known-bad commits (JAX 9d2e9... and XLA 42b04...)
- last-known-good commits (JAX 988ed... and XLA 88935...)
```
[INFO] 2024-10-16 01:00:10 Bisecting JAX [988ed2bd75df5fe25b74eaf38075aadff19be207, 9d2e9c688c4e8b733e68467d713091436a672ac0] and XLA [8893550a604fe39aae2eeae49a836e92eed497d1, 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7] using ghcr.io/nvidia/jax:jax-2024-09-19
[INFO] 2024-10-16 01:00:10 Building in the range-ending container...
```
Sanity check that re-building the first-known-bad commits in the first-known-bad
container reproduces the failure
```
[INFO] 2024-10-16 01:00:12 Checking out XLA 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7 JAX 9d2e9c688c4e8b733e68467d713091436a672ac0
```
No Bazel cache was passed, and this is the first build in the triage session, so it is
slow -- a full rebuild of JAX and XLA was needed
```
[INFO] 2024-10-16 01:13:56 Build completed in 824.9s
[INFO] 2024-10-16 01:15:25 Test completed in 88.5s
[INFO] 2024-10-16 01:15:25 Verified test failure after vanilla rebuild
```
Verification that the last-known-good commits still pass when rebuilt in the
first-known-bad container; this is a bit faster because the Bazel cache is warmer
```
[INFO] 2024-10-16 01:15:25 Checking out XLA 8893550a604fe39aae2eeae49a836e92eed497d1 JAX 988ed2bd75df5fe25b74eaf38075aadff19be207
[INFO] 2024-10-16 01:26:43 Build completed in 677.5s
[INFO] 2024-10-16 01:27:36 Test completed in 53.7s
[INFO] 2024-10-16 01:27:36 Test passed after rebuilding commits from start container in end container
```
Binary search in commits continues, with progressively faster build times
```
[INFO] 2024-10-16 01:27:37 Checking out XLA b976dd94f11ab130c5f718b360fcfb5ac6d6b875 JAX b51c65357f0ae9659e58e2ff0df871542124cddf
[INFO] 2024-10-16 01:32:24 Build completed in 287.7s
[INFO] 2024-10-16 01:33:19 Test completed in 54.4s
[INFO] 2024-10-16 01:33:19 Checking out XLA e291dfe0a12ec5907636a722c545c19d43f04c8b JAX 9dd363da1298e4810b693a918fc2e8199094acdb
[INFO] 2024-10-16 01:34:58 Build completed in 98.9s
[INFO] 2024-10-16 01:35:52 Test completed in 54.1s
[INFO] 2024-10-16 01:35:53 Checking out XLA 6e652a5d91657cfbe9fbcdff4a0ccd1b803675a7 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
[INFO] 2024-10-16 01:36:54 Build completed in 61.3s
[INFO] 2024-10-16 01:37:47 Test completed in 52.7s
[INFO] 2024-10-16 01:37:47 Checking out XLA a1299f86507c79c8acf877344d545f10329f8515 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
[INFO] 2024-10-16 01:38:39 Build completed in 52.5s
[INFO] 2024-10-16 01:39:32 Test completed in 52.5s
[INFO] 2024-10-16 01:39:32 Checking out XLA 2d1f7b70740649a57ec4988702ae1dbdfeee6e9c JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
[INFO] 2024-10-16 01:40:24 Build completed in 52.2s
[INFO] 2024-10-16 01:41:17 Test completed in 52.9s
[INFO] 2024-10-16 01:41:17 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX 016c49951f670256ce4750cdfea182e3a2a15325
[INFO] 2024-10-16 01:42:08 Build completed in 50.9s
[INFO] 2024-10-16 01:43:12 Test completed in 64.2s
```
The XLA commit has stopped changing; the initial bisection is XLA-centric (with JAX
kept roughly in sync), but when this converges on a single XLA commit, the tool will
run extra tests to decide whether to blame that XLA commit or a nearby JAX commit
```
[INFO] 2024-10-16 01:43:13 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
[INFO] 2024-10-16 01:44:01 Build completed in 48.8s
[INFO] 2024-10-16 01:45:02 Test completed in 60.8s
[INFO] 2024-10-16 01:45:03 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX cd04d0f32e854aa754e37e4b676725655a94e731
[INFO] 2024-10-16 01:45:52 Build completed in 49.4s
[INFO] 2024-10-16 01:46:53 Test completed in 60.7s
[INFO] 2024-10-16 01:46:53 Bisected failure to JAX cd04d0f32e854aa754e37e4b676725655a94e731..b164d67d4a9bd094426ff450fe1f1335d3071d03 with XLA 662eb45a17c76df93e5a386929653ae4c1f593da
```

Where the final result should be read as saying that the test passes with
[xla@662eb](https://github.com/openxla/xla/commit/662eb45a17c76df93e5a386929653ae4c1f593da)
and
[jax@cd04d](https://github.com/jax-ml/jax/commit/cd04d0f32e854aa754e37e4b676725655a94e731),
but that if JAX is moved forward to include
[jax@b164d](https://github.com/jax-ml/jax/commit/b164d67d4a9bd094426ff450fe1f1335d3071d03)
then the test fails.
This failure is fixed in [jax#24427](https://github.com/jax-ml/jax/pull/24427).

## Limitations

This tool aims to target the common case that regressions are due to commits in JAX or
XLA, so if the root cause is different it may not converge, although the partial results
may still be helpful.

For example, if the regression is due to a new version of some other dependency
`SomeProject` that was first installed in the `2024-10-15` container, then the first
two stages of the triage process will correctly identify that `2024-10-15` is the
critical date, but the third stage will fail because it will try and fail to reproduce
test success by building the JAX/XLA commits from `2024-10-14` in the `2024-10-15`
container.

Other limitations include that only `docker` is supported as a container runtime, which
also implies that it is not currently possible to triage a test that requires a
multi-node or multi-process test.

The tool also does not currently handle skipping commits that do not compile, or test
cases that require copying files (*e.g.* script files) into the container.

If you run into these limitations in real-world usage of this tool, please file a bug
against JAX-Toolbox including details of manual steps you took to root-case the test
regression.
11 changes: 5 additions & 6 deletions docs/triage.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr
be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed,
but this automates the investigation of questions like "what state of library X works with Jax at state Y?"

__Note__: There is also a utility, [triage](../.github/triage/triage), which can be
used for more granular bisection of failures in specific tests. Run it with `--help`
for usage instructions. Given a test expression that can be run inside the nightly
containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly
container where the failure first appeared, and second attributes the failure to a
specific commit of JAX or XLA.
__Note__: There is also a [triage tool](triage-tool.md), which can be used for
more granular bisection of failures in specific tests. Given a test expression that can
be run inside the nightly containers (*e.g.* `test-jax.sh jet_test_gpu`), it first
identifies the nightly container where the failure first appeared, and second attributes
the failure to a specific commit of JAX or XLA.

## Algorithm
The pseudocode for the triaging algorithm is as follows:
Expand Down
Loading