Skip to content

Commit e67180e

Browse files
author
maxtext authors
committed
Merge pull request #1778 from AI-Hypercomputer:enable_ckpt_tests
PiperOrigin-RevId: 764797579
2 parents d5ae120 + de0936d commit e67180e

File tree

4 files changed

+154
-71
lines changed

4 files changed

+154
-71
lines changed

MaxText/tests/globals.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

MaxText/tests/integration_tests/checkpoint_compatibility_test.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,94 @@
1414
limitations under the License.
1515
"""
1616

17-
"""Integration tests for test_checkpointing.sh"""
17+
"""
18+
Integration tests to check compatibility of checkpoints between different input pipelines.
19+
20+
These tests verify that a checkpoint saved during a training run using one
21+
input pipeline (e.g., 'grain') can be successfully restored and continued
22+
by a subsequent training run using a different input pipeline (e.g., 'tfds').
23+
The tests confirm restoration by checking the starting step of the resumed runs.
24+
25+
Note: Make sure to run
26+
`bash setup_gcsfuse.sh DATASET_GCS_BUCKET=gs://maxtext-dataset MOUNT_PATH=/tmp/gcsfuse/`
27+
before running tests locally.
28+
"""
1829

1930
from datetime import datetime
20-
import subprocess
21-
import os.path
31+
import json
2232
import pytest
23-
from MaxText.globals import PKG_DIR
24-
from MaxText.tests.globals import TEST_DISABLE_SUBPROCESS, TEST_DISABLE_SUBPROCESS_STR
33+
from MaxText.train import main as train_main
34+
from MaxText.tests.integration_tests.checkpointing_test import get_checkpointing_command
35+
2536

37+
def check_start_step(metrics_file, start_step_target):
38+
with open(metrics_file, "rt", encoding="utf8") as metrics:
39+
start_step = json.loads(metrics.readlines()[0])["step"]
40+
print(f"Start step is {start_step}, start step target is {start_step_target}")
41+
assert start_step == float(start_step_target)
2642

27-
def run_checkpoint_compatibility(attention_type):
43+
44+
def run_checkpoint_compatibility(hardware, attention_type):
2845
"""Tests checkpoint compatibility."""
2946

3047
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31-
script_path = os.path.join(os.path.dirname(PKG_DIR), "end_to_end", "test_checkpoint_compatibility.sh")
32-
if not os.path.isfile(script_path):
33-
raise FileNotFoundError(script_path)
34-
command = [
35-
"bash",
36-
script_path,
37-
f"runner_{run_date}", # run_name
38-
"gs://runner-maxtext-logs", # output_path
39-
"gs://maxtext-dataset", # dataset_path
40-
attention_type,
48+
grain_command = [
49+
"grain_worker_count=0",
50+
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
4151
]
4252

43-
subprocess.run(command, check=True, cwd=os.path.dirname(PKG_DIR))
53+
# Run training using grain input pipeline
54+
train_main(
55+
get_checkpointing_command(
56+
run_date,
57+
hardware=hardware,
58+
steps=3,
59+
metrics_file="run_1_metrics.txt",
60+
attention_type=attention_type,
61+
dataset_type="grain",
62+
dataset_path="/tmp/gcsfuse",
63+
)
64+
+ grain_command
65+
)
66+
67+
# Resume training using tfds input pipeline
68+
train_main(
69+
get_checkpointing_command(
70+
run_date,
71+
hardware=hardware,
72+
steps=5,
73+
metrics_file="run_2_metrics.txt",
74+
attention_type=attention_type,
75+
dataset_type="tfds",
76+
dataset_path="/tmp/gcsfuse",
77+
)
78+
)
79+
80+
# Resume training again using grain input pipeline
81+
train_main(
82+
get_checkpointing_command(
83+
run_date,
84+
hardware=hardware,
85+
steps=7,
86+
metrics_file="run_3_metrics.txt",
87+
attention_type=attention_type,
88+
dataset_type="grain",
89+
dataset_path="/tmp/gcsfuse",
90+
)
91+
+ grain_command
92+
)
93+
94+
check_start_step("run_2_metrics.txt", 3.0)
95+
check_start_step("run_3_metrics.txt", 5.0)
4496

4597

4698
@pytest.mark.integration_test
4799
@pytest.mark.tpu_only
48-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
49100
def test_autoselected_attention():
50-
run_checkpoint_compatibility("autoselected")
101+
run_checkpoint_compatibility("tpu", "autoselected")
51102

52103

53104
@pytest.mark.integration_test
54105
@pytest.mark.gpu_only
55-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
56106
def test_with_dot_product():
57-
run_checkpoint_compatibility("dot_product")
107+
run_checkpoint_compatibility("gpu", "dot_product")

MaxText/tests/integration_tests/checkpointing_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
"""
1616

