-
Notifications
You must be signed in to change notification settings - Fork 0
Add model metadata #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add model metadata #135
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends model metadata to include multihit and neutral model settings and integrates these changes across tests and core model functions.
- Updates tests to load and use multihit models via load_multihit.
- Extends AbstractBinarySelectionModel and SingleValueBinarySelectionModel with new metadata (including model_type, train_timestamp, neutral_model_name, and multihit_model_name) and adjusts hyperparameter defaults.
- Enhances framework functions (including add_shm_model_outputs_to_pcp_df and DXSMBurrito initialization) to verify model metadata consistency.
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tests/test_simulation.py | Uses load_multihit to retrieve multihit model and adds tolerance in allclose check; reassigns train_dataset to val_dataset. |
tests/test_multihit.py | Updates model instantiation to pass model_type and generate multihit_model_name from model weights. |
tests/test_dnsm.py, test_ddsm.py, test_dasm.py, test_ambiguous.py | Integrates new parameter model_type and multihit_model into model/dataset creation. |
netam/pretrained.py | Introduces load_multihit and name_and_multihit_model_match for multihit model handling. |
netam/models.py | Extends metadata in model constructors and updates reinitialize_weights, to_weights, and from_weights methods. |
netam/framework.py | Adds default hyperparameter values for legacy models and filters sequences in add_shm_model_outputs_to_pcp_df. |
netam/dxsm.py | Implements metadata validation with warnings regarding model_type and multihit model consistency. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few final todos 👍
@@ -66,13 +63,7 @@ def apply_multihit_correction( | |||
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) | |||
corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp() | |||
reshaped_corrections = corrections[per_parent_hit_class] | |||
unnormalized_corrected_probs = clamp_probability(codon_probs * reshaped_corrections) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a refactor -- the forward method of the multihit model still sets the parent codon probability, but this allows the model to expose a method that adjusts codon probs but does not set the parent codon probability.
This PR adds the following values to metadata of saved models:
multihit_model_name
: expected to be a key innetam.pretrained.PRETRAINED_MULTIHIT_MODELS
. Defaults tonetam.models.DEFAULT_MULTIHIT_MODEL
. For crepes saved without this data, defaults toNone
.neutral_model_name
: expected to be a named pretrained neutral model. Defaults tonetam.models.DEFAULT_NEUTRAL_MODEL
. For crepes saved without this data, defaults toThriftyHumV0.2-59
.train_timestamp
: a UTC timestamp taken at the time of model initialization, if not provided explicitly (e.g.2025-05-01T22:05
). For crepes saved without this data, defaults toold
model_type
: eitherdnsm
,dasm
, orddsm
which must be provided at the time of model instantiation. For crepes saved without this data, defaults tounknown
, and will throw warnings.As hinted at above, I added a dictionary containing pretrained multihit models to
netam.pretrained
. These models can be accessed by name usingnetam.pretrained.load_multihit
.Requires companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/132