Skip to content

March backport changes#31

Merged
mmdanziger merged 16 commits into
mainfrom
march-backport
Mar 26, 2026
Merged

March backport changes#31
mmdanziger merged 16 commits into
mainfrom
march-backport

Conversation

@mmdanziger
Copy link
Copy Markdown
Collaborator

@mmdanziger mmdanziger commented Mar 8, 2026

This PR significantly simplifies the core modeling architecture by consolidating task-specific models (MLM, Sequence Classification, etc.) into a unified MultiTask architecture. It also introduces a user-friendly, Scanpy-style Python inference API, improves how model configurations are merged from checkpoints, and ensures backward compatibility for older checkpoints.

Key Changes

1. Unified MultiTask Architecture

  • **Removed collation_strategy & ModelingStrategy**: The codebase no longer relies on tracking separate modeling strategies (mlm, sequence_classification, sequence_labeling). All models (SCBert, Performer, Nystromformer, Llama, ModernBert) now exclusively return and instantiate their ForMultiTaskModeling variants.
  • HuggingFace AutoModel registrations have been updated so that both AutoModelForMaskedLM and AutoModelForSequenceClassification correctly resolve to the multitask models.
  • Replaced dynamic training module resolution with direct usage of MultiTaskTrainingModule (or SequenceLabelingTrainingModule if enable_perturbation_metrics is set).

2. New Python Inference API

  • Introduced bmfm_targets.inference.inference to allow seamless zero-shot predictions directly on AnnData objects.
  • This feature acts similarly to a Scanpy tool (bmfm.inference(adata)), handling layer swapping, tokenization, prediction extraction, and appending the resulting embeddings and metadata labels back into the adata object automatically.

3. Robust Config & Checkpoint Merging

  • Refactored SCBertMainConfig to handle config merging intelligently between YAML specifications and checkpoint hyperparameters (_merge_fields, _merge_label_columns, and _merge_configs_from_checkpoint).
  • Clarified precedence: Checkpoint configs are largely authoritative during inference/prediction to match trained weights, while YAML configs take precedence/augment during training.
  • Handles edge cases gracefully, such as clearing label_columns in predict mode if a checkpoint lacks label decoder weights.
  • A generalized merge_configs function replaces merge_trainer_configs in task_utils.py.

4. Backward Compatibility & Migration

  • Added migrate_checkpoint_if_needed to dynamically detect and convert legacy checkpoints (e.g., pure MLM or pure SeqCls) into the expected multitask format at runtime, eliminating the need to manually re-train old models.

5. Model & Layer Refinements

  • Pooling Defaults: Changed the default pooling_method in TrainerConfig from "pooling_layer" to "first_token", which provides a safer default (especially for pure MLM models where the pooler is untrained).
  • MVC Embeddings: Fixed missing dictionary unpacking for mvc_query_embeddings across multiple predictive layers, ensuring the correct query tensors are routed to MVC decoders.
  • PEFT Support: Added output_attentions and return_dict arguments to forward signatures of model wrappers to preserve compatibility with standard PEFT/LoRA integrations.

6. Cleanup and Documentation

  • Removed all references to collation_strategy across tutorial notebooks, GitHub Actions CI workflows, and .yaml config files.
  • Updated README.md files to reflect the simplified data module configuration requirements.

@mmdanziger mmdanziger marked this pull request as ready for review March 8, 2026 15:16
@mmdanziger mmdanziger merged commit 0a944f5 into main Mar 26, 2026
8 checks passed
@mmdanziger mmdanziger deleted the march-backport branch March 26, 2026 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant