|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import os |
7 | 8 | import runpy
|
8 |
| - |
9 | 9 | import sys
|
10 | 10 | from pathlib import Path
|
11 | 11 |
|
@@ -113,3 +113,89 @@ def test_loss(
|
113 | 113 | torch.testing.assert_close(
|
114 | 114 | loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
|
115 | 115 | )
|
| 116 | + |
| 117 | + @pytest.mark.integration_test |
| 118 | + @pytest.mark.parametrize( |
| 119 | + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", |
| 120 | + [ |
| 121 | + ("llama3/8B_full", "llama3", "tune", 1, 4, False), |
| 122 | + ], |
| 123 | + ) |
| 124 | + @gpu_test(gpu_count=2) |
| 125 | + def test_training_state_on_resume( |
| 126 | + self, |
| 127 | + micro_batch_size, |
| 128 | + gradient_accumulation_steps, |
| 129 | + config, |
| 130 | + model_type, |
| 131 | + ckpt_type, |
| 132 | + optim_in_bwd, |
| 133 | + tmpdir, |
| 134 | + monkeypatch, |
| 135 | + ): |
| 136 | + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] |
| 137 | + ckpt = model_type + "_" + ckpt_type |
| 138 | + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) |
| 139 | + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) |
| 140 | + ckpt_dir = ckpt_path.parent |
| 141 | + log_file = gen_log_file_name(tmpdir) |
| 142 | + |
| 143 | + # Config file needed for model conversion. |
| 144 | + # Create a second copy for training resume |
| 145 | + write_hf_ckpt_config(ckpt_dir) |
| 146 | + write_hf_ckpt_config(tmpdir) |
| 147 | + |
| 148 | + # Train for two epochs |
| 149 | + cmd_1 = f""" |
| 150 | + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ |
| 151 | + --config {config} \ |
| 152 | + batch_size={micro_batch_size} \ |
| 153 | + gradient_accumulation_steps={gradient_accumulation_steps} \ |
| 154 | + output_dir={tmpdir} \ |
| 155 | + checkpointer._component_={ckpt_component} \ |
| 156 | + checkpointer.checkpoint_dir='{ckpt_dir}' \ |
| 157 | + checkpointer.checkpoint_files=[{ckpt_path}]\ |
| 158 | + checkpointer.output_dir={tmpdir} \ |
| 159 | + checkpointer.model_type={model_type.upper()} \ |
| 160 | + tokenizer.path='{tokenizer_path}' \ |
| 161 | + tokenizer.prompt_template=null \ |
| 162 | + clip_grad_norm=100 \ |
| 163 | + """.split() |
| 164 | + |
| 165 | + model_config = MODEL_TEST_CONFIGS[model_type] |
| 166 | + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config |
| 167 | + |
| 168 | + monkeypatch.setattr(sys, "argv", cmd_1) |
| 169 | + runpy.run_path(TUNE_PATH, run_name="__main__") |
| 170 | + |
| 171 | + # Resume training |
| 172 | + cmd_2 = f""" |
| 173 | + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ |
| 174 | + --config {config} \ |
| 175 | + batch_size={micro_batch_size} \ |
| 176 | + gradient_accumulation_steps={gradient_accumulation_steps} \ |
| 177 | + output_dir={tmpdir} \ |
| 178 | + checkpointer._component_={ckpt_component} \ |
| 179 | + checkpointer.checkpoint_dir='{tmpdir}' \ |
| 180 | + checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\ |
| 181 | + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\ |
| 182 | + checkpointer.output_dir={tmpdir} \ |
| 183 | + checkpointer.model_type={model_type.upper()} \ |
| 184 | + tokenizer.path='{tokenizer_path}' \ |
| 185 | + tokenizer.prompt_template=null \ |
| 186 | + resume_from_checkpoint=True \ |
| 187 | + metric_logger.filename={log_file} \ |
| 188 | + clip_grad_norm=100 \ |
| 189 | + """.split() |
| 190 | + |
| 191 | + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config |
| 192 | + |
| 193 | + monkeypatch.setattr(sys, "argv", cmd_2) |
| 194 | + runpy.run_path(TUNE_PATH, run_name="__main__") |
| 195 | + |
| 196 | + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] |
| 197 | + |
| 198 | + loss_values = get_loss_values_from_metric_logger(log_file) |
| 199 | + torch.testing.assert_close( |
| 200 | + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 |
| 201 | + ) |
0 commit comments