Skip to content

Commit fe5639a

Browse files
authored
Merge branch 'main' into alechan/upgrade-xpk-v0.13.0
2 parents e506904 + 08ae742 commit fe5639a

File tree

2 files changed

+150
-30
lines changed

2 files changed

+150
-30
lines changed

.github/triage/jax_toolbox_triage/triage_tool.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ def _gather_histories(
153153
args=self.args,
154154
)
155155
package_versions[package] = history
156+
if package in self.args.cherry_pick:
157+
# If explicit commits to cherry-pick were given on the commandline,
158+
# make sure they are known to the local working copy. They might not be
159+
# if the fix being cherry-picked is newer, or only lives in a remote
160+
# that is being passed in via --override-remotes.
161+
worker.check_exec(
162+
["git", "fetch", self.args.override_remotes.get(package, "origin")]
163+
+ self.args.cherry_pick[package],
164+
policy="once_per_container",
165+
workdir=self.package_dirs[package],
166+
)
156167
for cherry_pick_range in cherry_pick_ranges:
157168
if package not in self.args.cherry_pick:
158169
self.args.cherry_pick[package] = []
@@ -529,22 +540,6 @@ def gather_version_info(self, passing_url: str, failing_url: str):
529540
passing_url, failing_url, passing_env, failing_env
530541
)
531542

532-
# We only know how to handle software packages that have versions defined at
533-
# both ends of the range.
534-
inconsistent_keys = passing_versions.keys() ^ failing_versions.keys()
535-
if len(inconsistent_keys):
536-
self.logger.warning(
537-
f"Ignoring packages that only have defined versions in one endpoint: {' '.join(inconsistent_keys)}"
538-
)
539-
for k in inconsistent_keys:
540-
for d in [passing_versions, failing_versions]:
541-
d.pop(k, None)
542-
543-
# Which packages have versions that are not always the same?
544-
self.dynamic_packages = {
545-
pkg
546-
for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items())
547-
}
548543
# Choose an environment to do the version-level bisection in; use directory names that
549544
# match it, and track what the initial versions of the different packages are
550545
if self.args.container_runtime == "local":
@@ -560,6 +555,40 @@ def gather_version_info(self, passing_url: str, failing_url: str):
560555
self.bisection_url = passing_url
561556
self.bisection_versions = original_passing_versions
562557
self.package_dirs = passing_package_dirs
558+
559+
# We only know how to handle software packages that have versions defined at
560+
# both ends of the range.
561+
inconsistent_keys = passing_versions.keys() ^ failing_versions.keys()
562+
if len(inconsistent_keys):
563+
self.logger.warning(
564+
f"Ignoring packages that only have defined versions in one endpoint: {' '.join(inconsistent_keys)}"
565+
)
566+
for k in inconsistent_keys:
567+
for d in [passing_versions, failing_versions]:
568+
d.pop(k, None)
569+
570+
# Not sure how to handle a package that does not have a defined version in the
571+
# bisection environment but that is expected to be included in the bisection...
572+
assert passing_versions.keys() == failing_versions.keys()
573+
unknown_initial_packages = (
574+
passing_versions.keys() - self.bisection_versions.keys()
575+
)
576+
assert len(unknown_initial_packages) == 0, (
577+
passing_versions.keys(),
578+
self.bisection_versions.keys(),
579+
)
580+
581+
# Which packages have versions that are not always the same? There are three
582+
# relevant sets of versions: the starting values in the bisection environment,
583+
# the start/passing value for the bisection, and the end/failing value for the
584+
# bisection.
585+
static_packages = {
586+
pkg
587+
for pkg, _ in set(passing_versions.items())
588+
& set(failing_versions.items())
589+
& set(self.bisection_versions.items())
590+
}
591+
self.dynamic_packages = passing_versions.keys() - static_packages
563592
self.logger.info(f"Using {self.bisection_url} for version-level bisection...")
564593
assert self.package_dirs is not None
565594
# This is the set of versions that are already installed

