March backport changes#31
Merged
Merged
Conversation
that have label_columns but do not actually use them this is nonsensical in the new paradigm but the cruft is already there in the legacy ckpts
we may need to fix this architecturally but as long as there are two poolers they should be inited properly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR significantly simplifies the core modeling architecture by consolidating task-specific models (MLM, Sequence Classification, etc.) into a unified MultiTask architecture. It also introduces a user-friendly, Scanpy-style Python inference API, improves how model configurations are merged from checkpoints, and ensures backward compatibility for older checkpoints.
Key Changes
1. Unified MultiTask Architecture
collation_strategy&ModelingStrategy**: The codebase no longer relies on tracking separate modeling strategies (mlm,sequence_classification,sequence_labeling). All models (SCBert, Performer, Nystromformer, Llama, ModernBert) now exclusively return and instantiate theirForMultiTaskModelingvariants.AutoModelregistrations have been updated so that bothAutoModelForMaskedLMandAutoModelForSequenceClassificationcorrectly resolve to the multitask models.MultiTaskTrainingModule(orSequenceLabelingTrainingModuleifenable_perturbation_metricsis set).2. New Python Inference API
bmfm_targets.inference.inferenceto allow seamless zero-shot predictions directly onAnnDataobjects.bmfm.inference(adata)), handling layer swapping, tokenization, prediction extraction, and appending the resulting embeddings and metadata labels back into theadataobject automatically.3. Robust Config & Checkpoint Merging
SCBertMainConfigto handle config merging intelligently between YAML specifications and checkpoint hyperparameters (_merge_fields,_merge_label_columns, and_merge_configs_from_checkpoint).label_columnsin predict mode if a checkpoint lacks label decoder weights.merge_configsfunction replacesmerge_trainer_configsintask_utils.py.4. Backward Compatibility & Migration
migrate_checkpoint_if_neededto dynamically detect and convert legacy checkpoints (e.g., pure MLM or pure SeqCls) into the expected multitask format at runtime, eliminating the need to manually re-train old models.5. Model & Layer Refinements
pooling_methodinTrainerConfigfrom"pooling_layer"to"first_token", which provides a safer default (especially for pure MLM models where the pooler is untrained).mvc_query_embeddingsacross multiple predictive layers, ensuring the correct query tensors are routed to MVC decoders.output_attentionsandreturn_dictarguments to forward signatures of model wrappers to preserve compatibility with standard PEFT/LoRA integrations.6. Cleanup and Documentation
collation_strategyacross tutorial notebooks, GitHub Actions CI workflows, and.yamlconfig files.README.mdfiles to reflect the simplified data module configuration requirements.