Skip to content

Commit 0b76ee0

Browse files
authored
Fix for dasm2 issue 61 (remove division normalization) (#177)
1 parent 81c50e8 commit 0b76ee0

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

netam/molevol.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from netam.codon_table import (
1515
CODON_AA_INDICATOR_MATRIX,
16+
STOP_CODON_INDICATOR,
1617
STOP_CODON_ZAPPER,
1718
aa_idxs_of_codon_idxs,
1819
)
@@ -155,10 +156,16 @@ def aaprobs_of_codon_probs(codon_probs: Tensor) -> Tensor:
155156
"""Compute the probability of each amino acid from the probability of each codon,
156157
for each parent codon along the sequence.
157158
159+
Note: The output amino acid probabilities will sum to (1 - P(stop codons)).
160+
If inputs are valid probability distributions with small stop codon probability,
161+
outputs will be approximately normalized.
162+
158163
Args:
159164
codon_probs (torch.Tensor): A 4D tensor representing the probability of mutating
160165
to each codon for each parent codon along the sequence.
161166
Shape should be (codon_count, 4, 4, 4).
167+
Must be a valid probability distribution (non-negative,
168+
sums to 1 for each site).
162169
163170
Returns:
164171
torch.Tensor: A 2D tensor with shape (codon_count, 20) where the ij-th entry is the probability
@@ -170,13 +177,9 @@ def aaprobs_of_codon_probs(codon_probs: Tensor) -> Tensor:
170177
# the `codon_count` dimension intact. This prepares the tensor for matrix multiplication.
171178
reshaped_probs = codon_probs.reshape(codon_count, -1)
172179

173-
# Perform matrix multiplication to get unnormalized amino acid probabilities.
180+
# Perform matrix multiplication to get amino acid probabilities.
174181
aaprobs = torch.matmul(reshaped_probs, CODON_AA_INDICATOR_MATRIX)
175182

176-
# Normalize probabilities along the amino acid dimension.
177-
row_sums = aaprobs.sum(dim=1, keepdim=True)
178-
aaprobs /= row_sums
179-
180183
return aaprobs
181184

182185

@@ -190,7 +193,8 @@ def aaprob_of_mut_and_sub(
190193
This function actually isn't used anymore, but there is a good test for it, which
191194
tests other functions, so we keep it.
192195
193-
Stop codons don't appear as part of this calculation.
196+
Stop codons don't appear as part of this calculation. Stop codon probabilities are
197+
zeroed and the remaining probabilities are renormalized before computing AA probs.
194198
195199
Args:
196200
parent_codon_idxs (torch.Tensor): A 2D tensor where each row contains indices representing
@@ -210,6 +214,14 @@ def aaprob_of_mut_and_sub(
210214
parent_codon_idxs, codon_mut_probs, codon_csps
211215
)
212216
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)
217+
# Zero out stop codon probabilities and renormalize
218+
flat_codon_probs = flatten_codons(codon_probs)
219+
flat_codon_probs = zero_stop_codon_probs(flat_codon_probs)
220+
# Need to convert these parent codon indices to the 64-based kind
221+
flat_codon_probs = set_parent_codon_prob(
222+
flat_codon_probs, sequences.flatten_codon_idxs(parent_codon_idxs)
223+
)
224+
codon_probs = unflatten_codons(flat_codon_probs)
213225
return aaprobs_of_codon_probs(codon_probs)
214226

215227

@@ -279,6 +291,9 @@ def aaprobs_of_parent_scaled_rates_and_csps(
279291
"""Calculate per-site amino acid probabilities from per-site nucleotide rates and
280292
substitution probabilities.
281293
294+
Stop codon probabilities are zeroed and the parent codon probability is adjusted
295+
to make the distribution sum to 1 before computing AA probs.
296+
282297
Args:
283298
parent_idxs (torch.Tensor): Parent nucleotide indices. Shape should be (site_count,).
284299
scaled_nt_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
@@ -290,11 +305,17 @@ def aaprobs_of_parent_scaled_rates_and_csps(
290305
torch.Tensor: A 2D tensor with rows corresponding to sites and columns
291306
corresponding to amino acids.
292307
"""
293-
return aaprobs_of_codon_probs(
294-
codon_probs_of_parent_scaled_nt_rates_and_csps(
295-
parent_idxs, scaled_nt_rates, nt_csps
296-
)
308+
codon_probs = codon_probs_of_parent_scaled_nt_rates_and_csps(
309+
parent_idxs, scaled_nt_rates, nt_csps
297310
)
311+
# Zero out stop codon probabilities and set parent codon prob to make sum = 1
312+
flat_codon_probs = flatten_codons(codon_probs)
313+
flat_codon_probs = zero_stop_codon_probs(flat_codon_probs)
314+
parent_codon_idxs = parent_idxs.reshape(-1, 3)
315+
flat_parent_codon_idxs = sequences.flatten_codon_idxs(parent_codon_idxs)
316+
flat_codon_probs = set_parent_codon_prob(flat_codon_probs, flat_parent_codon_idxs)
317+
codon_probs = unflatten_codons(flat_codon_probs)
318+
return aaprobs_of_codon_probs(codon_probs)
298319

299320

300321
def build_codon_mutsel(

tests/test_molevol.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,20 @@ def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps):
151151

152152
aa_probs[aa] += child_prob
153153

154-
# need renormalization factor so that amino acid probabilities sum to 1,
155-
# since probabilities to STOP codon are dropped
154+
# Instead of renormalizing, add the "missing" probability (from stop codons)
155+
# to the parent AA probability. This matches the behavior of set_parent_codon_prob.
156156
psum = sum(aa_probs.values())
157+
parent_aa = translate_sequence(parent_codon)
158+
aa_probs[parent_aa] += 1.0 - psum
157159

158-
return torch.tensor([aa_probs[aa] / psum for aa in AA_STR_SORTED])
160+
return torch.tensor([aa_probs[aa] for aa in AA_STR_SORTED])
159161

160162

161163
def test_aaprob_of_mut_and_sub():
162164
crepe = pretrained.load("ThriftyHumV0.2-45")
163-
[rates], [subs] = crepe([parent_nt_seq])
165+
[rates], [subs_logits] = crepe([parent_nt_seq])
166+
# Apply softmax to convert logits to valid CSPs (probability distributions)
167+
subs = torch.softmax(subs_logits, dim=-1)
164168
mut_probs = 1.0 - torch.exp(-rates.squeeze().clone().detach())
165169
parent_codon = parent_nt_seq[0:3]
166170
parent_codon_idxs = nt_idx_tensor_of_str(parent_codon)

0 commit comments

Comments
 (0)