Skip to content

Commit a672761

Browse files
authored
Add local runtime support for in-situ unit test bisection (#1501)
Introduces a new `local` mode to the triage utility that allows developers to perform a commit-level bisection for unit tests without requiring container orchestration. This is designed for the case where a user is already inside a failing container and wants to find a culprit commit directly. **Changes:** 1. Adds a new option, **--container-runtime=local**. 2. When this mode is active, all `git` operations and build/test commands are executed directly on the local machine using `subprocess`. 3. The `run_and_log` utility was updated to accept a `cwd` parameter. This was necessary to allow the `LocalContainer` to execute commands like `git checkout` and `bazel` builds within the correct local repository directories. 4. The `local` runtime skips the container search, so it requires the user to provide the bisection range explicitly using: - `--passing-commits="jax:<hash>,xla:<hash>"` - `--failing-commits="jax:<hash>,xla:<hash>"` 5. The argument parser has been updated to enforce these new rules and prevent mixing `local` mode with container-specific arguments (i.e, `--container`, `--start-date`). 6. Unit tests to the new argument parsing logic Now, this command can be run inside a JAX development container: ``` jax-toolbox-triage \ --container-runtime=local \ --passing-commits="jax:<some_good_hash>,xla:<some_good_hash>" \ --failing-commits="jax:<some_bad_hash>,xla:<some_bad_hash>" \ /opt/jax-toolbox/.github/container/test-jax.sh "some_target" ```
1 parent a5956cd commit a672761

File tree

6 files changed

+115
-19
lines changed

6 files changed

+115
-19
lines changed

.github/triage/jax_toolbox_triage/args.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,27 @@ def parse_args(args=None) -> argparse.Namespace:
187187
parser.add_argument(
188188
"--container-runtime",
189189
default="docker",
190-
help="Container runtime used, this can be either docker or pyxis.",
190+
help="Container runtime used, can be docker, pyxis, or local.",
191191
type=lambda s: s.lower(),
192192
)
193193
args = parser.parse_args(args=args)
194-
assert args.container_runtime in {"docker", "pyxis"}, args.container_runtime
194+
assert args.container_runtime in {"docker", "pyxis", "local"}, (
195+
args.container_runtime
196+
)
197+
198+
if args.container_runtime == "local":
199+
assert args.passing_commits is not None and args.failing_commits is not None, (
200+
"For local runtime, --passing-commits and --failing-commits must be provided."
201+
)
202+
assert (
203+
args.container is None
204+
and args.start_date is None
205+
and args.end_date is None
206+
and args.passing_container is None
207+
and args.failing_container is None
208+
), "Container-level search options are not applicable for local runtime."
209+
return args
210+
195211
passing_commits_known = (args.passing_container is not None) or (
196212
args.passing_commits is not None
197213
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import logging
2+
import subprocess
3+
import typing
4+
5+
from .container import Container
6+
from .utils import run_and_log
7+
8+
9+
class LocalContainer(Container):
10+
def __init__(self, *, logger: logging.Logger):
11+
super().__init__(logger=logger)
12+
13+
def __enter__(self) -> Container:
14+
self._logger.debug("Running local mode inside current container")
15+
return self
16+
17+
def __exit__(self, *exc_info) -> None:
18+
pass
19+
20+
def __repr__(self) -> str:
21+
return "Local"
22+
23+
def exec(
24+
self,
25+
command: typing.List[str],
26+
policy: typing.Literal["once", "once_per_container", "default"] = "default",
27+
stderr: typing.Literal["interleaved", "separate"] = "interleaved",
28+
workdir=None,
29+
) -> subprocess.CompletedProcess:
30+
return run_and_log(command, logger=self._logger, stderr=stderr, cwd=workdir)
31+
32+
def exists(self) -> bool:
33+
"""The local environment always exists."""
34+
return True

.github/triage/jax_toolbox_triage/main.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .args import compulsory_software, optional_software, parse_args
1414
from .container import Container
1515
from .docker import DockerContainer
16+
from .local import LocalContainer
1617
from .logic import commit_search, container_search, TestResult
1718
from .pyxis import PyxisContainer
1819
from .utils import (
@@ -108,6 +109,9 @@ def test_output_directory(
108109
def Container(
109110
url, test_output_host_directory: typing.Optional[pathlib.Path] = None
110111
):
112+
if args.container_runtime == "local":
113+
return LocalContainer(logger=logger)
114+
111115
Imp = DockerContainer if args.container_runtime == "docker" else PyxisContainer
112116
mounts = bazel_cache_mounts + args.container_mount
113117
if test_output_host_directory is not None:
@@ -198,7 +202,10 @@ def check_container(date: datetime.date) -> TestResult:
198202
host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout
199203
)
200204

201-
if args.passing_container is None and args.failing_container is None:
205+
if args.container_runtime == "local":
206+
passing_url = "local"
207+
failing_url = "local"
208+
elif args.passing_container is None and args.failing_container is None:
202209
# Search through the published containers, narrowing down to a pair of dates with
203210
# the property that the test passed on `range_start` and fails on `range_end`.
204211
range_start, range_end = container_search(
@@ -235,7 +242,10 @@ def check_container(date: datetime.date) -> TestResult:
235242

236243
# Choose a container to do the commit-level bisection in; use directory
237244
# names that match it.
238-
if failing_url is not None:
245+
if args.container_runtime == "local":
246+
bisection_url = "local"
247+
package_dirs = failing_package_dirs
248+
elif failing_url is not None:
239249
bisection_url = failing_url
240250
package_dirs = failing_package_dirs
241251
else:

.github/triage/jax_toolbox_triage/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,18 @@ def prepare_bazel_cache_mounts(
8282

8383

8484
def run_and_log(
85-
command, logger: logging.Logger, stderr: typing.Literal["interleaved", "separate"]
85+
command,
86+
logger: logging.Logger,
87+
stderr: typing.Literal["interleaved", "separate"],
88+
cwd: typing.Optional[str] = None,
8689
) -> subprocess.CompletedProcess:
87-
logger.debug(shlex.join(command))
90+
logger.debug(f"Executing in {cwd or '.'}: {shlex.join(command)}")
8891
result = subprocess.Popen(
8992
command,
9093
encoding="utf-8",
9194
stderr=subprocess.STDOUT if stderr == "interleaved" else subprocess.PIPE,
9295
stdout=subprocess.PIPE,
96+
cwd=cwd,
9397
)
9498
assert result.stdout is not None
9599
stdouterr = ""

.github/triage/tests/test_arg_parsing.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,19 @@
4444
["--container", "jax", "--end-date", "2024-10-02"],
4545
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
4646
]
47+
valid_local_args = [
48+
"--container-runtime",
49+
"local",
50+
"--passing-commits",
51+
"jax:0123,xla:4567",
52+
"--failing-commits",
53+
"jax:89ab,xla:cdef",
54+
]
4755

4856

4957
@pytest.mark.parametrize(
5058
"good_args",
51-
[valid_start_end_container]
59+
[valid_start_end_container, valid_local_args]
5260
+ valid_start_end_date_args
5361
+ valid_container_and_commits,
5462
)
@@ -57,6 +65,14 @@ def test_good_container_args(good_args):
5765
assert args.test_command == test_command
5866

5967

68+
def test_good_local_args():
69+
args = parse_args(valid_local_args + test_command)
70+
assert args.test_command == test_command
71+
assert args.container_runtime == "local"
72+
assert "jax" in args.passing_commits
73+
assert "xla" in args.failing_commits
74+
75+
6076
@pytest.mark.parametrize("date_args", valid_start_end_date_args)
6177
def test_bad_container_arg_combinations_across_groups(date_args):
6278
# Can't combine --{start,end}-container with --container/--{start,end}-date
@@ -129,3 +145,20 @@ def test_unparsable_container_args(container_args):
129145
def test_invalid_container_runtime():
130146
with pytest.raises(Exception):
131147
parse_args(["--container-runtime=magic-beans"] + test_command)
148+
149+
150+
@pytest.mark.parametrize(
151+
"bad_local_args",
152+
[
153+
["--container-runtime", "local"],
154+
["--container-runtime", "local", "--passing-commits", "jax:1,xla:2"],
155+
["--container-runtime", "local", "--failing-commits", "jax:1,xla:2"],
156+
valid_local_args + ["--container", "jax"],
157+
valid_local_args + ["--start-date", "2024-01-01"],
158+
valid_local_args + ["--passing-container", "url"],
159+
valid_local_args + ["--failing-container", "url"],
160+
],
161+
)
162+
def test_bad_local_arg_combinations(bad_local_args):
163+
with pytest.raises(Exception):
164+
parse_args(bad_local_args + test_command)

docs/triage-tool.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,20 @@
33
`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an
44
individual commit of JAX or XLA.
55
It takes as input a command that returns an error (non-zero) code when run in "recent"
6-
containers, but which returns a success (zero) code when run in some "older" container.
7-
The command must be executable within the containers, *i.e.* it cannot refer to files
6+
environments, but which returns a success (zero) code when run in some "older" environment.
7+
The command must be executable within the triage environment, *i.e.* it cannot refer to files
88
that only exist on the host system, unless those are explicitly mounted in using the
9-
`-v` (`--container-mount`) option.
9+
`-v` (`--container-mount`) option when using a container-based runtime.
1010

11-
The tool follows a three-step process:
11+
The tool follows a process that can include up to three steps:
1212
1. A container-level search backwards from the "recent" container where the test is
1313
known to fail, which identifies an "older" container where the test passes. This
1414
search proceeds with an exponentially increasing step size and is based on the
1515
`YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`.
1616
2. A container-level binary search to refine this to the **latest** available
1717
container where test passes and the **earliest** available container where it
1818
fails.
19-
3. A commit-level binary search, repeatedly building + testing inside the same
20-
container, to identify a single commit of a software package known to the tool
19+
3. A commit-level binary search, repeatedly building + testing, to identify a single commit of a software package known to the tool
2120
(JAX, XLA, Flax, optionally MaxText) that causes the test to start failing, and a
2221
set of reference commits of the other packages that can be used to reproduce the
2322
regression.
@@ -64,20 +63,20 @@ or more machines with appropriate GPUs, *e.g.* inside an `salloc` session.
6463
Appropriate arguments (number of nodes, number of tasks per node, *etc.*) should be
6564
passed to `salloc` or set via `SLURM_` environment variables so that a bare `srun` will
6665
correctly launch the test case.
66+
If `--container-runtime=local` is used, the tool assumes it is already inside a JAX container and will execute all build and test commands directly.
6767

6868
## Usage
6969

7070
To use the tool, there are two compulsory inputs:
7171
* A test command to triage.
72-
* A specification of which containers to triage in. There are two choices here:
73-
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
74-
families to execute the test command in. Example: `jax` for a JAX unit test
75-
failure, `maxtext` for a MaxText model execution failure. The `--start-date` and
72+
* A specification of the triage scope. There are three choices here:
73+
* **Container Search**: Usage `--container` to specify which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container families to search through. Example: `jax` for a JAX unit test failure, `maxtext` for a MaxText model execution failure. The `--start-date` and
7674
`--end-date` options can be combined with `--container` to tune the search; see
7775
below for more details.
78-
* `--passing-container` and `--failing-container`: a pair of URLs to containers to
76+
* **Commit Search between Containers**: `--passing-container` and `--failing-container`: a pair of URLs to containers to
7977
use in the commit-level search; if these are passed then no container-level
8078
search is performed.
79+
* **Local Commit Search**: Use `--container-runtime=local` when you are already inside a JAX container. This mode skips all container orchestration and performs a commit-level search directly in the local container. It requires you to specify the commit range with `--passing-commits` and `--failing-commits`.
8180

8281
The test command will be executed directly in the container, not inside a shell, so be
8382
sure not to add excessive quotation marks (*i.e.* run
@@ -88,7 +87,7 @@ as fast and targeted as possible.
8887
If you want to run multiple commands, you might want to use something like
8988
`jax-toolbox-triage --container=jax sh -c "command1 && command2"`.
9089

91-
Alternatively, you can use `-v` (`--container-mount`) to mount a host directory
90+
Alternatively, when using a container runtime, you can use `-v` (`--container-mount`) to mount a host directory
9291
containing test scripts into the container and execute a script from there, *e.g.*
9392
`-v $PWD:/work /work/test.sh`.
9493

0 commit comments

Comments
 (0)