Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions ci/benchmarks/partial-conv/evo2_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
scope: partial-conv
time_limit: 14400
key_segments:
# Modify keys to be renamed (str) or excluded (False) from run identifier. By default, all args under script_args are included.
dataset_config: False
dataset_dir: False
data_base_path: False
num_workers: False
limit_val_batches: False
val_check_interval: False
experiment_name: False
workspace: False
restore_from_checkpoint_path: False
activation_checkpoint_layers: False
lora_enabled: False
lr: False
min_lr: False
warmup_steps: False
accumulate_grad_batches: False
clip_grad: False
weight_decay: False
attention_dropout: False
hidden_dropout: False
precision: False
seq_length: False
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace: /workspace/bionemo2
data_base_path: /data/evo2
restore_from_checkpoint_path: checkpoints/nemo2_evo2_1b_8k
nodes: 1
model: evo2
config_name: 1b
num_workers: 1
limit_val_batches: 20
dataset_config: training_data_config.yaml
dataset_dir: preprocessed_data
val_check_interval: 5
seq_length: 8192
warmup_steps: 10
activation_checkpoint_layers: 2
lr: 0.000015
min_lr: 0.0000149
accumulate_grad_batches: 4
max_steps: 1000
gpus: 1
clip_grad: 250
weight_decay: 0.001
attention_dropout: 0.01
hidden_dropout: 0.01
stop_steps: 100
batch_size: 2
variant: finetune
precision: fp8
products:
- variant: finetune
lora_enabled: ""
task: finetune_from_ckpt
experiment_name: evo2-finetune
- variant: lora_finetune
lora_enabled: "--lora-finetune"
task: lora_finetune_from_ckpt
experiment_name: evo2-lora-finetune
script: |-
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY train_${model} \
-d ${data_base_path}/${dataset_config} \
--dataset-dir=${data_base_path}/${dataset_dir} \
--ckpt-dir=${data_base_path}/${restore_from_checkpoint_path} \
${lora_enabled} \
--model-size=${config_name} \
--max-steps=${max_steps} \
--experiment-name=${experiment_name}_${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s \
--lr=${lr} \
--min-lr=${min_lr} \
--warmup-steps=${warmup_steps} \
--result-dir=${tensorboard_dir} \
--micro-batch-size=${batch_size} \
--grad-acc-batches=${accumulate_grad_batches} \
--limit-val-batches=${limit_val_batches} \
--seq-length=${seq_length} \
--clip-grad=${clip_grad} \
--wd=${weight_decay} \
--attention-dropout=${attention_dropout} \
--hidden-dropout=${hidden_dropout} \
--num-layers 4 \
--hybrid-override-pattern 'SDH*' \
--devices=${gpus} \
--num-nodes=${nodes} \
--val-check-interval=${val_check_interval} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name}_${task}_${target} \
--create-tensorboard-logger \
--activation-checkpoint-recompute-num-layers=${activation_checkpoint_layers} \
--disable-checkpointing \
--early-stop-on-step=${stop_steps} \
--garbage-collect-at-inference;
2 changes: 1 addition & 1 deletion ci/benchmarks/partial-conv/evo2_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ script_args:
# See jet/core/configs.py for the specification of the configuration class
workspace: /workspace/bionemo2
data_path: /data/evo2
artefacts_url: https://__token__:${JET_GITLAB_TOKEN}@gitlab-master.nvidia.com/api/v4/projects/180496/packages/pypi/simple
artefacts_url: https://__token__:${{JET_GITLAB_TOKEN}}@gitlab-master.nvidia.com/api/v4/projects/180496/packages/pypi/simple
file_name_wheel: subquadratic-ops
model: evo2
variant: train
Expand Down
10 changes: 10 additions & 0 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings
from bionemo.evo2.run.peft import Evo2LoRA
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime
from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings
from bionemo.evo2.utils.logging.callbacks import TEVCallback
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
Expand Down Expand Up @@ -506,6 +507,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
default=False,
help="Skip checking for NaNs in gradients. Only use this for debugging purposes.",
)
parser.add_argument(
"--garbage-collect-at-inference",
action="store_true",
default=False,
help="Enable CUDA memory cleanup before validation to prevent initialization errors.",
)

recompute_group = parser.add_mutually_exclusive_group(required=False)
recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False)
Expand Down Expand Up @@ -645,6 +652,9 @@ def train(args: argparse.Namespace) -> nl.Trainer:
TEVCallback(),
]

if args.garbage_collect_at_inference:
callbacks.append(GarbageCollectAtInferenceTime())

if args.lora_finetune:
callbacks.append(ModelTransform())
if args.enable_preemption:
Expand Down
36 changes: 36 additions & 0 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import torch
from lightning.pytorch import Callback


class GarbageCollectAtInferenceTime(Callback):
"""Callback to clean up CUDA memory before validation to prevent initialization errors."""

def on_validation_start(self, trainer, pl_module) -> None:
"""Clean up CUDA memory before validation to prevent initialization errors."""
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
torch.cuda.synchronize()
current_device = torch.cuda.current_device()
torch.cuda.set_device(current_device)
torch.cuda.synchronize()
gc.collect()
except Exception as e:
print(f"Warning: CUDA cleanup failed: {e}")