.github/triage/tests/test_pyxis_backend.py

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
mock_scripts_path = pathlib.Path(__file__).parent / "mock_scripts"
1414

1515

16+
pyxis_args = [
17+
# Currently no way to use the pyxis backend without a cache
18+
"--bazel-cache=https://example.com/does-not-exist",
19+
"--container-runtime=pyxis",
20+
]
21+
22+
1623
def git_cmd(*args, cwd=None):
1724
return subprocess.run(
1825
["git"] + list(args),
@@ -73,9 +80,17 @@ def failing_container(passing_container):
7380
git("commit", "--allow-empty", "-m", "C4")
7481
git("commit", "--allow-empty", "-m", "C5")
7582
git("commit", "--allow-empty", "-m", "C6")
83+
c6 = git("rev-parse", "HEAD")
7684
if scenario == "non-linear":
85+
# cherry-pick the feature on top of the good commit
86+
git("checkout", metadata[f"{project}_good"])
87+
git("cherry-pick", passing_container[f"{project}_feature_commit"])
88+
metadata[f"{project}_good_with_feature"] = git("rev-parse", "HEAD")
7789
git("checkout", "-b", "feature-2")
90+
git("reset", "--hard", c6)
7891
git("cherry-pick", passing_container[f"{project}_feature_commit"])
92+
else:
93+
metadata[f"{project}_good_with_feature"] = metadata[f"{project}_good"]
7994
metadata[f"{project}_failing_container"] = git("rev-parse", "HEAD")
8095
yield metadata
8196

@@ -99,10 +114,7 @@ def test_mock_containers(
99114
# Ensure bazel, build-jax.sh, srun etc. stubs can be found.
100115
monkeypatch.setenv("PATH", str(mock_scripts_path), prepend=":")
101116
with tempfile.TemporaryDirectory() as output_prefix:
102-
arg_list = [
103-
# Currently no way to use the pyxis backend without a cache
104-
"--bazel-cache=https://example.com/does-not-exist",
105-
"--container-runtime=pyxis",
117+
arg_list = pyxis_args + [
106118
"--output-prefix",
107119
output_prefix,
108120
"--passing-container",
@@ -144,6 +156,93 @@ def test_mock_containers(
144156
)
145157

146158

159+
def test_mock_containers_with_explicit_version_override(
160+
monkeypatch,
161+
passing_container,
162+
failing_container,
163+
):
164+
# The point of this test is that if you pass
165+
# --passing-container=<has X=a> --failing-container=<has X=c>,
166+
# --passing-versions=X=b --failing-versions=X=b
167+
# then it is important to check out version b of package X in the triage
168+
# environment, even though it is not actually included in the triage
169+
170+
# fixed_package is forced to its good value by --{passing,failing}-versions, and
171+
# the test command only passes if it has *exactly* that value -- which it doesn't
172+
# initially have in either container. `_with_feature` is needed to make sure the
173+
# fake build succeeds.
174+
triage_package, fixed_package = compulsory_software[:2]
175+
fixed_good_commit = failing_container[f"{fixed_package}_good_with_feature"]
176+
# Tell the mock `srun` how to behave
177+
monkeypatch.setenv("JAX_TOOLBOX_TRIAGE_MOCK_SRUN_NODES", str(1))
178+
monkeypatch.setenv("JAX_TOOLBOX_TRIAGE_MOCK_SRUN_PROCS_PER_NODE", str(1))
179+
# Ensure bazel, build-jax.sh, srun etc. stubs can be found.
180+
monkeypatch.setenv("PATH", str(mock_scripts_path), prepend=":")
181+
with tempfile.TemporaryDirectory() as output_prefix:
182+
arg_list = pyxis_args + [
183+
"--output-prefix",
184+
output_prefix,
185+
"--passing-container",
186+
str(passing_container["prefix"]),
187+
"--passing-versions",
188+
f"{fixed_package}:{fixed_good_commit}",
189+
"--failing-container",
190+
str(failing_container["prefix"]),
191+
"--failing-versions",
192+
f"{fixed_package}:{fixed_good_commit}",
193+
"--",
194+
"sh",
195+
"-c",
196+
" && ".join(
197+
[
198+
f'[ $(cd ${{JAX_TOOLBOX_TRIAGE_PREFIX}}/opt/{fixed_package} && git rev-parse HEAD) = "{fixed_good_commit}" ]',
199+
f"test-case.sh /opt/{triage_package} {failing_container[f'{triage_package}_bad']}",
200+
]
201+
),
202+
]
203+
args = parse_args(arg_list)
204+
logger = logging.getLogger()
205+
logger.setLevel(logging.DEBUG)
206+
tool = TriageTool(args, logger)
207+
# Check the correct versions are extracted from the two pseudocontainers
208+
passing_versions, failing_versions = tool.gather_version_info(
209+
args.passing_container, args.failing_container
210+
)
211+
# These are not overridden, they are read from the containers
212+
assert (
213+
passing_versions[triage_package]
214+
== passing_container[f"{triage_package}_passing_container"]
215+
)
216+
assert (
217+
failing_versions[triage_package]
218+
== failing_container[f"{triage_package}_failing_container"]
219+
)
220+
# These are overridden by --passing-version and --failing-version
221+
assert passing_versions[fixed_package] == fixed_good_commit
222+
assert failing_versions[fixed_package] == fixed_good_commit
223+
# The starting value is not the value it is fixed to
224+
assert tool.bisection_versions[fixed_package] != fixed_good_commit
225+
assert tool.bisection_versions[fixed_package] in {
226+
passing_container[f"{fixed_package}_passing_container"],
227+
failing_container[f"{fixed_package}_failing_container"],
228+
}
229+
# fixed_package is dynamic, because its version needs to be changed from its starting value
230+
assert fixed_package in tool.dynamic_packages
231+
# triage_package is dynamic, because it is being triaged
232+
assert triage_package in tool.dynamic_packages
233+
# Run the bisection
234+
summary_data = tool.run_version_bisection(passing_versions, failing_versions)
235+
assert "result" in summary_data, summary_data
236+
assert (
237+
summary_data["result"][f"{triage_package}_good"]
238+
== failing_container[f"{triage_package}_good"]
239+
)
240+
assert (
241+
summary_data["result"][f"{triage_package}_bad"]
242+
== failing_container[f"{triage_package}_bad"]
243+
)
244+
245+
147246
@pytest.fixture
148247
def passing_container_with_bad_library(passing_container):
149248
scenario = passing_container["scenario"]
@@ -186,13 +285,9 @@ def test_triage_with_missing_installation_script_dir(
186285
monkeypatch.setenv("JAX_TOOLBOX_TRIAGE_MOCK_SRUN_PROCS_PER_NODE", "1")
187286
# Ensure the srun stub can be found
188287
monkeypatch.setenv("PATH", str(mock_scripts_path), prepend=":")
189-
arg_list = [
190-
# Currently no way to use the pyxis backend without a cache
191-
"--bazel-cache=https://example.com/does-not-exist",
288+
arg_list = pyxis_args + [
192289
"--build-scripts",
193290
"/path-does-not-exist",
194-
"--container-runtime",
195-
"pyxis",
196291
"--passing-container",
197292
str(passing_container["prefix"]),
198293
"--failing-container",
@@ -226,13 +321,9 @@ def test_triage_with_installation_scripts(
226321
# Ensure bazel, build-jax.sh, srun etc. stubs can be found.
227322
monkeypatch.setenv("PATH", str(mock_scripts_path), prepend=":")
228323
with tempfile.TemporaryDirectory() as output_prefix:
229-
arg_list = [
230-
# Currently no way to use the pyxis backend without a cache
231-
"--bazel-cache=https://example.com/does-not-exist",
324+
arg_list = pyxis_args + [
232325
"--build-scripts",
233326
"/build-scripts",
234-
"--container-runtime",
235-
"pyxis",
236327
"--output-prefix",
237328
output_prefix,
238329
"--passing-container",

0 commit comments

Comments
 (0)