Skip to content

Add model metadata #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4943949
neutral model and timestamp
willdumm Apr 30, 2025
762d618
clearer metadata naming
willdumm Apr 30, 2025
4d14076
WIP start on multihit in metadata
willdumm Apr 30, 2025
1581999
purge multihit crepe prefix
willdumm May 1, 2025
74de8c3
fix tests
willdumm May 1, 2025
4ffec0b
convert checks to warnings
willdumm May 2, 2025
62b7464
new warning
willdumm May 2, 2025
756d35e
immutable default arguments
willdumm May 2, 2025
f167092
fix integer casting bug
willdumm May 3, 2025
dbc6e27
add new pretrained model
willdumm May 5, 2025
158894f
format
willdumm May 5, 2025
3ec33ce
update backward compat test for correct neutral model
willdumm May 6, 2025
0768cda
enable simulation on ambiguous sequences
willdumm May 7, 2025
052ee0f
Path management better practice
matsen May 8, 2025
21ccdfb
fix zero branch length sampling
willdumm May 8, 2025
abd2fb7
tweaks for numerical stability
willdumm May 9, 2025
1790d07
format
willdumm May 9, 2025
068afba
slight multihit refactor
willdumm May 9, 2025
09a53cf
incremental updates
willdumm May 13, 2025
081a2e4
Woohoo
willdumm May 16, 2025
38f94cc
failing test
willdumm May 18, 2025
58153cf
a failing test
willdumm May 20, 2025
f86299f
more comprehensive test
willdumm May 22, 2025
fb8dae0
better tests and fix sim masking
willdumm May 24, 2025
40f8c6a
experiment with codon masking
willdumm May 24, 2025
e6c91ae
dnsm works as well as dasm
willdumm May 28, 2025
ab20c66
testing ambiguities...
willdumm May 28, 2025
3715be5
fix tests
willdumm May 29, 2025
e17756a
format
willdumm May 29, 2025
cdbec13
cleanup
willdumm May 30, 2025
0152fad
respond to Erick's comments
willdumm Jun 2, 2025
741231d
fix docformatter
willdumm Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ test:
pytest tests

format:
docformatter --in-place --black --recursive netam tests
docformatter --in-place --black --recursive netam tests || echo "Docformatter made changes"
black netam tests

checkformat:
docformatter --check --black --recursive netam tests
black --check netam tests

checktodo:
grep -rq --include=\*.{py,Snakemake} "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0
grep -rq --include="*.py" --include="*.Snakemake" "TODO" . && echo "TODOs found" && exit 1 || echo "No TODOs found" && exit 0

lint:
flake8 . --max-complexity=30 --ignore=E731,W503,E402,F541,E501,E203,E266 --statistics --exclude=_ignore
Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 17 additions & 0 deletions netam/codon_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,20 @@ def build_stop_codon_indicator_tensor():
STOP_CODON_INDICATOR = build_stop_codon_indicator_tensor()

STOP_CODON_ZAPPER = STOP_CODON_INDICATOR * -BIG

# We build a table that will allow us to look up the amino acid index
# from the codon indices. Argmax gets the aa index.
AA_IDX_FROM_CODON = CODON_AA_INDICATOR_MATRIX.argmax(dim=1).view(4, 4, 4)


def aa_idxs_of_codon_idxs(codon_idx_tensor):
"""Translate an unflattened codon index tensor of shape (L, 3) to a tensor of amino
acid indices."""
# Get the amino acid index for each parent codon.
return AA_IDX_FROM_CODON[
(
codon_idx_tensor[:, 0],
codon_idx_tensor[:, 1],
codon_idx_tensor[:, 2],
)
]
4 changes: 4 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def clamp_probability(x: Tensor) -> Tensor:
return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB))


def clamp_probability_above_only(x: Tensor) -> Tensor:
return torch.clamp(x, max=(1.0 - SMALL_PROB))


def clamp_log_probability(x: Tensor) -> Tensor:
return torch.clamp(x, max=np.log(1.0 - SMALL_PROB))

Expand Down
4 changes: 2 additions & 2 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F

