Skip to content

Commit 953e5b5

Browse files
committed
jax-toolbox-triage: improve documentation
1 parent 22b4eb4 commit 953e5b5

File tree

3 files changed

+219
-6
lines changed

3 files changed

+219
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,4 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
407407
* [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/)
408408
* [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html)
409409
* [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
410+
* [Triaging regressions](docs/triage-tool.md)

docs/triage-tool.md

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Triage tool
2+
3+
`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an
4+
individual commit of JAX or XLA.
5+
It takes as input a command that returns an error (non-zero) code when run in "recent"
6+
containers, but which returns a success (zero) code when run in some "older" container.
7+
8+
The tool follows a three-step process:
9+
1. A container-level search backwards from the "recent" container where the test is
10+
known to fail, which identifies an "older" container where the test passes. This
11+
search proceeds with an exponentially increasing step size and is based on the
12+
`YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`.
13+
2. A container-level binary search to refine this to the **latest** available
14+
container where test passes and the **earliest** available container where it
15+
fails.
16+
3. A commit-level binary search, repeatedly building + testing inside the same
17+
container, to identify a single commit of JAX (XLA) that causes the test to start
18+
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
19+
regression.
20+
21+
## Installation
22+
23+
The triage tool can be installed using `pip`:
24+
```bash
25+
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
26+
```
27+
or directly from a checkout of the JAX-Toolbox repository.
28+
Because the tool needs to orchestrate running commands in multiple containers, it is
29+
most convenient to install it in a virtual environment on the host system, rather than
30+
attempting to install it inside a container.
31+
32+
The tool should be invoked on a machine with `docker` available and whatever GPUs are
33+
needed to execute the test case.
34+
35+
## Usage
36+
37+
To use the tool, there are two compulsory arguments:
38+
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
39+
families to execute the test command in. Example: `jax` for a JAX unit test
40+
failure, `maxtext` for a MaxText model execution failure
41+
* A test command to triage.
42+
43+
The test command will be executed directly in the container, not inside a shell, so be
44+
sure not to add excessive quotation marks (*i.e.* run
45+
`jax-toolbox-triage --container=jax test-jax.sh foo` not
46+
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
47+
as fast and targeted as possible.
48+
The expectation is that the test case will be executed successfully several times as
49+
part of the triage, so you may want to tune some parameters to reduce the execution
50+
time in the successful case.
51+
For example, if `text-maxtext.sh --steps=500 ...` is failing on step 0, you should
52+
probably reduce `--steps` to optimise execution time in the successful case.
53+
54+
A JSON status file and both info-level and debug-level logfiles are written to the
55+
directory given by `--output-prefix`.
56+
57+
### Optimising container-level search performance
58+
59+
By default, the container-level search starts from the most recent available container,
60+
if you already know that the test has been failing for a while, you can pass
61+
`--end-date` to start the search further in the past.
62+
If you are sure that the test is failing on the `--end-date` you have passed, you can
63+
skip verification of that fact by passing `--skip-precondition-checks` (but see below
64+
for other checks that this skips).
65+
66+
By default, the container-level backwards search for a date on which the test passed
67+
tries the containers approximately [1, 2, 4, ...] days before `--end-date`.
68+
This can be tuned by passing `--start-date`, which overrides the "end date minus one"
69+
start value (but leaves the exponential growth of the search range width).
70+
If you are sure that the test is passing on the `--start-date` you have passed, you can
71+
skip verification of that fact by passing `--skip-precondition-checks`.
72+
73+
The combination of `--start-date`, `--end-date` and `--skip-precondition-checks` can be
74+
used to skip the entire first stage of the bisection process.
75+
76+
The second stage of the triage process can be made to abort early using the
77+
`--threshold-days` option; this stage will terminate once the delta between the latest
78+
known-good and earliest known-bad containers is below the threshold.
79+
80+
If you need to re-start the tool for some reason, use of these options can help
81+
bootstrap the tool using the results of a previous (partial) run.
82+
83+
### Optimising commit-level search performance
84+
85+
The third stage of the triage process involves repeatedly building JAX and XLA, which
86+
can be sped up significantly using a Bazel cache.
87+
By default, a local directory on the host machine (where the tool is being executed)
88+
will be used, but it may be more efficient to use a persistent and/or pre-heated cache.
89+
This can be achieved by passing the `--bazel-cache` option, which accepts absolute
90+
paths and `http`/`https`/`grpc` URLs.
91+
92+
If `--skip-precondition-checks` is passed, a sanity check that the failure can be
93+
reproduced after rebuilding the JAX/XLA commits from the first-known-bad container
94+
inside that container will be skipped.
95+
96+
## Example
97+
98+
Here is an example execution for a JAX unit test failure, with some annotation in
99+
`# comments`:
100+
```ShellSession
101+
$ jax-toolbox-triage --container jax test-jax.sh //tests:nn_test_gpu
102+
# --end-date was not passed, and 2024-10-15 is the most recent available container at
103+
# the time of execution
104+
[INFO] 2024-10-16 00:31:41 Checking end-of-range failure in 2024-10-15
105+
# --skip-precondition-checks was not passed, so the tool checks that the test does, in
106+
# fact, fail in the 2024-10-15 container
107+
[INFO] 2024-10-16 00:33:36 Ran test case in 2024-10-15 in 114.8s, pass=False
108+
# --start-date was not passed, so the first (backwards search) stage of the triage
109+
# process starts with the container 1 day before the end of the range, i.e. 2024-10-14
110+
[INFO] 2024-10-16 00:33:37 Starting coarse search with 2024-10-14 based on end_date=2024-10-15
111+
[INFO] 2024-10-16 00:35:35 Ran test case in 2024-10-14 in 118.1s, pass=False
112+
# end_date - 2 * (end_date - search_date) = 2024-10-15 - 2 days = 2024-10-13
113+
[INFO] 2024-10-16 00:38:11 Ran test case in 2024-10-13 in 122.4s, pass=False
114+
# In principle this would be 4 days before the end date, but the 2024-10-11 container
115+
# does not exist, so the tool chooses a nearby container that does exist and is older
116+
# than 2024-10-13
117+
[INFO] 2024-10-16 00:40:53 Ran test case in 2024-10-12 in 127.7s, pass=False
118+
# Steps in date start to increase significantly
119+
[INFO] 2024-10-16 00:43:28 Ran test case in 2024-10-09 in 119.3s, pass=False
120+
[INFO] 2024-10-16 00:45:29 Ran test case in 2024-10-03 in 120.7s, pass=False
121+
[INFO] 2024-10-16 00:47:27 Ran test case in 2024-09-21 in 116.3s, pass=False
122+
# The first stage of the triage process successfully identifies an old container where
123+
# this test passed
124+
[INFO] 2024-10-16 00:51:22 Ran test case in 2024-08-28 in 194.0s, pass=True
125+
[INFO] 2024-10-16 00:51:22 Coarse container-level search yielded [2024-08-28, 2024-09-21]...
126+
# The second stage of the triage process refines the container-level range by bisection
127+
[INFO] 2024-10-16 00:53:19 Ran test case in 2024-09-09 in 115.5s, pass=True
128+
[INFO] 2024-10-16 00:53:19 Refined container-level range to [2024-09-09, 2024-09-21]
129+
[INFO] 2024-10-16 00:56:03 Ran test case in 2024-09-15 in 125.4s, pass=True
130+
[INFO] 2024-10-16 00:56:03 Refined container-level range to [2024-09-15, 2024-09-21]
131+
[INFO] 2024-10-16 00:58:07 Ran test case in 2024-09-18 in 122.9s, pass=True
132+
[INFO] 2024-10-16 00:58:07 Refined container-level range to [2024-09-18, 2024-09-21]
133+
# The second stage of the triage process converges
134+
[INFO] 2024-10-16 01:00:09 Ran test case in 2024-09-19 in 121.2s, pass=False
135+
[INFO] 2024-10-16 01:00:09 Refined container-level range to [2024-09-18, 2024-09-19]
136+
# The third stage of the triage process begins, using:
137+
# - the first-known-bad container 2024-09-19
138+
# - first-known-bad commits (JAX 9d2e9... and XLA 42b04...)
139+
# - last-known-good commits (JAX 988ed... and XLA 88935...)
140+
[INFO] 2024-10-16 01:00:10 Bisecting JAX [988ed2bd75df5fe25b74eaf38075aadff19be207, 9d2e9c688c4e8b733e68467d713091436a672ac0] and XLA [8893550a604fe39aae2eeae49a836e92eed497d1, 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7] using ghcr.io/nvidia/jax:jax-2024-09-19
141+
[INFO] 2024-10-16 01:00:10 Building in the range-ending container...
142+
# Sanity check that re-building the first-known-bad commits in the first-known-bad
143+
# container reproduces the failure
144+
[INFO] 2024-10-16 01:00:12 Checking out XLA 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7 JAX 9d2e9c688c4e8b733e68467d713091436a672ac0
145+
# No Bazel cache was passed, and this is the first build in the triage session, so it
146+
# is slow -- a full rebuild of JAX and XLA was needed
147+
[INFO] 2024-10-16 01:13:56 Build completed in 824.9s
148+
[INFO] 2024-10-16 01:15:25 Test completed in 88.5s
149+
[INFO] 2024-10-16 01:15:25 Verified test failure after vanilla rebuild
150+
# Verification that the last-known-good commits still pass when rebuilt in the
151+
# first-known-bad container; this is a bit faster because the Bazel cache is warmer
152+
[INFO] 2024-10-16 01:15:25 Checking out XLA 8893550a604fe39aae2eeae49a836e92eed497d1 JAX 988ed2bd75df5fe25b74eaf38075aadff19be207
153+
[INFO] 2024-10-16 01:26:43 Build completed in 677.5s
154+
[INFO] 2024-10-16 01:27:36 Test completed in 53.7s
155+
[INFO] 2024-10-16 01:27:36 Test passed after rebuilding commits from start container in end container
156+
# Binary search in commits continues, with progressively faster build times
157+
[INFO] 2024-10-16 01:27:37 Checking out XLA b976dd94f11ab130c5f718b360fcfb5ac6d6b875 JAX b51c65357f0ae9659e58e2ff0df871542124cddf
158+
[INFO] 2024-10-16 01:32:24 Build completed in 287.7s
159+
[INFO] 2024-10-16 01:33:19 Test completed in 54.4s
160+
[INFO] 2024-10-16 01:33:19 Checking out XLA e291dfe0a12ec5907636a722c545c19d43f04c8b JAX 9dd363da1298e4810b693a918fc2e8199094acdb
161+
[INFO] 2024-10-16 01:34:58 Build completed in 98.9s
162+
[INFO] 2024-10-16 01:35:52 Test completed in 54.1s
163+
[INFO] 2024-10-16 01:35:53 Checking out XLA 6e652a5d91657cfbe9fbcdff4a0ccd1b803675a7 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
164+
[INFO] 2024-10-16 01:36:54 Build completed in 61.3s
165+
[INFO] 2024-10-16 01:37:47 Test completed in 52.7s
166+
[INFO] 2024-10-16 01:37:47 Checking out XLA a1299f86507c79c8acf877344d545f10329f8515 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
167+
[INFO] 2024-10-16 01:38:39 Build completed in 52.5s
168+
[INFO] 2024-10-16 01:39:32 Test completed in 52.5s
169+
[INFO] 2024-10-16 01:39:32 Checking out XLA 2d1f7b70740649a57ec4988702ae1dbdfeee6e9c JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
170+
[INFO] 2024-10-16 01:40:24 Build completed in 52.2s
171+
[INFO] 2024-10-16 01:41:17 Test completed in 52.9s
172+
[INFO] 2024-10-16 01:41:17 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX 016c49951f670256ce4750cdfea182e3a2a15325
173+
[INFO] 2024-10-16 01:42:08 Build completed in 50.9s
174+
[INFO] 2024-10-16 01:43:12 Test completed in 64.2s
175+
# The XLA commit has stopped changing; the initial bisection is XLA-centric (with JAX
176+
# kept roughly in sync), but when this converges on a single XLA commit, the tool will
177+
# run extra tests to decide whether to blame that XLA commit or a nearby JAX commit
178+
[INFO] 2024-10-16 01:43:13 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
179+
[INFO] 2024-10-16 01:44:01 Build completed in 48.8s
180+
[INFO] 2024-10-16 01:45:02 Test completed in 60.8s
181+
[INFO] 2024-10-16 01:45:03 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX cd04d0f32e854aa754e37e4b676725655a94e731
182+
[INFO] 2024-10-16 01:45:52 Build completed in 49.4s
183+
[INFO] 2024-10-16 01:46:53 Test completed in 60.7s
184+
[INFO] 2024-10-16 01:46:53 Bisected failure to JAX cd04d0f32e854aa754e37e4b676725655a94e731..b164d67d4a9bd094426ff450fe1f1335d3071d03 with XLA 662eb45a17c76df93e5a386929653ae4c1f593da
185+
```
186+
Where the final result should be read as saying that the test passes with
187+
https://github.com/openxla/xla/commit/662eb45a17c76df93e5a386929653ae4c1f593da and
188+
https://github.com/jax-ml/jax/commit/cd04d0f32e854aa754e37e4b676725655a94e731, but that
189+
if JAX is moved forward to include
190+
https://github.com/jax-ml/jax/commit/b164d67d4a9bd094426ff450fe1f1335d3071d03 then the
191+
test fails.
192+
This failure is fixed in XXX.
193+
194+
## Limitations
195+
196+
This tool aims to target the common case that regressions are due to commits in JAX or
197+
XLA, so if the root cause is different it may not converge, although the partial results
198+
may still be helpful.
199+
200+
For example, if the regression is due to a new version of some other dependency
201+
`SomeProject` that was first installed in the `2024-10-15` container, then the first
202+
two stages of the triage process will correctly identify that `2024-10-15` is the
203+
critical date, but the third stage will fail because it will try and fail to reproduce
204+
test success by building the JAX/XLA commits from `2024-10-14` in the `2024-10-15`
205+
container.
206+
207+
Other limitations include that only `docker` is supported as a container runtime, which
208+
also implies that it is not currently possible to triage a test that requires a
209+
multi-node or multi-process test.
210+
211+
If you run into these limitations in real-world usage of this tool, please file a bug
212+
against JAX-Toolbox including details of manual steps you took to root-case the test
213+
regression.

docs/triage.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr
44
be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed,
55
but this automates the investigation of questions like "what state of library X works with Jax at state Y?"
66

7-
__Note__: There is also a utility, [triage](../.github/triage/triage), which can be
8-
used for more granular bisection of failures in specific tests. Run it with `--help`
9-
for usage instructions. Given a test expression that can be run inside the nightly
10-
containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly
11-
container where the failure first appeared, and second attributes the failure to a
12-
specific commit of JAX or XLA.
7+
__Note__: There is also a [triage tool](triage-tool.md), which can be used for
8+
more granular bisection of failures in specific tests. Given a test expression that can
9+
be run inside the nightl containers (*e.g.* `test-jax.sh jet_test_gpu`), it first
10+
identifies the nightl container where the failure first appeared, and second attributes
11+
the failure to a specific commit of JAX or XLA.
1312

1413
## Algorithm
1514
The pseudocode for the triaging algorithm is as follows:

0 commit comments

Comments
 (0)