Skip to content

Commit 3fdfada

Browse files
committed
feat(evo2): Add comprehensive fine-tuning support for Evo2 models
This commit introduces a complete fine-tuning framework for Evo2 models, supporting both Hyena and Mamba architectures. Key features: - Support for sequence-level regression and classification tasks - Support for token-level classification tasks - LoRA (Low-Rank Adaptation) support for parameter-efficient fine-tuning - FP8 and BF16 precision support - Modular architecture with separate components for: * Datasets (InMemorySingleValueDataset, InMemoryPerTokenValueDataset) * Loss functions (Regression, Classification, Token Classification) * Model heads (MLP for sequence-level, CNN for token-level) * Fine-tuned models (Evo2FineTuneSeqModel, MambaFineTuneSeqModel, etc.) * Configuration classes (Evo2FineTuneSeqConfig, Evo2FineTuneTokenConfig) * Data loading (Evo2FineTuneDataModule) The implementation follows the ESM2 fine-tuning pattern for consistency across the BioNeMo framework. The script supports distributed training with tensor, pipeline, and context parallelism. Example usage: python finetune_evo2.py --train-data-path train.csv --valid-data-path val.csv --task-type regression --restore-from-checkpoint-path /path/to/evo2_checkpoint --model-type hyena --model-size 7b Documentation includes comprehensive examples for various use cases and detailed parameter explanations. Signed-off-by: My Le <mvle@nvidia.com>
1 parent c498ede commit 3fdfada

File tree

11 files changed

+1881
-0
lines changed

11 files changed

+1881
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Fine-tuning components for Evo2 models."""
17+
18+
from bionemo.evo2.models.finetune.config import (
19+
Evo2FineTuneSeqConfig,
20+
Evo2FineTuneTokenConfig,
21+
)
22+
from bionemo.evo2.models.finetune.datamodule import Evo2FineTuneDataModule
23+
from bionemo.evo2.models.finetune.dataset import (
24+
InMemoryNucleotideDataset,
25+
InMemoryPerTokenValueDataset,
26+
InMemorySingleValueDataset,
27+
)
28+
from bionemo.evo2.models.finetune.loss import (
29+
ClassifierLossReduction,
30+
RegressorLossReduction,
31+
TokenClassifierLossReduction,
32+
)
33+
from bionemo.evo2.models.finetune.sequence_model import (
34+
Evo2FineTuneSeqModel,
35+
MambaFineTuneSeqModel,
36+
)
37+
from bionemo.evo2.models.finetune.token_model import (
38+
Evo2FineTuneTokenModel,
39+
MambaFineTuneTokenModel,
40+
)
41+
42+
43+
__all__ = [
44+
"ClassifierLossReduction",
45+
"Evo2FineTuneDataModule",
46+
"Evo2FineTuneSeqConfig",
47+
"Evo2FineTuneSeqModel",
48+
"Evo2FineTuneTokenConfig",
49+
"Evo2FineTuneTokenModel",
50+
"InMemoryNucleotideDataset",
51+
"InMemoryPerTokenValueDataset",
52+
"InMemorySingleValueDataset",
53+
"MambaFineTuneSeqModel",
54+
"MambaFineTuneTokenModel",
55+
"RegressorLossReduction",
56+
"TokenClassifierLossReduction",
57+
]
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Configuration classes for Evo2 fine-tuning."""
17+
18+
from dataclasses import dataclass, field
19+
from typing import List, Optional, Type
20+
21+
import torch
22+
from megatron.core.models.bert.bert_lm_head import BERTMLMLossWithReduction
23+
from nemo import io as iom
24+
from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS
25+
26+
from bionemo.evo2.models.finetune.loss import (
27+
ClassifierLossReduction,
28+
RegressorLossReduction,
29+
TokenClassifierLossReduction,
30+
)
31+
from bionemo.evo2.models.finetune.sequence_model import (
32+
Evo2FineTuneSeqModel,
33+
MambaFineTuneSeqModel,
34+
)
35+
from bionemo.evo2.models.finetune.token_model import (
36+
Evo2FineTuneTokenModel,
37+
MambaFineTuneTokenModel,
38+
)
39+
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS
40+
from bionemo.llm.model.config import TorchmetricsConfig
41+
42+
43+
@dataclass
44+
class Evo2FineTuneSeqConfig(iom.IOMixinWithGettersSetters):
45+
"""Configuration for sequence-level fine-tuning.
46+
47+
This configuration class sets up the model, loss function, and training
48+
parameters for sequence-level regression or classification tasks.
49+
"""
50+
51+
# Model configuration
52+
model_type: str = "hyena" # "hyena" or "mamba"
53+
model_size: str = "7b"
54+
model_cls: Optional[Type] = None # Will be set based on model_type
55+
initial_ckpt_path: Optional[str] = None
56+
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(
57+
default_factory=lambda: ["regression_head", "classification_head"]
58+
)
59+
60+
# Task configuration
61+
task_type: str = "regression" # "regression" or "classification"
62+
encoder_frozen: bool = True
63+
64+
# MLP head parameters
65+
mlp_ft_dropout: float = 0.25
66+
mlp_hidden_size: int = 256
67+
mlp_target_size: int = 1 # For regression or number of classes for classification
68+
69+
# Training parameters
70+
params_dtype: torch.dtype = torch.bfloat16
71+
pipeline_dtype: torch.dtype = torch.bfloat16
72+
autocast_dtype: torch.dtype = torch.bfloat16
73+
tensor_model_parallel_size: int = 1
74+
pipeline_model_parallel_size: int = 1
75+
76+
# Metrics
77+
train_metric: Optional[TorchmetricsConfig] = None
78+
valid_metric: Optional[TorchmetricsConfig] = None
79+
80+
# Additional transformer config attributes needed
81+
hidden_size: int = field(init=False)
82+
ft_dropout: float = field(init=False)
83+
84+
def __post_init__(self):
85+
"""Post-initialization to set model class and parameters."""
86+
# Set model class based on model type
87+
if self.model_type == "hyena":
88+
self.model_cls = Evo2FineTuneSeqModel
89+
# Get hidden size from model config
90+
if self.model_size in HYENA_MODEL_OPTIONS:
91+
model_config = HYENA_MODEL_OPTIONS[self.model_size]
92+
self.hidden_size = model_config.hidden_size
93+
elif self.model_type == "mamba":
94+
self.model_cls = MambaFineTuneSeqModel
95+
if self.model_size in MAMBA_MODEL_OPTIONS:
96+
model_config = MAMBA_MODEL_OPTIONS[self.model_size]
97+
self.hidden_size = model_config.hidden_size
98+
else:
99+
raise ValueError(f"Unknown model type: {self.model_type}")
100+
101+
self.ft_dropout = self.mlp_ft_dropout
102+
103+
def get_loss_reduction_class(self) -> Type[BERTMLMLossWithReduction]:
104+
"""Get the appropriate loss reduction class based on task type.
105+
106+
Returns:
107+
Loss reduction class for the specified task type
108+
"""
109+
if self.task_type == "regression":
110+
return RegressorLossReduction
111+
elif self.task_type == "classification":
112+
return ClassifierLossReduction
113+
else:
114+
raise ValueError(f"Unknown task type: {self.task_type}")
115+
116+
def configure_model(self, tokenizer=None, pre_process=True, post_process=True):
117+
"""Configure and return the model instance.
118+
119+
Args:
120+
tokenizer: Tokenizer to use
121+
pre_process: Whether this is the first stage in pipeline parallelism
122+
post_process: Whether this is the last stage in pipeline parallelism
123+
124+
Returns:
125+
Configured model instance
126+
"""
127+
# Get base model configuration
128+
if self.model_type == "hyena":
129+
base_config = HYENA_MODEL_OPTIONS[self.model_size]
130+
else:
131+
base_config = MAMBA_MODEL_OPTIONS[self.model_size]
132+
133+
# Merge with fine-tuning config
134+
merged_config = type(base_config)(**base_config.__dict__)
135+
136+
# Add fine-tuning specific attributes
137+
merged_config.task_type = self.task_type
138+
merged_config.encoder_frozen = self.encoder_frozen
139+
merged_config.mlp_ft_dropout = self.mlp_ft_dropout
140+
merged_config.mlp_hidden_size = self.mlp_hidden_size
141+
merged_config.mlp_target_size = self.mlp_target_size
142+
merged_config.ft_dropout = self.ft_dropout
143+
144+
# Set parallelism
145+
merged_config.tensor_model_parallel_size = self.tensor_model_parallel_size
146+
merged_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size
147+
148+
# Create model
149+
return self.model_cls(
150+
config=merged_config,
151+
pre_process=pre_process,
152+
post_process=post_process,
153+
)
154+
155+
156+
@dataclass
157+
class Evo2FineTuneTokenConfig(Evo2FineTuneSeqConfig):
158+
"""Configuration for token-level fine-tuning.
159+
160+
This configuration extends the sequence-level config with additional
161+
parameters specific to token-level classification tasks.
162+
"""
163+
164+
# CNN head parameters
165+
cnn_dropout: float = 0.25
166+
cnn_hidden_size: int = 32
167+
cnn_num_classes: int = 3
168+
169+
# Override skip keys for token-level tasks
170+
initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(
171+
default_factory=lambda: ["token_classification_head"]
172+
)
173+
174+
def __post_init__(self):
175+
"""Post-initialization to set token model class."""
176+
super().__post_init__()
177+
# Set token model class
178+
if self.model_type == "hyena":
179+
self.model_cls = Evo2FineTuneTokenModel
180+
elif self.model_type == "mamba":
181+
self.model_cls = MambaFineTuneTokenModel
182+
183+
def get_loss_reduction_class(self) -> Type[BERTMLMLossWithReduction]:
184+
"""Get the token classification loss reduction class.
185+
186+
Returns:
187+
TokenClassifierLossReduction class
188+
"""
189+
return TokenClassifierLossReduction
190+
191+
def configure_model(self, tokenizer=None, pre_process=True, post_process=True):
192+
"""Configure and return the token-level model instance.
193+
194+
Args:
195+
tokenizer: Tokenizer to use
196+
pre_process: Whether this is the first stage in pipeline parallelism
197+
post_process: Whether this is the last stage in pipeline parallelism
198+
199+
Returns:
200+
Configured model instance
201+
"""
202+
# Get base configuration
203+
model = super().configure_model(tokenizer, pre_process, post_process)
204+
205+
# Add CNN-specific attributes to config
206+
if hasattr(model, "config"):
207+
model.config.cnn_dropout = self.cnn_dropout
208+
model.config.cnn_hidden_size = self.cnn_hidden_size
209+
model.config.cnn_num_classes = self.cnn_num_classes
210+
211+
return model

0 commit comments

Comments
 (0)