Skip to content

Commit 3545627

Browse files
committed
jax-toolbox-triage: improve documentation
1 parent 22b4eb4 commit 3545627

File tree

3 files changed

+250
-6
lines changed

3 files changed

+250
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,4 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
407407
* [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/)
408408
* [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html)
409409
* [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
410+
* [Triaging regressions](docs/triage-tool.md)

docs/triage-tool.md

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Triage tool
2+
3+
`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an
4+
individual commit of JAX or XLA.
5+
It takes as input a command that returns an error (non-zero) code when run in "recent"
6+
containers, but which returns a success (zero) code when run in some "older" container.
7+
8+
The tool follows a three-step process:
9+
1. A container-level search backwards from the "recent" container where the test is
10+
known to fail, which identifies an "older" container where the test passes. This
11+
search proceeds with an exponentially increasing step size and is based on the
12+
`YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`.
13+
2. A container-level binary search to refine this to the **latest** available
14+
container where test passes and the **earliest** available container where it
15+
fails.
16+
3. A commit-level binary search, repeatedly building + testing inside the same
17+
container, to identify a single commit of JAX (XLA) that causes the test to start
18+
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
19+
regression.
20+
21+
## Installation
22+
23+
The triage tool can be installed using `pip`:
24+
```bash
25+
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
26+
```
27+
or directly from a checkout of the JAX-Toolbox repository.
28+
Because the tool needs to orchestrate running commands in multiple containers, it is
29+
most convenient to install it in a virtual environment on the host system, rather than
30+
attempting to install it inside a container.
31+
32+
The tool should be invoked on a machine with `docker` available and whatever GPUs are
33+
needed to execute the test case.
34+
35+
## Usage
36+
37+
To use the tool, there are two compulsory arguments:
38+
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
39+
families to execute the test command in. Example: `jax` for a JAX unit test
40+
failure, `maxtext` for a MaxText model execution failure
41+
* A test command to triage.
42+
43+
The test command will be executed directly in the container, not inside a shell, so be
44+
sure not to add excessive quotation marks (*i.e.* run
45+
`jax-toolbox-triage --container=jax test-jax.sh foo` not
46+
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
47+
as fast and targeted as possible.
48+
The expectation is that the test case will be executed successfully several times as
49+
part of the triage, so you may want to tune some parameters to reduce the execution
50+
time in the successful case.
51+
For example, if `text-maxtext.sh --steps=500 ...` is failing on step 0, you should
52+
probably reduce `--steps` to optimise execution time in the successful case.
53+
54+
A JSON status file and both info-level and debug-level logfiles are written to the
55+
directory given by `--output-prefix`.
56+
57+
### Optimising container-level search performance
58+
59+
By default, the container-level search starts from the most recent available container,
60+
if you already know that the test has been failing for a while, you can pass
61+
`--end-date` to start the search further in the past.
62+
If you are sure that the test is failing on the `--end-date` you have passed, you can
63+
skip verification of that fact by passing `--skip-precondition-checks` (but see below
64+
for other checks that this skips).
65+
66+
By default, the container-level backwards search for a date on which the test passed
67+
tries the containers approximately [1, 2, 4, ...] days before `--end-date`.
68+
This can be tuned by passing `--start-date`, which overrides the "end date minus one"
69+
start value (but leaves the exponential growth of the search range width).
70+
If you are sure that the test is passing on the `--start-date` you have passed, you can
71+
skip verification of that fact by passing `--skip-precondition-checks`.
72+
73+
The combination of `--start-date`, `--end-date` and `--skip-precondition-checks` can be
74+
used to skip the entire first stage of the bisection process.
75+
76+
The second stage of the triage process can be made to abort early using the
77+
`--threshold-days` option; this stage will terminate once the delta between the latest
78+
known-good and earliest known-bad containers is below the threshold.
79+
80+
If you need to re-start the tool for some reason, use of these options can help
81+
bootstrap the tool using the results of a previous (partial) run.
82+
83+
### Optimising commit-level search performance
84+
85+
The third stage of the triage process involves repeatedly building JAX and XLA, which
86+
can be sped up significantly using a Bazel cache.
87+
By default, a local directory on the host machine (where the tool is being executed)
88+
will be used, but it may be more efficient to use a persistent and/or pre-heated cache.
89+
This can be achieved by passing the `--bazel-cache` option, which accepts absolute
90+
paths and `http`/`https`/`grpc` URLs.
91+
92+
If `--skip-precondition-checks` is passed, a sanity check that the failure can be
93+
reproduced after rebuilding the JAX/XLA commits from the first-known-bad container
94+
inside that container will be skipped.
95+
96+
## Example
97+
98+
Here is an example execution for a JAX unit test failure, with some annotation:
99+
```console
100+
user@gpu-machine $ jax-toolbox-triage --container jax test-jax.sh //tests:nn_test_gpu
101+
```
102+
`--end-date` was not passed, and 2024-10-15 is the most recent available container
103+
at the time of execution
104+
```
105+
[INFO] 2024-10-16 00:31:41 Checking end-of-range failure in 2024-10-15
106+
```
107+
`--skip-precondition-checks` was not passed, so the tool checks that the test does, in
108+
fact, fail in the 2024-10-15 container
109+
```
110+
[INFO] 2024-10-16 00:33:36 Ran test case in 2024-10-15 in 114.8s, pass=False
111+
```
112+
`--start-date` was not passed, so the first (backwards search) stage of the triage
113+
process starts with the container 1 day before the end of the range, *i.e.* 2024-10-14
114+
```
115+
[INFO] 2024-10-16 00:33:37 Starting coarse search with 2024-10-14 based on end_date=2024-10-15
116+
[INFO] 2024-10-16 00:35:35 Ran test case in 2024-10-14 in 118.1s, pass=False
117+
```
118+
`end_date - 2 * (end_date - search_date)` = `2024-10-15 - 2 days` = `2024-10-13`
119+
```
120+
[INFO] 2024-10-16 00:38:11 Ran test case in 2024-10-13 in 122.4s, pass=False
121+
```
122+
In principle this would be 4 days before the end date, but the 2024-10-11 container
123+
does not exist, so the tool chooses a nearby container that does exist and is older
124+
than 2024-10-13
125+
```
126+
[INFO] 2024-10-16 00:40:53 Ran test case in 2024-10-12 in 127.7s, pass=False
127+
```
128+
Steps in date start to increase significantly
129+
```
130+
[INFO] 2024-10-16 00:43:28 Ran test case in 2024-10-09 in 119.3s, pass=False
131+
[INFO] 2024-10-16 00:45:29 Ran test case in 2024-10-03 in 120.7s, pass=False
132+
[INFO] 2024-10-16 00:47:27 Ran test case in 2024-09-21 in 116.3s, pass=False
133+
```
134+
The first stage of the triage process successfully identifies an old container where
135+
this test passed
136+
```
137+
[INFO] 2024-10-16 00:51:22 Ran test case in 2024-08-28 in 194.0s, pass=True
138+
[INFO] 2024-10-16 00:51:22 Coarse container-level search yielded [2024-08-28, 2024-09-21]...
139+
```
140+
The second stage of the triage process refines the container-level range by bisection
141+
```
142+
[INFO] 2024-10-16 00:53:19 Ran test case in 2024-09-09 in 115.5s, pass=True
143+
[INFO] 2024-10-16 00:53:19 Refined container-level range to [2024-09-09, 2024-09-21]
144+
[INFO] 2024-10-16 00:56:03 Ran test case in 2024-09-15 in 125.4s, pass=True
145+
[INFO] 2024-10-16 00:56:03 Refined container-level range to [2024-09-15, 2024-09-21]
146+
[INFO] 2024-10-16 00:58:07 Ran test case in 2024-09-18 in 122.9s, pass=True
147+
[INFO] 2024-10-16 00:58:07 Refined container-level range to [2024-09-18, 2024-09-21]
148+
```
149+
The second stage of the triage process converges
150+
```
151+
[INFO] 2024-10-16 01:00:09 Ran test case in 2024-09-19 in 121.2s, pass=False
152+
[INFO] 2024-10-16 01:00:09 Refined container-level range to [2024-09-18, 2024-09-19]
153+
```
154+
The third stage of the triage process begins, using:
155+
- the first-known-bad container 2024-09-19
156+
- first-known-bad commits (JAX 9d2e9... and XLA 42b04...)
157+
- last-known-good commits (JAX 988ed... and XLA 88935...)
158+
```
159+
[INFO] 2024-10-16 01:00:10 Bisecting JAX [988ed2bd75df5fe25b74eaf38075aadff19be207, 9d2e9c688c4e8b733e68467d713091436a672ac0] and XLA [8893550a604fe39aae2eeae49a836e92eed497d1, 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7] using ghcr.io/nvidia/jax:jax-2024-09-19
160+
[INFO] 2024-10-16 01:00:10 Building in the range-ending container...
161+
```
162+
Sanity check that re-building the first-known-bad commits in the first-known-bad
163+
container reproduces the failure
164+
```
165+
[INFO] 2024-10-16 01:00:12 Checking out XLA 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7 JAX 9d2e9c688c4e8b733e68467d713091436a672ac0
166+
```
167+
No Bazel cache was passed, and this is the first build in the triage session, so it is
168+
slow -- a full rebuild of JAX and XLA was needed
169+
```
170+
[INFO] 2024-10-16 01:13:56 Build completed in 824.9s
171+
[INFO] 2024-10-16 01:15:25 Test completed in 88.5s
172+
[INFO] 2024-10-16 01:15:25 Verified test failure after vanilla rebuild
173+
```
174+
Verification that the last-known-good commits still pass when rebuilt in the
175+
first-known-bad container; this is a bit faster because the Bazel cache is warmer
176+
```
177+
[INFO] 2024-10-16 01:15:25 Checking out XLA 8893550a604fe39aae2eeae49a836e92eed497d1 JAX 988ed2bd75df5fe25b74eaf38075aadff19be207
178+
[INFO] 2024-10-16 01:26:43 Build completed in 677.5s
179+
[INFO] 2024-10-16 01:27:36 Test completed in 53.7s
180+
[INFO] 2024-10-16 01:27:36 Test passed after rebuilding commits from start container in end container
181+
```
182+
Binary search in commits continues, with progressively faster build times
183+
```
184+
[INFO] 2024-10-16 01:27:37 Checking out XLA b976dd94f11ab130c5f718b360fcfb5ac6d6b875 JAX b51c65357f0ae9659e58e2ff0df871542124cddf
185+
[INFO] 2024-10-16 01:32:24 Build completed in 287.7s
186+
[INFO] 2024-10-16 01:33:19 Test completed in 54.4s
187+
[INFO] 2024-10-16 01:33:19 Checking out XLA e291dfe0a12ec5907636a722c545c19d43f04c8b JAX 9dd363da1298e4810b693a918fc2e8199094acdb
188+
[INFO] 2024-10-16 01:34:58 Build completed in 98.9s
189+
[INFO] 2024-10-16 01:35:52 Test completed in 54.1s
190+
[INFO] 2024-10-16 01:35:53 Checking out XLA 6e652a5d91657cfbe9fbcdff4a0ccd1b803675a7 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
191+
[INFO] 2024-10-16 01:36:54 Build completed in 61.3s
192+
[INFO] 2024-10-16 01:37:47 Test completed in 52.7s
193+
[INFO] 2024-10-16 01:37:47 Checking out XLA a1299f86507c79c8acf877344d545f10329f8515 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
194+
[INFO] 2024-10-16 01:38:39 Build completed in 52.5s
195+
[INFO] 2024-10-16 01:39:32 Test completed in 52.5s
196+
[INFO] 2024-10-16 01:39:32 Checking out XLA 2d1f7b70740649a57ec4988702ae1dbdfeee6e9c JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
197+
[INFO] 2024-10-16 01:40:24 Build completed in 52.2s
198+
[INFO] 2024-10-16 01:41:17 Test completed in 52.9s
199+
[INFO] 2024-10-16 01:41:17 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX 016c49951f670256ce4750cdfea182e3a2a15325
200+
[INFO] 2024-10-16 01:42:08 Build completed in 50.9s
201+
[INFO] 2024-10-16 01:43:12 Test completed in 64.2s
202+
```
203+
The XLA commit has stopped changing; the initial bisection is XLA-centric (with JAX
204+
kept roughly in sync), but when this converges on a single XLA commit, the tool will
205+
run extra tests to decide whether to blame that XLA commit or a nearby JAX commit
206+
```
207+
[INFO] 2024-10-16 01:43:13 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
208+
[INFO] 2024-10-16 01:44:01 Build completed in 48.8s
209+
[INFO] 2024-10-16 01:45:02 Test completed in 60.8s
210+
[INFO] 2024-10-16 01:45:03 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX cd04d0f32e854aa754e37e4b676725655a94e731
211+
[INFO] 2024-10-16 01:45:52 Build completed in 49.4s
212+
[INFO] 2024-10-16 01:46:53 Test completed in 60.7s
213+
[INFO] 2024-10-16 01:46:53 Bisected failure to JAX cd04d0f32e854aa754e37e4b676725655a94e731..b164d67d4a9bd094426ff450fe1f1335d3071d03 with XLA 662eb45a17c76df93e5a386929653ae4c1f593da
214+
```
215+
216+
Where the final result should be read as saying that the test passes with
217+
[xla@662eb](https://github.com/openxla/xla/commit/662eb45a17c76df93e5a386929653ae4c1f593da)
218+
and
219+
[jax@cd04d](https://github.com/jax-ml/jax/commit/cd04d0f32e854aa754e37e4b676725655a94e731),
220+
but that if JAX is moved forward to include
221+
[jax@b164d](https://github.com/jax-ml/jax/commit/b164d67d4a9bd094426ff450fe1f1335d3071d03)
222+
then the test fails.
223+
This failure is fixed in XXX.
224+
225+
## Limitations
226+
227+
This tool aims to target the common case that regressions are due to commits in JAX or
228+
XLA, so if the root cause is different it may not converge, although the partial results
229+
may still be helpful.
230+
231+
For example, if the regression is due to a new version of some other dependency
232+
`SomeProject` that was first installed in the `2024-10-15` container, then the first
233+
two stages of the triage process will correctly identify that `2024-10-15` is the
234+
critical date, but the third stage will fail because it will try and fail to reproduce
235+
test success by building the JAX/XLA commits from `2024-10-14` in the `2024-10-15`
236+
container.
237+
238+
Other limitations include that only `docker` is supported as a container runtime, which
239+
also implies that it is not currently possible to triage a test that requires a
240+
multi-node or multi-process test.
241+
242+
If you run into these limitations in real-world usage of this tool, please file a bug
243+
against JAX-Toolbox including details of manual steps you took to root-case the test
244+
regression.

docs/triage.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr
44
be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed,
55
but this automates the investigation of questions like "what state of library X works with Jax at state Y?"
66

7-
__Note__: There is also a utility, [triage](../.github/triage/triage), which can be
8-
used for more granular bisection of failures in specific tests. Run it with `--help`
9-
for usage instructions. Given a test expression that can be run inside the nightly
10-
containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly
11-
container where the failure first appeared, and second attributes the failure to a
12-
specific commit of JAX or XLA.
7+
__Note__: There is also a [triage tool](triage-tool.md), which can be used for
8+
more granular bisection of failures in specific tests. Given a test expression that can
9+
be run inside the nightl containers (*e.g.* `test-jax.sh jet_test_gpu`), it first
10+
identifies the nightl container where the failure first appeared, and second attributes
11+
the failure to a specific commit of JAX or XLA.
1312

1413
## Algorithm
1514
The pseudocode for the triaging algorithm is as follows:

0 commit comments

Comments
 (0)