|
1 | 1 | import numpy as np |
| 2 | +import pandas as pd |
2 | 3 | import torch |
| 4 | +from typing import Tuple # , Dict, List |
3 | 5 |
|
4 | 6 | from Bio.Data import CodonTable |
5 | | -from netam.sequences import AA_STR_SORTED, CODONS, STOP_CODONS, translate_sequences |
| 7 | +from netam.sequences import ( |
| 8 | + AA_STR_SORTED, |
| 9 | + AMBIGUOUS_CODON_IDX, |
| 10 | + CODONS, |
| 11 | + STOP_CODONS, |
| 12 | + contains_stop_codon, |
| 13 | + idx_of_codon_allowing_ambiguous, |
| 14 | + iter_codons, |
| 15 | + translate_sequences, |
| 16 | +) |
6 | 17 | from netam.common import BIG |
7 | 18 |
|
8 | 19 |
|
@@ -95,3 +106,192 @@ def aa_idxs_of_codon_idxs(codon_idx_tensor): |
95 | 106 | codon_idx_tensor[:, 2], |
96 | 107 | ) |
97 | 108 | ] |
| 109 | + |
| 110 | + |
| 111 | +def generate_codon_neighbor_matrix(): |
| 112 | + """Generate codon neighbor matrix for efficient single-mutation lookups. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + torch.Tensor: A (65, 20) boolean matrix where entry (i, j) is True if |
| 116 | + codon i can mutate to amino acid j via single nucleotide substitution. |
| 117 | + Row 64 (AMBIGUOUS_CODON_IDX) will be all False. |
| 118 | + """ |
| 119 | + # Include space for ambiguous codon at index 64 |
| 120 | + matrix = np.zeros((AMBIGUOUS_CODON_IDX + 1, len(AA_STR_SORTED)), dtype=bool) |
| 121 | + |
| 122 | + # Only process the 64 standard codons, not the ambiguous codon |
| 123 | + for i, codon in enumerate(CODONS): |
| 124 | + mutant_aa_indices = single_mutant_aa_indices(codon) |
| 125 | + matrix[i, mutant_aa_indices] = True |
| 126 | + |
| 127 | + # Row 64 (AMBIGUOUS_CODON_IDX) remains all False |
| 128 | + |
| 129 | + return torch.tensor(matrix, dtype=torch.bool) |
| 130 | + |
| 131 | + |
| 132 | +def generate_codon_single_mutation_map(): |
| 133 | + """Generate mapping of codon-to-codon single mutations. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + Dict[int, List[Tuple[int, int, str]]]: Maps parent codon index to list of |
| 137 | + (child_codon_idx, nt_position, new_base) for all single mutations. |
| 138 | + Only includes valid codons (0-63), not AMBIGUOUS_CODON_IDX (64). |
| 139 | + """ |
| 140 | + mutation_map = {} |
| 141 | + |
| 142 | + # Only process the 64 valid codons, not the ambiguous codon at index 64 |
| 143 | + for parent_idx, parent_codon in enumerate(CODONS): |
| 144 | + mutations = [] |
| 145 | + for nt_pos in range(3): |
| 146 | + for new_base in ["A", "C", "G", "T"]: |
| 147 | + if new_base != parent_codon[nt_pos]: |
| 148 | + child_codon = ( |
| 149 | + parent_codon[:nt_pos] + new_base + parent_codon[nt_pos + 1 :] |
| 150 | + ) |
| 151 | + child_idx = CODONS.index(child_codon) |
| 152 | + mutations.append((child_idx, nt_pos, new_base)) |
| 153 | + mutation_map[parent_idx] = mutations |
| 154 | + |
| 155 | + return mutation_map |
| 156 | + |
| 157 | + |
| 158 | +# Global tensors/mappings for efficient lookups |
| 159 | +CODON_NEIGHBOR_MATRIX = generate_codon_neighbor_matrix() # (65, 20) |
| 160 | +CODON_SINGLE_MUTATIONS = generate_codon_single_mutation_map() |
| 161 | + |
| 162 | + |
| 163 | +def encode_codon_mutations( |
| 164 | + nt_parents: pd.Series, nt_children: pd.Series |
| 165 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 166 | + """Convert parent/child nucleotide sequences to codon indices and mutation |
| 167 | + indicators. |
| 168 | +
|
| 169 | + Args: |
| 170 | + nt_parents: Parent nucleotide sequences |
| 171 | + nt_children: Child nucleotide sequences |
| 172 | +
|
| 173 | + Returns: |
| 174 | + Tuple of: |
| 175 | + - codon_parents_idxss: (N, L_codon) tensor of parent codon indices |
| 176 | + - codon_children_idxss: (N, L_codon) tensor of child codon indices |
| 177 | + - codon_mutation_indicators: (N, L_codon) boolean tensor indicating mutation positions |
| 178 | +
|
| 179 | + Example: |
| 180 | + >>> parents = pd.Series(['ATGAAACCC']) |
| 181 | + >>> children = pd.Series(['ATGAAACCG']) # CCC->CCG mutation |
| 182 | + >>> p_idx, c_idx, mut = encode_codon_mutations(parents, children) |
| 183 | + >>> mut[0] # tensor([False, False, True]) |
| 184 | + """ |
| 185 | + # Convert sequences to lists for processing |
| 186 | + parent_seqs = nt_parents.tolist() |
| 187 | + child_seqs = nt_children.tolist() |
| 188 | + |
| 189 | + # Check that all sequences have same length and are multiples of 3 |
| 190 | + if not all(len(seq) == len(parent_seqs[0]) for seq in parent_seqs + child_seqs): |
| 191 | + raise ValueError("All sequences must have the same length") |
| 192 | + |
| 193 | + seq_len = len(parent_seqs[0]) |
| 194 | + if seq_len % 3 != 0: |
| 195 | + raise ValueError("Sequence length must be a multiple of 3") |
| 196 | + |
| 197 | + codon_len = seq_len // 3 |
| 198 | + n_sequences = len(parent_seqs) |
| 199 | + |
| 200 | + # Extract all codons at once for vectorized processing |
| 201 | + all_parent_codons = [] |
| 202 | + all_child_codons = [] |
| 203 | + |
| 204 | + for parent_seq, child_seq in zip(parent_seqs, child_seqs): |
| 205 | + parent_codons = list(iter_codons(parent_seq)) |
| 206 | + child_codons = list(iter_codons(child_seq)) |
| 207 | + all_parent_codons.append(parent_codons) |
| 208 | + all_child_codons.append(child_codons) |
| 209 | + |
| 210 | + # Vectorized codon index lookup |
| 211 | + parent_codon_indices = torch.zeros((n_sequences, codon_len), dtype=torch.long) |
| 212 | + child_codon_indices = torch.zeros((n_sequences, codon_len), dtype=torch.long) |
| 213 | + mutation_indicators = torch.zeros((n_sequences, codon_len), dtype=torch.bool) |
| 214 | + |
| 215 | + # Process in batches for better cache locality |
| 216 | + for seq_idx in range(n_sequences): |
| 217 | + parent_codons = all_parent_codons[seq_idx] |
| 218 | + child_codons = all_child_codons[seq_idx] |
| 219 | + |
| 220 | + # Vectorized index lookup using list comprehension (faster than nested loops) |
| 221 | + parent_indices = [ |
| 222 | + idx_of_codon_allowing_ambiguous(codon) for codon in parent_codons |
| 223 | + ] |
| 224 | + child_indices = [ |
| 225 | + idx_of_codon_allowing_ambiguous(codon) for codon in child_codons |
| 226 | + ] |
| 227 | + mutations = [ |
| 228 | + p_codon != c_codon for p_codon, c_codon in zip(parent_codons, child_codons) |
| 229 | + ] |
| 230 | + |
| 231 | + # Assign to tensors |
| 232 | + parent_codon_indices[seq_idx] = torch.tensor(parent_indices, dtype=torch.long) |
| 233 | + child_codon_indices[seq_idx] = torch.tensor(child_indices, dtype=torch.long) |
| 234 | + mutation_indicators[seq_idx] = torch.tensor(mutations, dtype=torch.bool) |
| 235 | + |
| 236 | + return parent_codon_indices, child_codon_indices, mutation_indicators |
| 237 | + |
| 238 | + |
| 239 | +def create_codon_masks(nt_parents: pd.Series, nt_children: pd.Series) -> torch.Tensor: |
| 240 | + """Create masks for valid codon positions, masking ambiguous codons (containing Ns). |
| 241 | +
|
| 242 | + Args: |
| 243 | + nt_parents: Parent nucleotide sequences |
| 244 | + nt_children: Child nucleotide sequences |
| 245 | +
|
| 246 | + Returns: |
| 247 | + masks: (N, L_codon) boolean tensor indicating valid codon positions |
| 248 | +
|
| 249 | + Example: |
| 250 | + >>> parents = pd.Series(['ATGNNNCCG']) # Middle codon has Ns |
| 251 | + >>> children = pd.Series(['ATGNNNCCG']) |
| 252 | + >>> masks = create_codon_masks(parents, children) |
| 253 | + >>> masks[0] # tensor([True, False, True]) |
| 254 | +
|
| 255 | + Raises: |
| 256 | + ValueError: If any sequences contain stop codons |
| 257 | + """ |
| 258 | + # Convert sequences to lists for processing |
| 259 | + parent_seqs = nt_parents.tolist() |
| 260 | + child_seqs = nt_children.tolist() |
| 261 | + |
| 262 | + # Check for stop codons in all sequences |
| 263 | + for seq_idx, seq in enumerate(parent_seqs): |
| 264 | + if contains_stop_codon(seq): |
| 265 | + raise ValueError(f"Parent sequence {seq_idx} contains a stop codon: {seq}") |
| 266 | + |
| 267 | + for seq_idx, seq in enumerate(child_seqs): |
| 268 | + if contains_stop_codon(seq): |
| 269 | + raise ValueError(f"Child sequence {seq_idx} contains a stop codon: {seq}") |
| 270 | + |
| 271 | + # Check that all sequences have same length and are multiples of 3 |
| 272 | + if not all(len(seq) == len(parent_seqs[0]) for seq in parent_seqs + child_seqs): |
| 273 | + raise ValueError("All sequences must have the same length") |
| 274 | + |
| 275 | + seq_len = len(parent_seqs[0]) |
| 276 | + if seq_len % 3 != 0: |
| 277 | + raise ValueError("Sequence length must be a multiple of 3") |
| 278 | + |
| 279 | + codon_len = seq_len // 3 |
| 280 | + n_sequences = len(parent_seqs) |
| 281 | + |
| 282 | + # Initialize mask tensor (True = valid, False = masked) |
| 283 | + masks = torch.ones((n_sequences, codon_len), dtype=torch.bool) |
| 284 | + |
| 285 | + # Process each sequence to identify ambiguous codons |
| 286 | + for seq_idx, (parent_seq, child_seq) in enumerate(zip(parent_seqs, child_seqs)): |
| 287 | + parent_codons = list(iter_codons(parent_seq)) |
| 288 | + child_codons = list(iter_codons(child_seq)) |
| 289 | + |
| 290 | + for codon_idx, (parent_codon, child_codon) in enumerate( |
| 291 | + zip(parent_codons, child_codons) |
| 292 | + ): |
| 293 | + # Mask positions where either parent or child has ambiguous codon (containing N) |
| 294 | + if "N" in parent_codon or "N" in child_codon: |
| 295 | + masks[seq_idx, codon_idx] = False |
| 296 | + |
| 297 | + return masks |
0 commit comments