Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/dna_test-cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ jobs:
- name: Test MLM+SNPGENOME finetune
run: bmfm-targets-run -cn dna_finetune_train_and_test_config -cd run working_dir=/tmp/dna/ input_directory=$MY_DATA_DIR checkpoint=ibm-research/biomed.dna.snp.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
- name: Test MLM+REFGENOME inference
run: bmfm-targets-run -cn dna_predict -cd run working_dir=/tmp/dna/ input_directory=$MY_DATA_DIR input_filename=test.csv data_module.collation_strategy=language_modeling checkpoint=ibm-research/biomed.dna.ref.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
run: bmfm-targets-run -cn dna_predict -cd run working_dir=/tmp/dna/ input_directory=$MY_DATA_DIR input_filename=test.csv checkpoint=ibm-research/biomed.dna.ref.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
- name: Test MLM+REFGENOME inference on whole dataset
run: bmfm-targets-run -cn dna_predict -cd run working_dir=/tmp/dna/ input_directory=$MY_DATA_DIR data_module.collation_strategy=language_modeling checkpoint=ibm-research/biomed.dna.ref.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
run: bmfm-targets-run -cn dna_predict -cd run working_dir=/tmp/dna/ input_directory=$MY_DATA_DIR checkpoint=ibm-research/biomed.dna.ref.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
# - name: Test MLM+SNPGENOME inference
# run: bmfm-targets-run -cn dna_predict working_dir=/tmp/dna/ input_directory=$MY_DATA_FILE data_module.collation_strategy=sequence_classification checkpoint=ibm-research/biomed.dna.snp.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
# run: bmfm-targets-run -cn dna_predict working_dir=/tmp/dna/ input_directory=$MY_DATA_FILE checkpoint=ibm-research/biomed.dna.snp.modernbert.113m.v1 accelerator=cpu && rm -rf /tmp/dna/*
4 changes: 2 additions & 2 deletions .github/workflows/test-cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
- name: Test MLM+RDA (inference, finetune, pretrain)
run: |
# inference
bmfm-targets-run -cd run -cn predict input_file=$MY_DATA_FILE working_dir=/tmp/runs data_module.collation_strategy=language_modeling data_module.log_normalize_transform=false data_module.max_length=256 checkpoint=ibm-research/biomed.rna.bert.110m.mlm.rda.v1 task.accelerator=cpu task.precision=32
bmfm-targets-run -cd run -cn predict input_file=$MY_DATA_FILE working_dir=/tmp/runs data_module.log_normalize_transform=false data_module.max_length=256 checkpoint=ibm-research/biomed.rna.bert.110m.mlm.rda.v1 task.accelerator=cpu task.precision=32
rm -rf /tmp/runs/*
# finetune
bmfm-targets-run -cd run -cn finetune label_column_name=celltype split_column_name=null input_file=$MY_DATA_FILE working_dir=/tmp/runs data_module.log_normalize_transform=false checkpoint=ibm-research/biomed.rna.bert.110m.mlm.rda.v1 accelerator=cpu data_module.max_length=256 max_epochs=2 val_check_interval=null data_module.num_workers=0
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
- name: Test WCED (inference, finetune)
run: |
# inference
bmfm-targets-run -cd run -cn predict input_file=$MY_DATA_FILE working_dir=/tmp/runs data_module.collation_strategy=language_modeling data_module.max_length=256 checkpoint=ibm-research/biomed.rna.bert.110m.wced.v1 task.accelerator=cpu task.precision=32 data_module.log_normalize_transform=false
bmfm-targets-run -cd run -cn predict input_file=$MY_DATA_FILE working_dir=/tmp/runs data_module.max_length=256 checkpoint=ibm-research/biomed.rna.bert.110m.wced.v1 task.accelerator=cpu task.precision=32 data_module.log_normalize_transform=false
rm -rf /tmp/runs/*
# finetune
bmfm-targets-run -cd run -cn finetune label_column_name=celltype split_column_name=null input_file=$MY_DATA_FILE working_dir=/tmp/runs checkpoint=ibm-research/biomed.rna.bert.110m.wced.v1 accelerator=cpu data_module.max_length=256 max_epochs=2 val_check_interval=null data_module.num_workers=0 data_module.log_normalize_transform=false
Expand Down
10 changes: 0 additions & 10 deletions bmfm_targets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ fields:

The data module settings selects which data to use and also what kind of task will be executed. The dataset can refer to one of the packaged datasets or a generic h5ad file based dataset.

By selecting a `collation_strategy`, the same dataset can be used for different tasks.
The recommended option is `"multitask"` which supports masked language modeling and sequence classification/regression simoultaniously.
To perform just MLM `"language_modeling"` can be used and `"sequence_classification"` for just sequence classification.
In "multitask" mode, any combination of field and label_column losses is supported.

The `"sequence_labeling"` option is only supported for datasets with paired "control" and perturbed cells.

Full documentation can be found in [data_module.py](./training/data_module.py)

#### Package datasets
Expand All @@ -71,7 +64,6 @@ data_module:
_target_: bmfm_targets.datasets.zheng68k.Zheng68kDataModule
_partial_: true
num_workers: 8
collation_strategy: "multitask"
batch_size: 20
max_length: 512
pad_to_multiple_of: 16
Expand All @@ -95,7 +87,6 @@ data_module:
_partial_: true
num_workers: 8
mlm: true
collation_strategy: "language_modeling"
batch_size: 2
transform_datasets: false
dataset_kwargs:
Expand All @@ -117,7 +108,6 @@ data_module:
_target_: bmfm_targets.training.data_module.DataModule
_partial_: true
num_workers: 8
collation_strategy: "sequence_classification"
batch_size: 2
transform_datasets: true
num_workers: 8
Expand Down
4 changes: 4 additions & 0 deletions bmfm_targets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
(C) Copyright IBM Corp. 2023
ALL RIGHTS RESERVED
"""

from bmfm_targets.inference import inference

__all__ = ["inference"]
2 changes: 0 additions & 2 deletions bmfm_targets/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
SCNystromformerConfig,
SCModelConfigBase,
SCModernBertConfig,
ModelingStrategy,
)
from .main_config import SCBertMainConfig, SCBertMainHydraConfigSchema
from .dataset_config import SplitEnum
Expand All @@ -48,5 +47,4 @@
"PreTrainedEmbeddingConfig",
"LoraConfigWrapper",
"SplitEnum",
"ModelingStrategy",
]
143 changes: 128 additions & 15 deletions bmfm_targets/config/main_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from bmfm_targets.config import (
FieldInfo,
LabelColumnInfo,
TestTaskConfig,
PredictTaskConfig,
TokenizerConfig,
TrainerConfig,
TrainingTaskConfig,
Expand Down Expand Up @@ -106,7 +106,7 @@ def _instantiate_recursive(cls, val):
return val

def complete_config(self):
self._load_missing_configs_from_ckpt()
self._merge_configs_from_checkpoint()
tokenizer = self._load_tokenizer_from_cfg()

if isinstance(self.data_module, functools.partial):
Expand Down Expand Up @@ -159,29 +159,142 @@ def _update_label_columns(self):
):
lc.update_n_unique_values(self.data_module.label_dict)

def _load_missing_configs_from_ckpt(self):
def _merge_fields(
self,
ckpt_fields: list[FieldInfo],
yaml_fields: list[FieldInfo] | None,
is_training: bool,
) -> list[FieldInfo]:
"""Merge fields: checkpoint wins for existing, new YAML fields added in training."""
if not ckpt_fields:
return yaml_fields if yaml_fields else []

# Test/predict: checkpoint only
if not is_training:
if yaml_fields and yaml_fields != ckpt_fields:
logger.warning(
"fields: Ignoring YAML config in test/predict mode. "
"Using checkpoint fields to match trained model."
)
return ckpt_fields

# If no YAML fields during training, just use checkpoint fields
if not yaml_fields:
logger.info("fields: No YAML fields specified, using checkpoint fields")
return ckpt_fields

# Training: merge - checkpoint fields win, new YAML fields added
ckpt_field_names = {f.field_name for f in ckpt_fields}
yaml_field_names = {f.field_name for f in yaml_fields}

merged = list(ckpt_fields)
new_fields = [f for f in yaml_fields if f.field_name not in ckpt_field_names]
if new_fields:
logger.info(
f"fields: Adding {len(new_fields)} new fields from YAML: "
f"{[f.field_name for f in new_fields]}"
)
merged.extend(new_fields)

conflicting = yaml_field_names & ckpt_field_names
if conflicting:
logger.info(
f"fields: Using checkpoint config for existing fields: {sorted(conflicting)}"
)

return merged

def _merge_label_columns(
self,
ckpt_cols: list[LabelColumnInfo],
yaml_cols: list[LabelColumnInfo],
is_training: bool,
) -> list[LabelColumnInfo]:
"""Merge label_columns: checkpoint-authoritative for test/predict, YAML for training."""
if not ckpt_cols:
return yaml_cols

if not is_training:
# Explicit empty list in YAML = embedding-only mode (cross-dataset)
if yaml_cols is not None and len(yaml_cols) == 0:
logger.info(
"label_columns: Using empty list from YAML (embedding-only mode)"
)
return yaml_cols

# Use checkpoint columns (normal same-dataset prediction)
if yaml_cols and yaml_cols != ckpt_cols:
logger.warning(
"label_columns: Ignoring YAML config in test/predict mode. "
"Using checkpoint label_columns to match trained model."
)
return ckpt_cols

# Training: YAML-authoritative
return yaml_cols if yaml_cols else ckpt_cols

def _merge_configs_from_checkpoint(self):
"""Merge configs from checkpoint with YAML configs based on task type."""
from bmfm_targets.models import download_ckpt_from_huggingface

checkpoint = self._get_checkpoint()
if not checkpoint:
return

if not os.path.isfile(checkpoint):
checkpoint = download_ckpt_from_huggingface(checkpoint)

ckpt_dict = torch.load(checkpoint, map_location="cpu", weights_only=False)
if self.fields is None:
self.fields = ckpt_dict["hyper_parameters"]["model_config"].fields
if isinstance(self.task, TestTaskConfig) and self.label_columns is None:
self.label_columns = getattr(
ckpt_dict["hyper_parameters"]["model_config"], "label_columns", None
)
ckpt_hyper = ckpt_dict["hyper_parameters"]

@staticmethod
def _instantiate_model_config(partial_model, data_module, fields, label_columns):
return partial_model(
fields=fields,
label_columns=label_columns,
pad_token_id=data_module.tokenizer.pad_token_id,
is_training = isinstance(self.task, TrainingTaskConfig)
is_predict = isinstance(self.task, PredictTaskConfig)

# Access checkpoint hyperparameters (can be dict or object)
ckpt_model_config = getattr(ckpt_hyper, "model_config", None) or ckpt_hyper.get(
"model_config"
)
ckpt_trainer_config = getattr(
ckpt_hyper, "trainer_config", None
) or ckpt_hyper.get("trainer_config")

# Merge fields (always merge, needed for model instantiation)
if ckpt_model_config:
ckpt_fields = getattr(ckpt_model_config, "fields", None)
if ckpt_fields is None and hasattr(ckpt_model_config, "get"):
ckpt_fields = ckpt_model_config.get("fields")
if ckpt_fields:
self.fields = self._merge_fields(ckpt_fields, self.fields, is_training)

# Merge label_columns (skip in predict mode for cross-dataset prediction)
if not is_predict and ckpt_model_config:
ckpt_label_columns = getattr(ckpt_model_config, "label_columns", None)
if ckpt_label_columns is None and hasattr(ckpt_model_config, "get"):
ckpt_label_columns = ckpt_model_config.get("label_columns")
if ckpt_label_columns:
self.label_columns = self._merge_label_columns(
ckpt_label_columns, self.label_columns, is_training
)
elif is_predict and ckpt_model_config:
# In predict mode, check if checkpoint has label decoder weights
# If not, explicitly clear label_columns to prevent model instantiation with label heads
has_label_weights = any(
"label_predictions" in k for k in ckpt_dict["state_dict"].keys()
)
if not has_label_weights:
ckpt_label_columns = getattr(ckpt_model_config, "label_columns", None)
if ckpt_label_columns is None and hasattr(ckpt_model_config, "get"):
ckpt_label_columns = ckpt_model_config.get("label_columns")
if ckpt_label_columns:
logger.warning(
f"fields: Checkpoint has {len(ckpt_label_columns)} label_columns in config "
"but no label decoder weights. Clearing label_columns for predict mode."
)
self.label_columns = []

# Merge trainer config
if ckpt_trainer_config and self.trainer:
self.trainer = self.trainer.merge_from_checkpoint(ckpt_trainer_config)

@staticmethod
def _instantiate_and_setup_data_module(
Expand Down
17 changes: 0 additions & 17 deletions bmfm_targets/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,6 @@
)


class ModelingStrategy(str, Enum):
"""
Enumeration of all modeling strategies.

Members:
MLM: mask language modeling
SEQUENCE_CLASSIFICATION: sequence-level classification
SEQUENCE_LABELING: token-level labeling
MULTITASK: multitask learning
"""

SEQUENCE_LABELING = "sequence_labeling"
MLM = "mlm"
MULTITASK = "multitask"
SEQUENCE_CLASSIFICATION = "sequence_classification"


class SCModelConfigBase(PretrainedConfig):
def to_dict(self):
"""Serializes class to a Python dictionary."""
Expand Down
36 changes: 30 additions & 6 deletions bmfm_targets/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import Any, Literal

import pytorch_lightning.callbacks
from peft import (
LoraConfig as PeftLoraConfig,
)
from peft import LoraConfig as PeftLoraConfig


def default_callbacks():
Expand Down Expand Up @@ -184,8 +182,10 @@ class TrainerConfig:
focal_gamma (float) defaults to 2.0
pooling_method: which pooling method to use to generate embeddings (if requested)
the options are "pooling_layer", "first_token" or "mean_pooling".
Note that "pooling_layer" will only give meaningful output with sequence
classification.
Default is "first_token" which is always valid.
"pooling_layer" should only be used when the pooler has been trained
(e.g., from sequence classification checkpoints with label prediction tasks).
Pure MLM checkpoints have untrained poolers and should use "first_token".
batch_prediction_behavior (str|int|None): whether to track batch_predictions, track and dump or
do not track at all.
"dump" - dump every batch to disk (uses lots of hd space and lots of memory)
Expand All @@ -197,21 +197,45 @@ class TrainerConfig:

"""

def merge_from_checkpoint(
self, checkpoint_trainer: "TrainerConfig | None"
) -> "TrainerConfig":
"""
Merge trainer config with checkpoint trainer config.

Special handling for losses: if yaml.losses is None, inherit from checkpoint.
"""
if not checkpoint_trainer:
return self

if self.losses is None and checkpoint_trainer.losses:
from transformers.utils import logging

logger = logging.get_logger(__name__)
logger.info(
f"Inheriting {len(checkpoint_trainer.losses)} losses from checkpoint"
)
self.losses = checkpoint_trainer.losses

return self

betas: tuple[float, float] = (0.9, 0.99)
epsilon: float = 1e-8
learning_rate: float = 1e-4
losses: list[Any] | None = None
lr_decay_steps: int | None = None
warmup_steps: int = 0
weight_decay: float | None = None
pooling_method: str | int = "pooling_layer"
pooling_method: str | int = "first_token"
batch_prediction_behavior: str | int | None = None
lora_config: Any = None
enable_perturbation_metrics: bool = False

def __setstate__(self, state):
# Handle removed fields from old checkpoints
state.pop("metrics", None)
state.pop("batch_size", None)
state.setdefault("enable_perturbation_metrics", False)
self.__dict__.update(state)

def get_lora_config(self) -> LoraConfigWrapper:
Expand Down
2 changes: 0 additions & 2 deletions bmfm_targets/datasets/SNPdb/streaming_snp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def collate_fn(self):
max_length=self.max_length,
padding=self.padding,
truncation=self.truncation,
collation_strategy=self.collation_strategy,
)


Expand Down Expand Up @@ -280,7 +279,6 @@ def collate_fn(self):
max_length=self.max_length,
padding=self.padding,
truncation=self.truncation,
collation_strategy=self.collation_strategy,
)


Expand Down
Loading
Loading