|
14 | 14 | limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
17 | | -"""Integration tests for test_generate_param_only_checkpoint.sh""" |
| 17 | +""" |
| 18 | +Integration tests for generating a decode-only checkpoint from a training checkpoint |
| 19 | +and then running decode with it. |
| 20 | +""" |
18 | 21 | from datetime import datetime |
19 | | -import subprocess |
20 | 22 | import os.path |
21 | 23 | import pytest |
22 | 24 |
|
23 | 25 | from MaxText.globals import PKG_DIR |
24 | | -from MaxText.tests.globals import TEST_DISABLE_SUBPROCESS, TEST_DISABLE_SUBPROCESS_STR |
| 26 | +from MaxText.train import main as train_main |
| 27 | +from MaxText.decode import main as decode_main |
| 28 | +from MaxText.generate_param_only_checkpoint import main as generate_param_only_ckpt_main |
| 29 | +from MaxText.tests.integration_tests.checkpointing_test import get_checkpointing_command |
25 | 30 |
|
26 | 31 |
|
27 | | -def run_generate_param_only_checkpoint(attention_type, quantization): |
| 32 | +def run_generate_param_only_checkpoint(hardware, attention_type, quantization): |
28 | 33 | """Tests generating a parameter-only checkpoint.""" |
29 | 34 |
|
30 | 35 | run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
31 | | - script_path = os.path.join(os.path.dirname(PKG_DIR), "end_to_end", "test_generate_param_only_checkpoint.sh") |
32 | | - if not os.path.isfile(script_path): |
33 | | - raise FileNotFoundError(script_path) |
34 | | - # fmt: off |
35 | | - command = [ |
36 | | - "bash", |
37 | | - script_path, |
38 | | - "-r", f"runner_{run_date}", |
39 | | - "-o", "gs://runner-maxtext-logs", |
40 | | - "-d", "gs://maxtext-dataset", |
41 | | - "-i", "4", |
42 | | - "-a", attention_type, |
43 | | - "-q", quantization, |
| 36 | + model_params = [ |
| 37 | + f"quantization={quantization}", |
| 38 | + "base_emb_dim=384", |
| 39 | + "base_num_query_heads=8", |
| 40 | + "base_num_kv_heads=8", |
| 41 | + "base_mlp_dim=192", |
| 42 | + "base_num_decoder_layers=8", |
| 43 | + "head_dim=128", |
44 | 44 | ] |
45 | 45 |
|
46 | | - subprocess.run(command, check=True, cwd=os.path.dirname(PKG_DIR)) |
| 46 | + train_main( |
| 47 | + get_checkpointing_command( |
| 48 | + run_date, |
| 49 | + hardware=hardware, |
| 50 | + steps=5, |
| 51 | + metrics_file="run_metrics.txt", |
| 52 | + attention_type=attention_type, |
| 53 | + dataset_type="tfds", |
| 54 | + dataset_path="gs://maxtext-dataset", |
| 55 | + ) |
| 56 | + + model_params |
| 57 | + ) |
| 58 | + |
| 59 | + state_path = f"gs://runner-maxtext-logs/runner_{run_date}/checkpoints/4/items" |
| 60 | + generate_param_only_ckpt_main( |
| 61 | + [ |
| 62 | + None, |
| 63 | + os.path.join(PKG_DIR, "configs", "base.yml"), |
| 64 | + f"hardware={hardware}", |
| 65 | + f"run_name=generate_param_{run_date}", |
| 66 | + "base_output_directory=gs://runner-maxtext-logs", |
| 67 | + "dataset_path=gs://maxtext-dataset", |
| 68 | + "async_checkpointing=False", |
| 69 | + f"attention={attention_type}", |
| 70 | + f"load_full_state_path={state_path}", |
| 71 | + ] |
| 72 | + + model_params |
| 73 | + ) |
| 74 | + |
| 75 | + decode_ckpt_path = f"gs://runner-maxtext-logs/generate_param_{run_date}/checkpoints/0/items" |
| 76 | + decode_main( |
| 77 | + [ |
| 78 | + None, |
| 79 | + os.path.join(PKG_DIR, "configs", "base.yml"), |
| 80 | + f"hardware={hardware}", |
| 81 | + f"run_name=decode_{run_date}", |
| 82 | + "base_output_directory=gs://runner-maxtext-logs", |
| 83 | + "dataset_path=gs://maxtext-dataset", |
| 84 | + f"load_parameters_path={decode_ckpt_path}", |
| 85 | + f"attention={attention_type}", |
| 86 | + "max_target_length=128", |
| 87 | + ] |
| 88 | + + model_params |
| 89 | + ) |
47 | 90 |
|
48 | 91 |
|
49 | 92 | @pytest.mark.integration_test |
50 | 93 | @pytest.mark.tpu_only |
51 | 94 | @pytest.mark.parametrize("quantization", [(""), ("int8")]) |
52 | | -@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR) |
53 | | -def test_autoselected_attention(quantization): |
54 | | - run_generate_param_only_checkpoint("autoselected", quantization) |
| 95 | +def test_autoselected_attention(quantization, capsys): |
| 96 | + run_generate_param_only_checkpoint("tpu", "autoselected", quantization) |
| 97 | + captured = capsys.readouterr() |
| 98 | + expected_output = "Input `I love to`" |
| 99 | + assert expected_output in captured.out |
55 | 100 |
|
56 | 101 |
|
57 | 102 | @pytest.mark.integration_test |
58 | 103 | @pytest.mark.gpu_only |
59 | 104 | @pytest.mark.parametrize("quantization", [(""), ("int8")]) |
60 | | -@pytest.mark.skipif(TEST_DISABLE_SUBPROCESS, reason=TEST_DISABLE_SUBPROCESS_STR) |
61 | | -def test_with_dot_product(quantization): |
62 | | - run_generate_param_only_checkpoint("dot_product", quantization) |
| 105 | +def test_with_dot_product(quantization, capsys): |
| 106 | + run_generate_param_only_checkpoint("gpu", "dot_product", quantization) |
| 107 | + captured = capsys.readouterr() |
| 108 | + expected_output = "Input `I love to`" |
| 109 | + assert expected_output in captured.out |
0 commit comments