1313
1414from netam .codon_table import (
1515 CODON_AA_INDICATOR_MATRIX ,
16+ STOP_CODON_INDICATOR ,
1617 STOP_CODON_ZAPPER ,
1718 aa_idxs_of_codon_idxs ,
1819)
@@ -155,10 +156,16 @@ def aaprobs_of_codon_probs(codon_probs: Tensor) -> Tensor:
155156 """Compute the probability of each amino acid from the probability of each codon,
156157 for each parent codon along the sequence.
157158
159+ Note: The output amino acid probabilities will sum to (1 - P(stop codons)).
160+ If inputs are valid probability distributions with small stop codon probability,
161+ outputs will be approximately normalized.
162+
158163 Args:
159164 codon_probs (torch.Tensor): A 4D tensor representing the probability of mutating
160165 to each codon for each parent codon along the sequence.
161166 Shape should be (codon_count, 4, 4, 4).
167+ Must be a valid probability distribution (non-negative,
168+ sums to 1 for each site).
162169
163170 Returns:
164171 torch.Tensor: A 2D tensor with shape (codon_count, 20) where the ij-th entry is the probability
@@ -170,13 +177,9 @@ def aaprobs_of_codon_probs(codon_probs: Tensor) -> Tensor:
170177 # the `codon_count` dimension intact. This prepares the tensor for matrix multiplication.
171178 reshaped_probs = codon_probs .reshape (codon_count , - 1 )
172179
173- # Perform matrix multiplication to get unnormalized amino acid probabilities.
180+ # Perform matrix multiplication to get amino acid probabilities.
174181 aaprobs = torch .matmul (reshaped_probs , CODON_AA_INDICATOR_MATRIX )
175182
176- # Normalize probabilities along the amino acid dimension.
177- row_sums = aaprobs .sum (dim = 1 , keepdim = True )
178- aaprobs /= row_sums
179-
180183 return aaprobs
181184
182185
@@ -190,7 +193,8 @@ def aaprob_of_mut_and_sub(
190193 This function actually isn't used anymore, but there is a good test for it, which
191194 tests other functions, so we keep it.
192195
193- Stop codons don't appear as part of this calculation.
196+ Stop codons don't appear as part of this calculation. Stop codon probabilities are
197+ zeroed and the remaining probabilities are renormalized before computing AA probs.
194198
195199 Args:
196200 parent_codon_idxs (torch.Tensor): A 2D tensor where each row contains indices representing
@@ -210,6 +214,14 @@ def aaprob_of_mut_and_sub(
210214 parent_codon_idxs , codon_mut_probs , codon_csps
211215 )
212216 codon_probs = codon_probs_of_mutation_matrices (mut_matrices )
217+ # Zero out stop codon probabilities and renormalize
218+ flat_codon_probs = flatten_codons (codon_probs )
219+ flat_codon_probs = zero_stop_codon_probs (flat_codon_probs )
220+ # Need to convert these parent codon indices to the 64-based kind
221+ flat_codon_probs = set_parent_codon_prob (
222+ flat_codon_probs , sequences .flatten_codon_idxs (parent_codon_idxs )
223+ )
224+ codon_probs = unflatten_codons (flat_codon_probs )
213225 return aaprobs_of_codon_probs (codon_probs )
214226
215227
@@ -279,6 +291,9 @@ def aaprobs_of_parent_scaled_rates_and_csps(
279291 """Calculate per-site amino acid probabilities from per-site nucleotide rates and
280292 substitution probabilities.
281293
294+ Stop codon probabilities are zeroed and the parent codon probability is adjusted
295+ to make the distribution sum to 1 before computing AA probs.
296+
282297 Args:
283298 parent_idxs (torch.Tensor): Parent nucleotide indices. Shape should be (site_count,).
284299 scaled_nt_rates (torch.Tensor): Poisson rates of mutation per site, scaled by branch length.
@@ -290,11 +305,17 @@ def aaprobs_of_parent_scaled_rates_and_csps(
290305 torch.Tensor: A 2D tensor with rows corresponding to sites and columns
291306 corresponding to amino acids.
292307 """
293- return aaprobs_of_codon_probs (
294- codon_probs_of_parent_scaled_nt_rates_and_csps (
295- parent_idxs , scaled_nt_rates , nt_csps
296- )
308+ codon_probs = codon_probs_of_parent_scaled_nt_rates_and_csps (
309+ parent_idxs , scaled_nt_rates , nt_csps
297310 )
311+ # Zero out stop codon probabilities and set parent codon prob to make sum = 1
312+ flat_codon_probs = flatten_codons (codon_probs )
313+ flat_codon_probs = zero_stop_codon_probs (flat_codon_probs )
314+ parent_codon_idxs = parent_idxs .reshape (- 1 , 3 )
315+ flat_parent_codon_idxs = sequences .flatten_codon_idxs (parent_codon_idxs )
316+ flat_codon_probs = set_parent_codon_prob (flat_codon_probs , flat_parent_codon_idxs )
317+ codon_probs = unflatten_codons (flat_codon_probs )
318+ return aaprobs_of_codon_probs (codon_probs )
298319
299320
300321def build_codon_mutsel (
0 commit comments