Skip to content

Commit 867dc8e

Browse files
authored
jax-toolbox-triage: improve documentation (#1104)
Improve documentation of the triage tool added in #793.
1 parent 48eb66d commit 867dc8e

File tree

4 files changed

+258
-8
lines changed

4 files changed

+258
-8
lines changed

.github/triage/jax_toolbox_triage/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ def check_container(date: datetime.date) -> bool:
8686
jax_commit = get_commit(worker, "jax")
8787
xla_commit = get_commit(worker, "xla")
8888

89-
logger.debug(result.stdout)
90-
logger.info(f"Ran test case in {date} in {test_time:.1f}s")
9189
test_pass = result.returncode == 0
90+
logger.info(f"Ran test case in {date} in {test_time:.1f}s, pass={test_pass}")
91+
logger.debug(result.stdout)
92+
logger.debug(result.stderr)
9293
add_summary_record(
9394
"container",
9495
{

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,4 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
407407
* [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/)
408408
* [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html)
409409
* [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
410+
* [Triaging regressions](docs/triage-tool.md)

docs/triage-tool.md

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Triage tool
2+
3+
`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an
4+
individual commit of JAX or XLA.
5+
It takes as input a command that returns an error (non-zero) code when run in "recent"
6+
containers, but which returns a success (zero) code when run in some "older" container.
7+
The command must be executable within the containers, *i.e.* it cannot refer to files
8+
that only exist on the host system.
9+
10+
The tool follows a three-step process:
11+
1. A container-level search backwards from the "recent" container where the test is
12+
known to fail, which identifies an "older" container where the test passes. This
13+
search proceeds with an exponentially increasing step size and is based on the
14+
`YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`.
15+
2. A container-level binary search to refine this to the **latest** available
16+
container where test passes and the **earliest** available container where it
17+
fails.
18+
3. A commit-level binary search, repeatedly building + testing inside the same
19+
container, to identify a single commit of JAX (XLA) that causes the test to start
20+
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
21+
regression.
22+
23+
## Installation
24+
25+
The triage tool can be installed using `pip`:
26+
```bash
27+
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
28+
```
29+
or directly from a checkout of the JAX-Toolbox repository.
30+
Because the tool needs to orchestrate running commands in multiple containers, it is
31+
most convenient to install it in a virtual environment on the host system, rather than
32+
attempting to install it inside a container.
33+
34+
The tool should be invoked on a machine with `docker` available and whatever GPUs are
35+
needed to execute the test case.
36+
37+
## Usage
38+
39+
To use the tool, there are two compulsory arguments:
40+
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
41+
families to execute the test command in. Example: `jax` for a JAX unit test
42+
failure, `maxtext` for a MaxText model execution failure
43+
* A test command to triage.
44+
45+
The test command will be executed directly in the container, not inside a shell, so be
46+
sure not to add excessive quotation marks (*i.e.* run
47+
`jax-toolbox-triage --container=jax test-jax.sh foo` not
48+
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
49+
as fast and targeted as possible.
50+
The expectation is that the test case will be executed successfully several times as
51+
part of the triage, so you may want to tune some parameters to reduce the execution
52+
time in the successful case.
53+
For example, if `text-maxtext.sh --steps=500 ...` is failing on step 0, you should
54+
probably reduce `--steps` to optimise execution time in the successful case.
55+
56+
A JSON status file and both info-level and debug-level logfiles are written to the
57+
directory given by `--output-prefix`.
58+
59+
### Optimising container-level search performance
60+
61+
By default, the container-level search starts from the most recent available container,
62+
if you already know that the test has been failing for a while, you can pass
63+
`--end-date` to start the search further in the past.
64+
If you are sure that the test is failing on the `--end-date` you have passed, you can
65+
skip verification of that fact by passing `--skip-precondition-checks` (but see below
66+
for other checks that this skips).
67+
68+
By default, the container-level backwards search for a date on which the test passed
69+
tries the containers approximately [1, 2, 4, ...] days before `--end-date`.
70+
This can be tuned by passing `--start-date`, which overrides the "end date minus one"
71+
start value (but leaves the exponential growth of the search range width).
72+
If you are sure that the test is passing on the `--start-date` you have passed, you can
73+
skip verification of that fact by passing `--skip-precondition-checks`.
74+
75+
The combination of `--start-date`, `--end-date` and `--skip-precondition-checks` can be
76+
used to skip the entire first stage of the bisection process.
77+
78+
The second stage of the triage process can be made to abort early using the
79+
`--threshold-days` option; this stage will terminate once the delta between the latest
80+
known-good and earliest known-bad containers is below the threshold.
81+
82+
If you need to re-start the tool for some reason, use of these options can help
83+
bootstrap the tool using the results of a previous (partial) run.
84+
85+
### Optimising commit-level search performance
86+
87+
The third stage of the triage process involves repeatedly building JAX and XLA, which
88+
can be sped up significantly using a Bazel cache.
89+
By default, a local directory on the host machine (where the tool is being executed)
90+
will be used, but it may be more efficient to use a persistent and/or pre-heated cache.
91+
This can be achieved by passing the `--bazel-cache` option, which accepts absolute
92+
paths and `http`/`https`/`grpc` URLs.
93+
94+
If `--skip-precondition-checks` is passed, a sanity check that the failure can be
95+
reproduced after rebuilding the JAX/XLA commits from the first-known-bad container
96+
inside that container will be skipped.
97+
98+
## Example
99+
100+
Here is an example execution for a JAX unit test failure, with some annotation:
101+
```console
102+
user@gpu-machine $ jax-toolbox-triage --container jax test-jax.sh //tests:nn_test_gpu
103+
```
104+
`--end-date` was not passed, and 2024-10-15 is the most recent available container
105+
at the time of execution
106+
```
107+
[INFO] 2024-10-16 00:31:41 Checking end-of-range failure in 2024-10-15
108+
```
109+
`--skip-precondition-checks` was not passed, so the tool checks that the test does, in
110+
fact, fail in the 2024-10-15 container
111+
```
112+
[INFO] 2024-10-16 00:33:36 Ran test case in 2024-10-15 in 114.8s, pass=False
113+
```
114+
`--start-date` was not passed, so the first (backwards search) stage of the triage
115+
process starts with the container 1 day before the end of the range, *i.e.* 2024-10-14
116+
```
117+
[INFO] 2024-10-16 00:33:37 Starting coarse search with 2024-10-14 based on end_date=2024-10-15
118+
[INFO] 2024-10-16 00:35:35 Ran test case in 2024-10-14 in 118.1s, pass=False
119+
```
120+
`end_date - 2 * (end_date - search_date)` = `2024-10-15 - 2 days` = `2024-10-13`
121+
```
122+
[INFO] 2024-10-16 00:38:11 Ran test case in 2024-10-13 in 122.4s, pass=False
123+
```
124+
In principle this would be 4 days before the end date, but the 2024-10-11 container
125+
does not exist, so the tool chooses a nearby container that does exist and is older
126+
than 2024-10-13
127+
```
128+
[INFO] 2024-10-16 00:40:53 Ran test case in 2024-10-12 in 127.7s, pass=False
129+
```
130+
Steps in date start to increase significantly
131+
```
132+
[INFO] 2024-10-16 00:43:28 Ran test case in 2024-10-09 in 119.3s, pass=False
133+
[INFO] 2024-10-16 00:45:29 Ran test case in 2024-10-03 in 120.7s, pass=False
134+
[INFO] 2024-10-16 00:47:27 Ran test case in 2024-09-21 in 116.3s, pass=False
135+
```
136+
The first stage of the triage process successfully identifies an old container where
137+
this test passed
138+
```
139+
[INFO] 2024-10-16 00:51:22 Ran test case in 2024-08-28 in 194.0s, pass=True
140+
[INFO] 2024-10-16 00:51:22 Coarse container-level search yielded [2024-08-28, 2024-09-21]...
141+
```
142+
The second stage of the triage process refines the container-level range by bisection
143+
```
144+
[INFO] 2024-10-16 00:53:19 Ran test case in 2024-09-09 in 115.5s, pass=True
145+
[INFO] 2024-10-16 00:53:19 Refined container-level range to [2024-09-09, 2024-09-21]
146+
[INFO] 2024-10-16 00:56:03 Ran test case in 2024-09-15 in 125.4s, pass=True
147+
[INFO] 2024-10-16 00:56:03 Refined container-level range to [2024-09-15, 2024-09-21]
148+
[INFO] 2024-10-16 00:58:07 Ran test case in 2024-09-18 in 122.9s, pass=True
149+
[INFO] 2024-10-16 00:58:07 Refined container-level range to [2024-09-18, 2024-09-21]
150+
```
151+
The second stage of the triage process converges
152+
```
153+
[INFO] 2024-10-16 01:00:09 Ran test case in 2024-09-19 in 121.2s, pass=False
154+
[INFO] 2024-10-16 01:00:09 Refined container-level range to [2024-09-18, 2024-09-19]
155+
```
156+
The third stage of the triage process begins, using:
157+
- the first-known-bad container 2024-09-19
158+
- first-known-bad commits (JAX 9d2e9... and XLA 42b04...)
159+
- last-known-good commits (JAX 988ed... and XLA 88935...)
160+
```
161+
[INFO] 2024-10-16 01:00:10 Bisecting JAX [988ed2bd75df5fe25b74eaf38075aadff19be207, 9d2e9c688c4e8b733e68467d713091436a672ac0] and XLA [8893550a604fe39aae2eeae49a836e92eed497d1, 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7] using ghcr.io/nvidia/jax:jax-2024-09-19
162+
[INFO] 2024-10-16 01:00:10 Building in the range-ending container...
163+
```
164+
Sanity check that re-building the first-known-bad commits in the first-known-bad
165+
container reproduces the failure
166+
```
167+
[INFO] 2024-10-16 01:00:12 Checking out XLA 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7 JAX 9d2e9c688c4e8b733e68467d713091436a672ac0
168+
```
169+
No Bazel cache was passed, and this is the first build in the triage session, so it is
170+
slow -- a full rebuild of JAX and XLA was needed
171+
```
172+
[INFO] 2024-10-16 01:13:56 Build completed in 824.9s
173+
[INFO] 2024-10-16 01:15:25 Test completed in 88.5s
174+
[INFO] 2024-10-16 01:15:25 Verified test failure after vanilla rebuild
175+
```
176+
Verification that the last-known-good commits still pass when rebuilt in the
177+
first-known-bad container; this is a bit faster because the Bazel cache is warmer
178+
```
179+
[INFO] 2024-10-16 01:15:25 Checking out XLA 8893550a604fe39aae2eeae49a836e92eed497d1 JAX 988ed2bd75df5fe25b74eaf38075aadff19be207
180+
[INFO] 2024-10-16 01:26:43 Build completed in 677.5s
181+
[INFO] 2024-10-16 01:27:36 Test completed in 53.7s
182+
[INFO] 2024-10-16 01:27:36 Test passed after rebuilding commits from start container in end container
183+
```
184+
Binary search in commits continues, with progressively faster build times
185+
```
186+
[INFO] 2024-10-16 01:27:37 Checking out XLA b976dd94f11ab130c5f718b360fcfb5ac6d6b875 JAX b51c65357f0ae9659e58e2ff0df871542124cddf
187+
[INFO] 2024-10-16 01:32:24 Build completed in 287.7s
188+
[INFO] 2024-10-16 01:33:19 Test completed in 54.4s
189+
[INFO] 2024-10-16 01:33:19 Checking out XLA e291dfe0a12ec5907636a722c545c19d43f04c8b JAX 9dd363da1298e4810b693a918fc2e8199094acdb
190+
[INFO] 2024-10-16 01:34:58 Build completed in 98.9s
191+
[INFO] 2024-10-16 01:35:52 Test completed in 54.1s
192+
[INFO] 2024-10-16 01:35:53 Checking out XLA 6e652a5d91657cfbe9fbcdff4a0ccd1b803675a7 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
193+
[INFO] 2024-10-16 01:36:54 Build completed in 61.3s
194+
[INFO] 2024-10-16 01:37:47 Test completed in 52.7s
195+
[INFO] 2024-10-16 01:37:47 Checking out XLA a1299f86507c79c8acf877344d545f10329f8515 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
196+
[INFO] 2024-10-16 01:38:39 Build completed in 52.5s
197+
[INFO] 2024-10-16 01:39:32 Test completed in 52.5s
198+
[INFO] 2024-10-16 01:39:32 Checking out XLA 2d1f7b70740649a57ec4988702ae1dbdfeee6e9c JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
199+
[INFO] 2024-10-16 01:40:24 Build completed in 52.2s
200+
[INFO] 2024-10-16 01:41:17 Test completed in 52.9s
201+
[INFO] 2024-10-16 01:41:17 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX 016c49951f670256ce4750cdfea182e3a2a15325
202+
[INFO] 2024-10-16 01:42:08 Build completed in 50.9s
203+
[INFO] 2024-10-16 01:43:12 Test completed in 64.2s
204+
```
205+
The XLA commit has stopped changing; the initial bisection is XLA-centric (with JAX
206+
kept roughly in sync), but when this converges on a single XLA commit, the tool will
207+
run extra tests to decide whether to blame that XLA commit or a nearby JAX commit
208+
```
209+
[INFO] 2024-10-16 01:43:13 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX b164d67d4a9bd094426ff450fe1f1335d3071d03
210+
[INFO] 2024-10-16 01:44:01 Build completed in 48.8s
211+
[INFO] 2024-10-16 01:45:02 Test completed in 60.8s
212+
[INFO] 2024-10-16 01:45:03 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX cd04d0f32e854aa754e37e4b676725655a94e731
213+
[INFO] 2024-10-16 01:45:52 Build completed in 49.4s
214+
[INFO] 2024-10-16 01:46:53 Test completed in 60.7s
215+
[INFO] 2024-10-16 01:46:53 Bisected failure to JAX cd04d0f32e854aa754e37e4b676725655a94e731..b164d67d4a9bd094426ff450fe1f1335d3071d03 with XLA 662eb45a17c76df93e5a386929653ae4c1f593da
216+
```
217+
218+
Where the final result should be read as saying that the test passes with
219+
[xla@662eb](https://github.com/openxla/xla/commit/662eb45a17c76df93e5a386929653ae4c1f593da)
220+
and
221+
[jax@cd04d](https://github.com/jax-ml/jax/commit/cd04d0f32e854aa754e37e4b676725655a94e731),
222+
but that if JAX is moved forward to include
223+
[jax@b164d](https://github.com/jax-ml/jax/commit/b164d67d4a9bd094426ff450fe1f1335d3071d03)
224+
then the test fails.
225+
This failure is fixed in [jax#24427](https://github.com/jax-ml/jax/pull/24427).
226+
227+
## Limitations
228+
229+
This tool aims to target the common case that regressions are due to commits in JAX or
230+
XLA, so if the root cause is different it may not converge, although the partial results
231+
may still be helpful.
232+
233+
For example, if the regression is due to a new version of some other dependency
234+
`SomeProject` that was first installed in the `2024-10-15` container, then the first
235+
two stages of the triage process will correctly identify that `2024-10-15` is the
236+
critical date, but the third stage will fail because it will try and fail to reproduce
237+
test success by building the JAX/XLA commits from `2024-10-14` in the `2024-10-15`
238+
container.
239+
240+
Other limitations include that only `docker` is supported as a container runtime, which
241+
also implies that it is not currently possible to triage a test that requires a
242+
multi-node or multi-process test.
243+
244+
The tool also does not currently handle skipping commits that do not compile, or test
245+
cases that require copying files (*e.g.* script files) into the container.
246+
247+
If you run into these limitations in real-world usage of this tool, please file a bug
248+
against JAX-Toolbox including details of manual steps you took to root-case the test
249+
regression.

docs/triage.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr
44
be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed,
55
but this automates the investigation of questions like "what state of library X works with Jax at state Y?"
66

7-
__Note__: There is also a utility, [triage](../.github/triage/triage), which can be
8-
used for more granular bisection of failures in specific tests. Run it with `--help`
9-
for usage instructions. Given a test expression that can be run inside the nightly
10-
containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly
11-
container where the failure first appeared, and second attributes the failure to a
12-
specific commit of JAX or XLA.
7+
__Note__: There is also a [triage tool](triage-tool.md), which can be used for
8+
more granular bisection of failures in specific tests. Given a test expression that can
9+
be run inside the nightly containers (*e.g.* `test-jax.sh jet_test_gpu`), it first
10+
identifies the nightly container where the failure first appeared, and second attributes
11+
the failure to a specific commit of JAX or XLA.
1312

1413
## Algorithm
1514
The pseudocode for the triaging algorithm is as follows:

0 commit comments

Comments
 (0)