1717
"""
18-
Integration tests for test_checkpointing.sh
18+
Integration tests for checkpointing functionality.
19+
20+
These tests verify that a training run saves a checkpoint,
21+
and then a subsequent training run can correctly restore and
22+
continue from that saved checkpoint.
1923
2024
Note: Make sure to run
2125
`bash setup_gcsfuse.sh DATASET_GCS_BUCKET=gs://maxtext-dataset MOUNT_PATH=/tmp/gcsfuse/`
@@ -31,7 +35,7 @@
3135
from MaxText.train import main as train_main
3236

3337

34-
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type):
38+
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path):
3539
model_params = [
3640
"base_emb_dim=384",
3741
"base_num_query_heads=8",
@@ -51,7 +55,8 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
5155
f"metrics_file={metrics_file}",
5256
"checkpoint_period=3",
5357
"base_output_directory=gs://runner-maxtext-logs",
54-
"dataset_path=/tmp/gcsfuse/",
58+
f"dataset_path={dataset_path}",
59+
f"dataset_type={dataset_type}",
5560
"async_checkpointing=False",
5661
f"attention={attention_type}",
5762
] + model_params
@@ -80,7 +85,6 @@ def run_checkpointing(hardware, attention_type):
8085
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
8186
grain_command = [
8287
"grain_worker_count=0",
83-
"dataset_type=grain",
8488
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
8589
]
8690
train_main(
@@ -90,6 +94,8 @@ def run_checkpointing(hardware, attention_type):
9094
steps=5,
9195
metrics_file="saved_metrics.txt",
9296
attention_type=attention_type,
97+
dataset_type="grain",
98+
dataset_path="/tmp/gcsfuse",
9399
)
94100
+ grain_command
95101
)
@@ -101,6 +107,8 @@ def run_checkpointing(hardware, attention_type):
101107
steps=10,
102108
metrics_file="restored_metrics.txt",
103109
attention_type=attention_type,
110+
dataset_type="grain",
111+
dataset_path="/tmp/gcsfuse",
104112
)
105113
+ grain_command
106114
)

MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,49 +14,96 @@
1414
limitations under the License.
1515
"""
1616

17-
"""Integration tests for test_generate_param_only_checkpoint.sh"""
17+
"""
18+
Integration tests for generating a decode-only checkpoint from a training checkpoint
19+
and then running decode with it.
20+
"""
1821
from datetime import datetime
19-
import subprocess
2022
import os.path
2123
import pytest
2224

2325
from MaxText.globals import PKG_DIR
24-
from MaxText.tests.globals import TEST_DISABLE_SUBPROCESS, TEST_DISABLE_SUBPROCESS_STR
26+
from MaxText.train import main as train_main
27+
from MaxText.decode import main as decode_main
28+
from MaxText.generate_param_only_checkpoint import main as generate_param_only_ckpt_main
29+
from MaxText.tests.integration_tests.checkpointing_test import get_checkpointing_command
2530

2631

