Skip to content

Add model metadata #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

Add model metadata #135

wants to merge 27 commits into from

Conversation

willdumm
Copy link
Contributor

@willdumm willdumm commented May 2, 2025

This PR adds the following values to metadata of saved models:

  • multihit_model_name: expected to be a key in netam.pretrained.PRETRAINED_MULTIHIT_MODELS. Defaults to netam.models.DEFAULT_MULTIHIT_MODEL. For crepes saved without this data, defaults to None.
  • neutral_model_name: expected to be a named pretrained neutral model. Defaults to netam.models.DEFAULT_NEUTRAL_MODEL. For crepes saved without this data, defaults to ThriftyHumV0.2-59.
  • train_timestamp: a UTC timestamp taken at the time of model initialization, if not provided explicitly (e.g. 2025-05-01T22:05). For crepes saved without this data, defaults to old
  • model_type: either dnsm, dasm, or ddsm which must be provided at the time of model instantiation. For crepes saved without this data, defaults to unknown, and will throw warnings.

As hinted at above, I added a dictionary containing pretrained multihit models to netam.pretrained. These models can be accessed by name using netam.pretrained.load_multihit.

Requires companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/132

@willdumm willdumm requested a review from Copilot May 2, 2025 18:55
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends model metadata to include multihit and neutral model settings and integrates these changes across tests and core model functions.

  • Updates tests to load and use multihit models via load_multihit.
  • Extends AbstractBinarySelectionModel and SingleValueBinarySelectionModel with new metadata (including model_type, train_timestamp, neutral_model_name, and multihit_model_name) and adjusts hyperparameter defaults.
  • Enhances framework functions (including add_shm_model_outputs_to_pcp_df and DXSMBurrito initialization) to verify model metadata consistency.

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_simulation.py Uses load_multihit to retrieve multihit model and adds tolerance in allclose check; reassigns train_dataset to val_dataset.
tests/test_multihit.py Updates model instantiation to pass model_type and generate multihit_model_name from model weights.
tests/test_dnsm.py, test_ddsm.py, test_dasm.py, test_ambiguous.py Integrates new parameter model_type and multihit_model into model/dataset creation.
netam/pretrained.py Introduces load_multihit and name_and_multihit_model_match for multihit model handling.
netam/models.py Extends metadata in model constructors and updates reinitialize_weights, to_weights, and from_weights methods.
netam/framework.py Adds default hyperparameter values for legacy models and filters sequences in add_shm_model_outputs_to_pcp_df.
netam/dxsm.py Implements metadata validation with warnings regarding model_type and multihit model consistency.

@willdumm willdumm requested a review from matsen May 2, 2025 19:15
Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few final todos 👍

@@ -66,13 +63,7 @@ def apply_multihit_correction(
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs)
corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp()
reshaped_corrections = corrections[per_parent_hit_class]
unnormalized_corrected_probs = clamp_probability(codon_probs * reshaped_corrections)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a refactor -- the forward method of the multihit model still sets the parent codon probability, but this allows the model to expose a method that adjusts codon probs but does not set the parent codon probability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants