|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-Apache2 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Configuration classes for Evo2 fine-tuning.""" |
| 17 | + |
| 18 | +from dataclasses import dataclass, field |
| 19 | +from typing import List, Optional, Type |
| 20 | + |
| 21 | +import torch |
| 22 | +from megatron.core.models.bert.bert_lm_head import BERTMLMLossWithReduction |
| 23 | +from nemo import io as iom |
| 24 | +from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS |
| 25 | + |
| 26 | +from bionemo.evo2.models.finetune.loss import ( |
| 27 | + ClassifierLossReduction, |
| 28 | + RegressorLossReduction, |
| 29 | + TokenClassifierLossReduction, |
| 30 | +) |
| 31 | +from bionemo.evo2.models.finetune.sequence_model import ( |
| 32 | + Evo2FineTuneSeqModel, |
| 33 | + MambaFineTuneSeqModel, |
| 34 | +) |
| 35 | +from bionemo.evo2.models.finetune.token_model import ( |
| 36 | + Evo2FineTuneTokenModel, |
| 37 | + MambaFineTuneTokenModel, |
| 38 | +) |
| 39 | +from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS |
| 40 | +from bionemo.llm.model.config import TorchmetricsConfig |
| 41 | + |
| 42 | + |
| 43 | +@dataclass |
| 44 | +class Evo2FineTuneSeqConfig(iom.IOMixinWithGettersSetters): |
| 45 | + """Configuration for sequence-level fine-tuning. |
| 46 | +
|
| 47 | + This configuration class sets up the model, loss function, and training |
| 48 | + parameters for sequence-level regression or classification tasks. |
| 49 | + """ |
| 50 | + |
| 51 | + # Model configuration |
| 52 | + model_type: str = "hyena" # "hyena" or "mamba" |
| 53 | + model_size: str = "7b" |
| 54 | + model_cls: Optional[Type] = None # Will be set based on model_type |
| 55 | + initial_ckpt_path: Optional[str] = None |
| 56 | + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field( |
| 57 | + default_factory=lambda: ["regression_head", "classification_head"] |
| 58 | + ) |
| 59 | + |
| 60 | + # Task configuration |
| 61 | + task_type: str = "regression" # "regression" or "classification" |
| 62 | + encoder_frozen: bool = True |
| 63 | + |
| 64 | + # MLP head parameters |
| 65 | + mlp_ft_dropout: float = 0.25 |
| 66 | + mlp_hidden_size: int = 256 |
| 67 | + mlp_target_size: int = 1 # For regression or number of classes for classification |
| 68 | + |
| 69 | + # Training parameters |
| 70 | + params_dtype: torch.dtype = torch.bfloat16 |
| 71 | + pipeline_dtype: torch.dtype = torch.bfloat16 |
| 72 | + autocast_dtype: torch.dtype = torch.bfloat16 |
| 73 | + tensor_model_parallel_size: int = 1 |
| 74 | + pipeline_model_parallel_size: int = 1 |
| 75 | + |
| 76 | + # Metrics |
| 77 | + train_metric: Optional[TorchmetricsConfig] = None |
| 78 | + valid_metric: Optional[TorchmetricsConfig] = None |
| 79 | + |
| 80 | + # Additional transformer config attributes needed |
| 81 | + hidden_size: int = field(init=False) |
| 82 | + ft_dropout: float = field(init=False) |
| 83 | + |
| 84 | + def __post_init__(self): |
| 85 | + """Post-initialization to set model class and parameters.""" |
| 86 | + # Set model class based on model type |
| 87 | + if self.model_type == "hyena": |
| 88 | + self.model_cls = Evo2FineTuneSeqModel |
| 89 | + # Get hidden size from model config |
| 90 | + if self.model_size in HYENA_MODEL_OPTIONS: |
| 91 | + model_config = HYENA_MODEL_OPTIONS[self.model_size] |
| 92 | + self.hidden_size = model_config.hidden_size |
| 93 | + elif self.model_type == "mamba": |
| 94 | + self.model_cls = MambaFineTuneSeqModel |
| 95 | + if self.model_size in MAMBA_MODEL_OPTIONS: |
| 96 | + model_config = MAMBA_MODEL_OPTIONS[self.model_size] |
| 97 | + self.hidden_size = model_config.hidden_size |
| 98 | + else: |
| 99 | + raise ValueError(f"Unknown model type: {self.model_type}") |
| 100 | + |
| 101 | + self.ft_dropout = self.mlp_ft_dropout |
| 102 | + |
| 103 | + def get_loss_reduction_class(self) -> Type[BERTMLMLossWithReduction]: |
| 104 | + """Get the appropriate loss reduction class based on task type. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + Loss reduction class for the specified task type |
| 108 | + """ |
| 109 | + if self.task_type == "regression": |
| 110 | + return RegressorLossReduction |
| 111 | + elif self.task_type == "classification": |
| 112 | + return ClassifierLossReduction |
| 113 | + else: |
| 114 | + raise ValueError(f"Unknown task type: {self.task_type}") |
| 115 | + |
| 116 | + def configure_model(self, tokenizer=None, pre_process=True, post_process=True): |
| 117 | + """Configure and return the model instance. |
| 118 | +
|
| 119 | + Args: |
| 120 | + tokenizer: Tokenizer to use |
| 121 | + pre_process: Whether this is the first stage in pipeline parallelism |
| 122 | + post_process: Whether this is the last stage in pipeline parallelism |
| 123 | +
|
| 124 | + Returns: |
| 125 | + Configured model instance |
| 126 | + """ |
| 127 | + # Get base model configuration |
| 128 | + if self.model_type == "hyena": |
| 129 | + base_config = HYENA_MODEL_OPTIONS[self.model_size] |
| 130 | + else: |
| 131 | + base_config = MAMBA_MODEL_OPTIONS[self.model_size] |
| 132 | + |
| 133 | + # Merge with fine-tuning config |
| 134 | + merged_config = type(base_config)(**base_config.__dict__) |
| 135 | + |
| 136 | + # Add fine-tuning specific attributes |
| 137 | + merged_config.task_type = self.task_type |
| 138 | + merged_config.encoder_frozen = self.encoder_frozen |
| 139 | + merged_config.mlp_ft_dropout = self.mlp_ft_dropout |
| 140 | + merged_config.mlp_hidden_size = self.mlp_hidden_size |
| 141 | + merged_config.mlp_target_size = self.mlp_target_size |
| 142 | + merged_config.ft_dropout = self.ft_dropout |
| 143 | + |
| 144 | + # Set parallelism |
| 145 | + merged_config.tensor_model_parallel_size = self.tensor_model_parallel_size |
| 146 | + merged_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size |
| 147 | + |
| 148 | + # Create model |
| 149 | + return self.model_cls( |
| 150 | + config=merged_config, |
| 151 | + pre_process=pre_process, |
| 152 | + post_process=post_process, |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +@dataclass |
| 157 | +class Evo2FineTuneTokenConfig(Evo2FineTuneSeqConfig): |
| 158 | + """Configuration for token-level fine-tuning. |
| 159 | +
|
| 160 | + This configuration extends the sequence-level config with additional |
| 161 | + parameters specific to token-level classification tasks. |
| 162 | + """ |
| 163 | + |
| 164 | + # CNN head parameters |
| 165 | + cnn_dropout: float = 0.25 |
| 166 | + cnn_hidden_size: int = 32 |
| 167 | + cnn_num_classes: int = 3 |
| 168 | + |
| 169 | + # Override skip keys for token-level tasks |
| 170 | + initial_ckpt_skip_keys_with_these_prefixes: List[str] = field( |
| 171 | + default_factory=lambda: ["token_classification_head"] |
| 172 | + ) |
| 173 | + |
| 174 | + def __post_init__(self): |
| 175 | + """Post-initialization to set token model class.""" |
| 176 | + super().__post_init__() |
| 177 | + # Set token model class |
| 178 | + if self.model_type == "hyena": |
| 179 | + self.model_cls = Evo2FineTuneTokenModel |
| 180 | + elif self.model_type == "mamba": |
| 181 | + self.model_cls = MambaFineTuneTokenModel |
| 182 | + |
| 183 | + def get_loss_reduction_class(self) -> Type[BERTMLMLossWithReduction]: |
| 184 | + """Get the token classification loss reduction class. |
| 185 | +
|
| 186 | + Returns: |
| 187 | + TokenClassifierLossReduction class |
| 188 | + """ |
| 189 | + return TokenClassifierLossReduction |
| 190 | + |
| 191 | + def configure_model(self, tokenizer=None, pre_process=True, post_process=True): |
| 192 | + """Configure and return the token-level model instance. |
| 193 | +
|
| 194 | + Args: |
| 195 | + tokenizer: Tokenizer to use |
| 196 | + pre_process: Whether this is the first stage in pipeline parallelism |
| 197 | + post_process: Whether this is the last stage in pipeline parallelism |
| 198 | +
|
| 199 | + Returns: |
| 200 | + Configured model instance |
| 201 | + """ |
| 202 | + # Get base configuration |
| 203 | + model = super().configure_model(tokenizer, pre_process, post_process) |
| 204 | + |
| 205 | + # Add CNN-specific attributes to config |
| 206 | + if hasattr(model, "config"): |
| 207 | + model.config.cnn_dropout = self.cnn_dropout |
| 208 | + model.config.cnn_hidden_size = self.cnn_hidden_size |
| 209 | + model.config.cnn_num_classes = self.cnn_num_classes |
| 210 | + |
| 211 | + return model |
0 commit comments