27-
def run_generate_param_only_checkpoint(attention_type, quantization):
32+
def run_generate_param_only_checkpoint(hardware, attention_type, quantization):
2833
"""Tests generating a parameter-only checkpoint."""
2934

3035
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31-
script_path = os.path.join(os.path.dirname(PKG_DIR), "end_to_end", "test_generate_param_only_checkpoint.sh")
32-
if not os.path.isfile(script_path):
33-
raise FileNotFoundError(script_path)
34-
# fmt: off
35-
command = [
36-
"bash",
37-
script_path,
38-
"-r", f"runner_{run_date}",
39-
"-o", "gs://runner-maxtext-logs",
40-
"-d", "gs://maxtext-dataset",
41-
"-i", "4",
42-
"-a", attention_type,
43-
"-q", quantization,
36+
model_params = [
37+
f"quantization={quantization}",
38+
"base_emb_dim=384",
39+
"base_num_query_heads=8",
40+
"base_num_kv_heads=8",
41+
"base_mlp_dim=192",
42+
"base_num_decoder_layers=8",
43+
"head_dim=128",
4444
]
4545

46-
subprocess.run(command, check=True, cwd=os.path.dirname(PKG_DIR))
46+
train_main(
47+
get_checkpointing_command(
48+
run_date,
49+
hardware=hardware,
50+
steps=5,
51+
metrics_file="run_metrics.txt",
52+
attention_type=attention_type,
53+
dataset_type="tfds",
54+
dataset_path="gs://maxtext-dataset",
55+
)
56+
+ model_params
57+
)
58+
59+
state_path = f"gs://runner-maxtext-logs/runner_{run_date}/checkpoints/4/items"
60+
generate_param_only_ckpt_main(
61+
[
62+
None,
63+
os.path.join(PKG_DIR, "configs", "base.yml"),
64+
f"hardware={hardware}",
65+
f"run_name=generate_param_{run_date}",
66+
"base_output_directory=gs://runner-maxtext-logs",
67+
"dataset_path=gs://maxtext-dataset",
68+
"async_checkpointing=False",
69+
f"attention={attention_type}",
70+
f"load_full_state_path={state_path}",
71+
]
72+
+ model_params
73+
)
74+
75+
decode_ckpt_path = f"gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items"
76+
decode_main(
77+
[
78+
None,
79+
os.path.join(PKG_DIR, "configs", "base.yml"),
80+
f"hardware={hardware}",
81+
f"run_name=decode_{run_date}",
82+
"base_output_directory=gs://runner-maxtext-logs",
83+
"dataset_path=gs://maxtext-dataset",
84+
f"load_parameters_path={decode_ckpt_path}",
85+
f"attention={attention_type}",
86+
"max_target_length=128",
87+
]
88+
+ model_params
89+
)
4790

4891

4992
@pytest.mark.integration_test
5093
@pytest.mark.tpu_only
5194
@pytest.mark.parametrize("quantization", [(""), ("int8")])
52-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
53-
def test_autoselected_attention(quantization):
54-
run_generate_param_only_checkpoint("autoselected", quantization)
95+
def test_autoselected_attention(quantization, capsys):
96+
run_generate_param_only_checkpoint("tpu", "autoselected", quantization)
97+
captured = capsys.readouterr()
98+
expected_output = "Input `I love to`"
99+
assert expected_output in captured.out
55100

56101

57102
@pytest.mark.integration_test
58103
@pytest.mark.gpu_only
59104
@pytest.mark.parametrize("quantization", [(""), ("int8")])
60-
@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR)
61-
def test_with_dot_product(quantization):
62-
run_generate_param_only_checkpoint("dot_product", quantization)
105+
def test_with_dot_product(quantization, capsys):
106+
run_generate_param_only_checkpoint("gpu", "dot_product", quantization)
107+
captured = capsys.readouterr()
108+
expected_output = "Input `I love to`"
109+
assert expected_output in captured.out

0 commit comments

Comments
 (0)