Skip to content

Commit dd4f626

Browse files
authored
[BIONEMO-2473] Added tests for Evo2 LoRA fine-tuning (#1060)
### Description Fixes and added test for Evo2LoRA ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels: - [SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci) - Skip all continuous integration tests - [INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests) - Execute notebook validation tests in pytest - [INCLUDE_SLOW_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_slow_tests) - Execute tests labelled as slow in pytest for extensive testing > [!NOTE] > By default, the notebooks validation tests are skipped unless explicitly enabled. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Usage <!--- How does a user interact with the changed code --> ```python # TODO: Add code snippet ``` ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [x] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expose controls for mock dataset sizes (train/val/test) for training runs. * LoRA finetuning flow simplified; LoRA integration now passes a preconstructed transform and checkpoint paths accept plain strings. * **Tests** * Added end-to-end integration tests for pretraining, finetuning, and LoRA finetuning with artifact and loss validations. * Introduced shared test helpers for constructing small training/finetune commands and consolidated imports. * **Chores** * Updated/cleaned license header boilerplate in tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent 58297ae commit dd4f626

File tree

6 files changed

+285
-43
lines changed

6 files changed

+285
-43
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/peft.py renamed to sub-packages/bionemo-evo2/src/bionemo/evo2/models/peft.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
17-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
18-
#
19-
# Licensed under the Apache License, Version 2.0 (the "License");
20-
# you may not use this file except in compliance with the License.
21-
# You may obtain a copy of the License at
22-
#
23-
# http://www.apache.org/licenses/LICENSE-2.0
24-
#
25-
# Unless required by applicable law or agreed to in writing, software
26-
# distributed under the License is distributed on an "AS IS" BASIS,
27-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28-
# See the License for the specific language governing permissions and
29-
# limitations under the License.
3016
from copy import deepcopy
3117
from typing import List, Optional
3218

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
4040
from nemo.lightning.pytorch import callbacks as nl_callbacks
41-
from nemo.lightning.pytorch.callbacks import ModelTransform
4241
from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback
4342
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
4443
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
@@ -49,7 +48,7 @@
4948
from bionemo.evo2.data.sharded_eden_dataloader import ShardedEdenDataModule
5049
from bionemo.evo2.models.llama import LLAMA_MODEL_OPTIONS
5150
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings
52-
from bionemo.evo2.run.peft import Evo2LoRA
51+
from bionemo.evo2.models.peft import Evo2LoRA
5352
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime
5453
from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings
5554
from bionemo.evo2.utils.logging.callbacks import TEVCallback
@@ -611,7 +610,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
611610
help="Disable saving the last checkpoint.",
612611
)
613612
parser.add_argument("--lora-finetune", action="store_true", help="Use LoRA fine-tuning", default=False)
614-
parser.add_argument("--lora-checkpoint-path", type=Path, default=None, help="LoRA checkpoint path")
613+
parser.add_argument("--lora-checkpoint-path", type=str, default=None, help="LoRA checkpoint path")
615614
parser.add_argument(
616615
"--no-calculate-per-token-loss",
617616
action="store_true",
@@ -669,6 +668,9 @@ def train(args: argparse.Namespace) -> nl.Trainer:
669668
seq_length=args.seq_length,
670669
micro_batch_size=args.micro_batch_size,
671670
global_batch_size=global_batch_size,
671+
num_train_samples=args.max_steps * global_batch_size,
672+
num_val_samples=args.limit_val_batches * global_batch_size,
673+
num_test_samples=1,
672674
num_workers=args.workers,
673675
tokenizer=tokenizer,
674676
)
@@ -823,7 +825,7 @@ def train(args: argparse.Namespace) -> nl.Trainer:
823825
callbacks.append(GarbageCollectAtInferenceTime())
824826

