Skip to content

Commit 9750267

Browse files
committed
[BIONEMO-2473] Added tests for Evo2 LoRA fine-tuning
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent c4f2038 commit 9750267

File tree

6 files changed

+201
-42
lines changed

6 files changed

+201
-42
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/peft.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
17-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
3016
from copy import deepcopy
3117
from typing import List, Optional
3218

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
4040
from nemo.lightning.pytorch import callbacks as nl_callbacks
41-
from nemo.lightning.pytorch.callbacks import ModelTransform
4241
from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback
4342
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
4443
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
@@ -492,7 +491,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
492491
help="Disable saving the last checkpoint.",
493492
)
494493
parser.add_argument("--lora-finetune", action="store_true", help="Use LoRA fine-tuning", default=False)
495-
parser.add_argument("--lora-checkpoint-path", type=Path, default=None, help="LoRA checkpoint path")
494+
parser.add_argument("--lora-checkpoint-path", type=str, default=None, help="LoRA checkpoint path")
496495
parser.add_argument(
497496
"--no-calculate-per-token-loss",
498497
action="store_true",
@@ -646,7 +645,7 @@ def train(args: argparse.Namespace) -> nl.Trainer:
646645
]
647646

648647
if args.lora_finetune:
649-
callbacks.append(ModelTransform())
648+
callbacks.append(lora_transform)
650649
if args.enable_preemption:
651650
callbacks.append(nl_callbacks.PreemptionCallback())
652651
if args.debug_ddp_parity_freq > 0:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved.
3+
# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved.
4+
# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved
5+
# SPDX-License-Identifier: LicenseRef-Apache2
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
20+
def small_training_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
21+
"""Command for training."""
22+
cmd = (
23+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
24+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
25+
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
26+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
27+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
28+
)
29+
return cmd
30+
31+
32+
def small_training_finetune_cmd(
33+
path,
34+
max_steps,
35+
val_check,
36+
prev_ckpt,
37+
devices: int = 1,
38+
create_tflops_callback: bool = True,
39+
additional_args: str = "",
40+
):
41+
"""Command for finetuning."""
42+
cmd = (
43+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
44+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
45+
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger "
46+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
47+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt} "
48+
f"{'--create-tflops-callback' if create_tflops_callback else ''}"
49+
)
50+
return cmd
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved.
3+
# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved.
4+
# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved
5+
# SPDX-License-Identifier: LicenseRef-Apache2
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
import pytest
19+
20+
from bionemo.testing.subprocess_utils import run_command_in_subprocess
21+
22+
from .common import small_training_cmd, small_training_finetune_cmd
23+
24+
25+
@pytest.mark.timeout(512) # Optional: fail if the test takes too long.
26+
@pytest.mark.slow
27+
@pytest.mark.parametrize("with_peft", [True, False])
28+
def test_train_evo2_finetune_runs_lora(tmp_path, with_peft: bool):
29+
"""
30+
This test runs the `train_evo2` command with mock data in a temporary directory.
31+
It uses the temporary directory provided by pytest as the working directory.
32+
The command is run in a subshell, and we assert that it returns an exit code of 0.
33+
"""
34+
num_steps = 2
35+
# Note: The command assumes that `train_evo2` is in your PATH.
36+
command = small_training_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps)
37+
stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path))
38+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
39+
40+
log_dir = tmp_path / "pretrain" / "evo2"
41+
checkpoints_dir = log_dir / "checkpoints"
42+
tensorboard_dir = log_dir / "dev"
43+
44+
# Check if logs dir exists
45+
assert log_dir.exists(), "Logs folder should exist."
46+
# Check if checkpoints dir exists
47+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
48+
49+
expected_checkpoint_suffix = f"{num_steps}.0-last"
50+
# Check if any subfolder ends with the expected suffix
51+
matching_subfolders = [
52+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
53+
]
54+
55+
assert matching_subfolders, (
56+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
57+
)
58+
59+
# Check if directory with tensorboard logs exists
60+
assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist."
61+
# Recursively search for files with tensorboard logger
62+
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
63+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir}"
64+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
65+
if with_peft:
66+
result_dir = tmp_path / "lora_finetune"
67+
additional_args = "--lora-finetune"
68+
else:
69+
result_dir = tmp_path / "finetune"
70+
additional_args = ""
71+
72+
command_finetune = small_training_finetune_cmd(
73+
result_dir,
74+
max_steps=num_steps,
75+
val_check=num_steps,
76+
prev_ckpt=matching_subfolders[0],
77+
create_tflops_callback=not with_peft,
78+
additional_args=additional_args,
79+
)
80+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
81+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
82+
83+
log_dir_ft = result_dir / "evo2"
84+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
85+
tensorboard_dir_ft = log_dir_ft / "dev"
86+
87+
# Check if logs dir exists
88+
assert log_dir_ft.exists(), "Logs folder should exist."
89+
# Check if checkpoints dir exists
90+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
91+
92+
expected_checkpoint_suffix = f"{num_steps}.0-last"
93+
# Check if any subfolder ends with the expected suffix
94+
matching_subfolders_finetune = [
95+
p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
96+
]
97+
98+
assert matching_subfolders_finetune, (
99+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}."
100+
)
101+
102+
# Check if directory with tensorboard logs exists
103+
assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist."
104+
# Recursively search for files with tensorboard logger
105+
event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
106+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}"
107+
108+
assert len(matching_subfolders_finetune) == 1, "Only one checkpoint subfolder should be found."
109+
110+
# With LoRA, test resuming from a saved LoRA checkpoint
111+
if with_peft:
112+
result_dir = tmp_path / "lora_finetune_resume"
113+
114+
# Resume from LoRA checkpoint
115+
command_resume_finetune = small_training_finetune_cmd(
116+
result_dir,
117+
max_steps=num_steps,
118+
val_check=num_steps,
119+
prev_ckpt=matching_subfolders[0],
120+
create_tflops_callback=False,
121+
additional_args=f"--lora-finetune --lora-checkpoint-path {matching_subfolders_finetune[0]}",
122+
)
123+
stdout_finetune: str = run_command_in_subprocess(command=command_resume_finetune, path=str(tmp_path))
124+
125+
log_dir_ft = result_dir / "evo2"
126+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
127+
tensorboard_dir_ft = log_dir_ft / "dev"
128+
129+
# Check if logs dir exists
130+
assert log_dir_ft.exists(), "Logs folder should exist."
131+
# Check if checkpoints dir exists
132+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state
3131
from bionemo.testing.subprocess_utils import run_command_in_subprocess
3232

33+
from .common import small_training_cmd, small_training_finetune_cmd
34+
3335

3436
fp8_available, reason_for_no_fp8 = check_fp8_support()
3537

3638

3739
def run_train_with_std_redirect(args: argparse.Namespace) -> Tuple[str, nl.Trainer]:
38-
"""
39-
Run a function with output capture.
40-
"""
40+
"""Run a function with output capture."""
4141
stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
4242
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
4343
with distributed_model_parallel_state():
@@ -50,28 +50,6 @@ def run_train_with_std_redirect(args: argparse.Namespace) -> Tuple[str, nl.Train
5050
return train_stdout, trainer
5151

5252

53-
def small_training_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
54-
cmd = (
55-
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
56-
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
57-
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
58-
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
59-
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
60-
)
61-
return cmd
62-
63-
64-
def small_training_finetune_cmd(path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = ""):
65-
cmd = (
66-
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
67-
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
68-
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
69-
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
70-
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}"
71-
)
72-
return cmd
73-
74-
7553
def small_training_mamba_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
7654
cmd = (
7755
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "

0 commit comments

Comments
 (0)