from netam.common import BIG
from netam.common import BIG, SMALL_PROB
from netam.dxsm import DXSMDataset, DXSMBurrito
import netam.molevol as molevol

Expand Down Expand Up @@ -77,7 +77,7 @@ def update_neutral_probs(self):
pad_len = self.max_aa_seq_len - neutral_codon_probs.shape[0]
if pad_len > 0:
neutral_codon_probs = F.pad(
neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8
neutral_codon_probs, (0, 0, 0, pad_len), value=SMALL_PROB
)

neutral_codon_probs_l.append(neutral_codon_probs)
Expand Down
25 changes: 9 additions & 16 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def update_neutral_probs(self):
nt_mask = mask.repeat_interleave(3)[:parent_len]
molevol.check_csps(parent_idxs[nt_mask], nt_csps[:parent_len][nt_mask])

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]

neutral_aa_mut_probs = molevol.neutral_aa_mut_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
neutral_aa_mut_probs = molevol.non_stop_neutral_aa_mut_probs(
parent_idxs,
nt_rates[:parent_len],
nt_csps,
branch_length,
multihit_model=self.multihit_model,
)

Expand Down Expand Up @@ -158,18 +158,11 @@ def _build_selection_matrix_from_selection_factors(

upgrades the provided tensor containing a selection factor per site to a matrix
containing a selection factor per site and amino acid. The wildtype aa selection
factor is set ot 1, and the rest are set to the selection factor.
factor is set to 1, and the rest are set to the selection factor.
"""
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
valid_mask = aa_parent_idxs < 20
selection_matrix[
torch.arange(len(aa_parent_idxs))[valid_mask], aa_parent_idxs[valid_mask]
] = 1.0
selection_matrix[~valid_mask] = 1.0
return selection_matrix
return molevol.lift_to_per_aa_selection_factors(
selection_factors, aa_parent_idxs
)

def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
Expand Down
44 changes: 44 additions & 0 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from warnings import warn
from abc import ABC, abstractmethod
import copy

Expand All @@ -16,6 +17,7 @@
stack_heterogeneous,
zap_predictions_along_diagonal,
)
from netam.pretrained import name_and_multihit_model_match
import netam.framework as framework
import netam.molevol as molevol
from netam.sequences import (
Expand Down Expand Up @@ -303,6 +305,48 @@ class DXSMBurrito(framework.Burrito, ABC):
# Not defining model_type here; instead defining it in subclasses.
# This will raise an error if we aren't using a subclass.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Does model metadata match dataset?

# For backward compatibility -- it's not possible to determine what an
# old crepe is from its metadata :(
if self.model.model_type != "unknown":
if not (self.model.model_type == self.model_type):
warn(
f"Model type {self.model.model_type} does not match expected type {self.model_type}. "
"To avoid this warning, provide `model_type` argument to model constructor."
)
else:
warn(
"Model type is unknown. This is likely an old model that does not include "
"its type (dnsm, ddsm, or dasm, etc.) in its metadata. Be sure the model "
"type matches the Dataset and Burrito type."
)

multihit_model_name = self.model.hyperparameters["multihit_model_name"]
if not name_and_multihit_model_match(
multihit_model_name,
self.val_dataset.multihit_model,
):
warn(
"Validation dataset multihit model does not match the one referenced in "
f"provided model metadata: '{multihit_model_name}'. "
"To fix this, provide the `multihit_model_name` argument to the model "
"constructor, or provide the corresponding multihit model instance to the Dataset constructor."
)
if self.train_dataset is not None:
if not name_and_multihit_model_match(
multihit_model_name,
self.train_dataset.multihit_model,
):
warn(
"Training dataset multihit model does not match the one referenced in "
f"provided model metadata: '{multihit_model_name}'. "
"To fix this, provide the `multihit_model_name` argument to the model "
"constructor, or provide the corresponding multihit model instance to the Dataset constructor."
)

def selection_factors_of_aa_idxs(self, aa_idxs, aa_mask):
"""Get the log selection factors for a batch of amino acid indices.

Expand Down
Loading