825827
if args.lora_finetune:
826-
callbacks.append(ModelTransform())
828+
callbacks.append(lora_transform)
827829
if args.enable_preemption:
828830
callbacks.append(nl_callbacks.PreemptionCallback())
829831
if args.debug_ddp_parity_freq > 0:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved.
3+
# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved.
4+
# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved
5+
# SPDX-License-Identifier: LicenseRef-Apache2
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
20+
def small_training_cmd(
21+
path,
22+
max_steps,
23+
val_check,
24+
global_batch_size: int | None = None,
25+
devices: int = 1,
26+
additional_args: str = "",
27+
):
28+
"""Command for training."""
29+
cmd = (
30+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
31+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
32+
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
33+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} "
34+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} "
35+
f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}"
36+
)
37+
return cmd
38+
39+
40+
def small_training_finetune_cmd(
41+
path,
42+
max_steps,
43+
val_check,
44+
prev_ckpt,
45+
devices: int = 1,
46+
global_batch_size: int | None = None,
47+
create_tflops_callback: bool = True,
48+
additional_args: str = "",
49+
):
50+
"""Command for finetuning."""
51+
cmd = (
52+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
53+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
54+
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger "
55+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} "
56+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt} "
57+
f"{'--create-tflops-callback' if create_tflops_callback else ''} "
58+
f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}"
59+
)
60+
return cmd
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved.
3+
# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved.
4+
# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved
5+
# SPDX-License-Identifier: LicenseRef-Apache2
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
import re
20+
21+
import pytest
22+
23+
from bionemo.testing.subprocess_utils import run_command_in_subprocess
24+
25+
from .common import small_training_cmd, small_training_finetune_cmd
26+
27+
28+
def extract_val_losses(log_text: str, n: int):
29+
"""
30+
Extracts validation losses every n-th occurrence (starting at 0).
31+
Iteration index is derived by counting val_loss appearances.
32+
33+
Args:
34+
log_text (str): The log output as a string.
35+
n (int): Interval of occurrences (e.g., n=5 -> get val_loss at 0, 5, 10...).
36+
37+
Returns:
38+
List of tuples: (step, validation_loss_value).
39+
"""
40+
# Regex to capture val_loss values
41+
pattern = re.compile(r"val_loss: ([0-9.]+)")
42+
43+
results = []
44+
for idx, match in enumerate(pattern.finditer(log_text)):
45+
if idx % n == 0: # take every n-th val_loss occurrence
46+
results.append((idx, float(match.group(1))))
47+
48+
return results
49+
50+
51+
@pytest.mark.timeout(2048) # Optional: fail if the test takes too long.
52+
@pytest.mark.slow
53+
@pytest.mark.parametrize("with_peft", [True, False])
54+
def test_train_evo2_finetune_runs(tmp_path, with_peft: bool):
55+
"""
56+
This test runs the `train_evo2` command with mock data in a temporary directory.
57+
It uses the temporary directory provided by pytest as the working directory.
58+
The command is run in a subshell, and we assert that it returns an exit code of 0.
59+
"""
60+
num_steps = 25
61+
val_steps = 10
62+
global_batch_size = 128
63+
64+
# Note: The command assumes that `train_evo2` is in your PATH.
65+
command = small_training_cmd(
66+
tmp_path / "pretrain",
67+
max_steps=num_steps,
68+
val_check=val_steps,
69+
global_batch_size=global_batch_size,
70+
additional_args=" --lr 0.1 ",
71+
)
72+
stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path))
73+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
74+
75+
log_dir = tmp_path / "pretrain" / "evo2"
76+
checkpoints_dir = log_dir / "checkpoints"
77+
tensorboard_dir = log_dir / "dev"
78+
79+
# Check if logs dir exists
80+
assert log_dir.exists(), "Logs folder should exist."
81+
# Check if checkpoints dir exists
82+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
83+
84+
expected_checkpoint_suffix = f"{num_steps * global_batch_size}.0-last"
85+
# Check if any subfolder ends with the expected suffix
86+
matching_subfolders = [
87+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
88+
]
89+
90+
assert matching_subfolders, (
91+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
92+
)
93+
94+
# Check if directory with tensorboard logs exists
95+
assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist."
96+
97+
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
98+
assert len(event_files) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir}"
99+
100+
val_losses = extract_val_losses(stdout_pretrain, val_steps)
101+
102+
for i in range(1, len(val_losses)):
103+
assert val_losses[i][1] <= val_losses[i - 1][1], (
104+
f"Validation loss increased at step {val_losses[i][0]}: {val_losses[i][1]} > {val_losses[i - 1][1]}"
105+
)
106+
107+
# Check if directory with tensorboard logs exists
108+
assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist."
109+
# Recursively search for files with tensorboard logger
110+
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
111+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir}"
112+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
113+
if with_peft:
114+
result_dir = tmp_path / "lora_finetune"
115+
additional_args = "--lora-finetune --lr 0.1 "
116+
else:
117+
result_dir = tmp_path / "finetune"
118+
additional_args = " --lr 0.1 "
119+
120+
command_finetune = small_training_finetune_cmd(
121+
result_dir,
122+
max_steps=num_steps,
123+
val_check=val_steps,
124+
global_batch_size=global_batch_size,
125+
prev_ckpt=matching_subfolders[0],
126+
create_tflops_callback=not with_peft,
127+
additional_args=additional_args,
128+
)
129+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
130+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
131+
132+
log_dir_ft = result_dir / "evo2"
133+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
134+
tensorboard_dir_ft = log_dir_ft / "dev"
135+
136+
# Check if logs dir exists
137+
assert log_dir_ft.exists(), "Logs folder should exist."
138+
# Check if checkpoints dir exists
139+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
140+
141+
expected_checkpoint_suffix = f"{num_steps * global_batch_size}.0-last"
142+
# Check if any subfolder ends with the expected suffix
143+
matching_subfolders_finetune = [
144+
p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
145+
]
146+
147+
assert matching_subfolders_finetune, (
148+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}."
149+
)
150+
151+
# Check if directory with tensorboard logs exists
152+
assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist."
153+
# Recursively search for files with tensorboard logger
154+
event_files_ft = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
155+
assert len(event_files_ft) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir_ft}"
156+
157+
val_losses_ft = extract_val_losses(stdout_finetune, val_steps)
158+
159+
# Check that each validation loss is less than or equal to the previous one
160+
for i in range(1, len(val_losses_ft)):
161+
assert val_losses_ft[i][1] <= val_losses_ft[i - 1][1], (
162+
f"Validation loss increased at step {val_losses_ft[i][0]}: {val_losses_ft[i][1]} > {val_losses_ft[i - 1][1]}"
163+
)
164+
165+
assert len(matching_subfolders_finetune) == 1, "Only one checkpoint subfolder should be found."
166+
167+
# With LoRA, test resuming from a saved LoRA checkpoint
168+
if with_peft:
169+
result_dir = tmp_path / "lora_finetune_resume"
170+
171+
# Resume from LoRA checkpoint
172+
command_resume_finetune = small_training_finetune_cmd(
173+
result_dir,
174+
max_steps=num_steps,
175+
val_check=val_steps,
176+
global_batch_size=global_batch_size,
177+
prev_ckpt=matching_subfolders[0],
178+
create_tflops_callback=False,
179+
additional_args=f"--lora-finetune --lora-checkpoint-path {matching_subfolders_finetune[0]} --lr 0.1 ",
180+
)
181+
stdout_finetune: str = run_command_in_subprocess(command=command_resume_finetune, path=str(tmp_path))
182+
183+
log_dir_ft = result_dir / "evo2"
184+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
185+
tensorboard_dir_ft = log_dir_ft / "dev"
186+
187+
# Check if logs dir exists
188+
assert log_dir_ft.exists(), "Logs folder should exist."
189+
# Check if checkpoints dir exists
190+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
191+
192+
# Recursively search for files with tensorboard logger
193+
event_files_ft = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
194+
assert len(event_files_ft) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir_ft}"
195+
196+
val_losses_ft = extract_val_losses(stdout_finetune, val_steps)
197+
198+
# Check that each validation loss is less than or equal to the previous one
199+
for i in range(1, len(val_losses_ft)):
200+
assert val_losses_ft[i][1] <= val_losses_ft[i - 1][1], (
201+
f"Validation loss increased at step {val_losses_ft[i][0]}: {val_losses_ft[i][1]} > {val_losses_ft[i - 1][1]}"
202+
)

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state
3131
from bionemo.testing.subprocess_utils import run_command_in_subprocess
3232

33+
from .common import small_training_cmd, small_training_finetune_cmd
34+
3335

3436
fp8_available, reason_for_no_fp8 = check_fp8_support()
3537

3638

3739
def run_train_with_std_redirect(args: argparse.Namespace) -> Tuple[str, nl.Trainer]:
38-
"""
39-
Run a function with output capture.
40-
"""
40+
"""Run a function with output capture."""
4141
stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
4242
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
4343
with distributed_model_parallel_state():
@@ -50,28 +50,6 @@ def run_train_with_std_redirect(args: argparse.Namespace) -> Tuple[str, nl.Train
5050
return train_stdout, trainer
5151

5252

53-
def small_training_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
54-
cmd = (
55-
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
56-
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
57-
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
58-
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
59-
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
60-
)
61-
return cmd
62-
63-
64-
def small_training_finetune_cmd(path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = ""):
65-
cmd = (
66-
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
67-
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
68-
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
69-
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
70-
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}"
71-
)
72-
return cmd
73-
74-
7553
def small_training_mamba_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
7654
cmd = (
7755
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "

0 commit comments

Comments
 (0)