Skip to content

Commit 9eb9427

Browse files
authored
Merge branch 'main' into aportnoy/use-jax_test_gpu
2 parents 02e8499 + 0832fac commit 9eb9427

File tree

11 files changed

+1346
-17
lines changed

11 files changed

+1346
-17
lines changed

.github/container/test-maxtext.sh

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ usage() {
1313
echo "Usage: $0 [OPTIONS]"
1414
echo ""
1515
echo " OPTIONS DESCRIPTION"
16-
echo " -a, --additional-args Additional args to pass to MaxText/train.py"
16+
echo " -a, --additional-args Additional args to pass to MaxText/train.py. Can be passed many times."
1717
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65". Default to 0.90, contradicting JAX default of 0.75.
1818
echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
1919
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
@@ -34,7 +34,7 @@ usage() {
3434
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
3535
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3636
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess
37-
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false
37+
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a "scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"
3838
5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess
3939
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
4040
7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
@@ -76,7 +76,7 @@ eval set -- "$args"
7676
while [ : ]; do
7777
case "$1" in
7878
-a | --additional-args)
79-
ADDITIONAL_ARGS="$2"
79+
ADDITIONAL_ARGS="$ADDITIONAL_ARGS $2"
8080
shift 2
8181
;;
8282
--mem-fraction)
@@ -245,22 +245,58 @@ RUN_NAME="logdir" ## the RUN_NAME cannot be changed
245245
if [ -z "$DECODER_BLOCK" ]; then
246246

247247
# this part could be used to test different model ootb
248-
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=${MODEL}\
249-
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} remat_policy=${REMAT_POLICY} enable_checkpointing=false\
250-
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
251-
dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP\
252-
ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP\
253-
ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS}"
254-
248+
RUN_SETTINGS="MaxText/train.py \
249+
MaxText/configs/base.yml \
250+
run_name=${RUN_NAME} \
251+
model_name=${MODEL} \
252+
steps=${STEPS} \
253+
per_device_batch_size=${BATCH_PER_GPU} \
254+
remat_policy=${REMAT_POLICY} \
255+
enable_checkpointing=false\
256+
base_output_directory=${OUTPUT} \
257+
dataset_path=local \
258+
dataset_type=synthetic \
259+
hardware=${HARDWARE} \
260+
enable_goodput_recording=false \
261+
monitor_goodput=false \
262+
dcn_fsdp_parallelism=${dcn_FSDP} \
263+
ici_fsdp_parallelism=${ici_FSDP} \
264+
ici_data_parallelism=${ici_DP} \
265+
dcn_data_parallelism=${dcn_DP} \
266+
ici_tensor_parallelism=${ici_TP} \
267+
dcn_tensor_parallelism=1 \
268+
${ADDITIONAL_ARGS}"
255269
else
256270
# this is essentially used for CI run
257-
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${DECODER_BLOCK} \
258-
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=${REMAT_POLICY} attention=${ATTN_TYPE}\
259-
base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
260-
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
261-
dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP\
262-
ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP\
263-
ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS}"
271+
RUN_SETTINGS="MaxText/train.py \
272+
MaxText/configs/base.yml \
273+
run_name=${RUN_NAME} \
274+
decoder_block=${DECODER_BLOCK} \
275+
steps=$STEPS \
276+
per_device_batch_size=${BATCH_PER_GPU} \
277+
base_emb_dim=2560 \
278+
base_mlp_dim=8192 \
279+
remat_policy=${REMAT_POLICY} \
280+
attention=${ATTN_TYPE} \
281+
base_num_query_heads=8 \
282+
base_num_kv_heads=8 \
283+
base_num_decoder_layers=8 \
284+
head_dim=128 \
285+
logits_via_embedding=true \
286+
enable_checkpointing=false \
287+
base_output_directory=${OUTPUT} \
288+
dataset_path=local \
289+
dataset_type=synthetic \
290+
hardware=${HARDWARE} \
291+
enable_goodput_recording=false \
292+
monitor_goodput=false \
293+
dcn_fsdp_parallelism=${dcn_FSDP} \
294+
ici_fsdp_parallelism=${ici_FSDP} \
295+
ici_data_parallelism=${ici_DP} \
296+
dcn_data_parallelism=${dcn_DP} \
297+
ici_tensor_parallelism=${ici_TP} \
298+
dcn_tensor_parallelism=1 \
299+
${ADDITIONAL_ARGS}"
264300
fi
265301

266302
echo "Command: python3 $RUN_SETTINGS"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .main import main
2+
3+
__all__ = ["main"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import datetime
3+
import getpass
4+
import os
5+
import pathlib
6+
import tempfile
7+
8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser(
11+
description="""
12+
Triage failures in JAX/XLA-related tests. The expectation is that the given
13+
test command is failing in recent versions, but that it passed in the past. The
14+
script first triages the regression with a search of the nightly containers,
15+
and then refines the search to a particular commit of JAX or XLA.""",
16+
)
17+
18+
container_search_args = parser.add_argument_group(
19+
title="Container-level search",
20+
description="""
21+
First, it is verified that the test command fails on the given end date, unless
22+
both --end-date and --skip-precondition-checks were passed. Then, the program
23+
searches backwards to find a container when the given test did pass. The
24+
--start-date option can be used to speed up this search, if you already know a
25+
date on which the test was passing. The earliest failure is located to within
26+
--threshold-days days.""",
27+
)
28+
commit_search_args = parser.add_argument_group(
29+
title="Commit-level search",
30+
description="""
31+
Second, the failure is localised to a commit of JAX or XLA by re-building and
32+
re-testing inside the earliest container that demonstrates the failure. At each
33+
point, the oldest JAX commit that is newer than XLA is used.""",
34+
)
35+
parser.add_argument(
36+
"--container",
37+
help="""
38+
Container to use. Example: jax, pax, triton. Used to construct the URLs of
39+
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
40+
required=True,
41+
)
42+
parser.add_argument(
43+
"--output-prefix",
44+
default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"),
45+
help="""
46+
Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS.
47+
An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is
48+
written as PREFIX-debug.log, and a JSON summary is written as
49+
PREFIX-summary.json""",
50+
type=pathlib.Path,
51+
)
52+
parser.add_argument(
53+
"--skip-precondition-checks",
54+
action="store_true",
55+
help="""
56+
Skip checks that should pass by construction. This saves time, but may yield
57+
incorrect results if you are not careful. Specifically this means that the test
58+
is assumed to fail on --end-date (if specified), pass on --start-date (if
59+
specified), and fail after recompilation in the earliest-known-failure
60+
container. Careful use of this option, along with --start-date, --end-date and
61+
--threshold-days, allows the container-level search to be skipped.""",
62+
)
63+
parser.add_argument(
64+
"test_command",
65+
nargs="+",
66+
help="""
67+
Command to execute inside the container. This should be as targeted as
68+
possible.""",
69+
)
70+
container_search_args.add_argument(
71+
"--end-date",
72+
help="""
73+
Initial estimate of the earliest nightly container date where the test case
74+
fails. Defaults to the newest available nightly container date. If this and
75+
--skip-precondition-checks are both set then it will not be verified that the
76+
test case fails on this date.""",
77+
type=lambda s: datetime.date.fromisoformat(s),
78+
)
79+
container_search_args.add_argument(
80+
"--start-date",
81+
help="""
82+
Initial estimate of the latest nightly container date where the test case
83+
passes. Defaults to the day before --end-date, but setting this to a date
84+
further in the past may lead to faster convergence of the initial backwards
85+
search for a date when the test case passed. If this and
86+
--skip-precondition-checks are both set then the test case *must* pass on
87+
this date, which will *not* be verified.""",
88+
type=lambda s: datetime.date.fromisoformat(s),
89+
)
90+
container_search_args.add_argument(
91+
"--threshold-days",
92+
default=1,
93+
help="""
94+
Convergence threshold. Ideally, the container-level search will continue while
95+
the number of days separating the last known success and first known failure is
96+
smaller than this value. The minimum, and default, value is 1. Note that in
97+
case of nightly build failures the search may finish without reaching this
98+
threshold.""",
99+
type=int,
100+
)
101+
commit_search_args.add_argument(
102+
"--bazel-cache",
103+
default=os.path.join(
104+
tempfile.gettempdir(), f"{getpass.getuser()}-bazel-triage-cache"
105+
),
106+
help="""
107+
Bazel cache to use when [re-]building JAX/XLA during the fine search. This can
108+
be a remote cache server or a local directory. Using a persistent cache can
109+
significantly speed up the commit-level search. By default, uses a temporary
110+
directory including the name of the current user.""",
111+
)
112+
return parser.parse_args()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import logging
2+
import pathlib
3+
import subprocess
4+
import typing
5+
6+
7+
class DockerContainer:
8+
def __init__(
9+
self,
10+
url: str,
11+
*,
12+
logger: logging.Logger,
13+
mounts: typing.List[typing.Tuple[pathlib.Path, pathlib.Path]],
14+
):
15+
self._logger = logger
16+
self._mount_args = []
17+
for src, dst in mounts:
18+
self._mount_args += ["-v", f"{src}:{dst}"]
19+
self._url = url
20+
21+
def __enter__(self):
22+
result = subprocess.run(
23+
[
24+
"docker",
25+
"run",
26+
"--detach",
27+
# Otherwise bazel shutdown hangs.
28+
"--init",
29+
"--gpus=all",
30+
"--shm-size=1g",
31+
]
32+
+ self._mount_args
33+
+ [
34+
self._url,
35+
"sleep",
36+
"infinity",
37+
],
38+
check=True,
39+
encoding="utf-8",
40+
stderr=subprocess.PIPE,
41+
stdout=subprocess.PIPE,
42+
)
43+
self._id = result.stdout.strip()
44+
return self
45+
46+
def __exit__(self, *exc_info):
47+
subprocess.run(
48+
["docker", "stop", self._id],
49+
check=True,
50+
stderr=subprocess.PIPE,
51+
stdout=subprocess.PIPE,
52+
)
53+
54+
def exec(
55+
self, command: typing.List[str], workdir=None
56+
) -> subprocess.CompletedProcess:
57+
"""
58+
Run a command inside a persistent container.
59+
"""
60+
workdir = [] if workdir is None else ["--workdir", workdir]
61+
return subprocess.run(
62+
["docker", "exec"] + workdir + [self._id] + command,
63+
encoding="utf-8",
64+
stderr=subprocess.PIPE,
65+
stdout=subprocess.PIPE,
66+
)
67+
68+
def check_exec(
69+
self, cmd: typing.List[str], **kwargs
70+
) -> subprocess.CompletedProcess:
71+
result = self.exec(cmd, **kwargs)
72+
if result.returncode != 0:
73+
self._logger.fatal(
74+
f"{' '.join(cmd)} exited with return code {result.returncode}"
75+
)
76+
self._logger.fatal(result.stdout)
77+
self._logger.fatal(result.stderr)
78+
result.check_returncode()
79+
return result

0 commit comments

Comments
 (0)