@@ -21,9 +21,18 @@ class GreedyDecoder(Decoder):
2121 models that conform to the `Decodable` interface.
2222 """
2323
24- def __init__ (self , model : Decodable , mass_scale : int = MASS_SCALE ):
24+ def __init__ (
25+ self ,
26+ model : Decodable ,
27+ suppressed_residues : list [str ] | None = None ,
28+ mass_scale : int = MASS_SCALE ,
29+ disable_terminal_residues_anywhere : bool = True ,
30+ ):
2531 super ().__init__ (model = model )
2632 self .mass_scale = mass_scale
33+ self .disable_terminal_residues_anywhere = disable_terminal_residues_anywhere
34+
35+ suppressed_residues = suppressed_residues or []
2736
2837 # NOTE: Greedy search requires `residue_set` class in the model, update all methods accordingly.
2938 if not hasattr (model , "residue_set" ):
@@ -37,10 +46,32 @@ def __init__(self, model: Decodable, mass_scale: int = MASS_SCALE):
3746 self .residue_masses = torch .zeros (
3847 (len (self .model .residue_set ),), dtype = torch .float64
3948 )
49+ terminal_residues_idx : list [int ] = []
50+ suppressed_residues_idx : list [int ] = []
4051 for i , residue in enumerate (model .residue_set .vocab ):
4152 if residue in self .model .residue_set .special_tokens :
4253 continue
4354 self .residue_masses [i ] = self .model .residue_set .get_mass (residue )
55+ # If no residue is attached, assume it is a n-terminal residue
56+ if not residue [0 ].isalpha ():
57+ terminal_residues_idx .append (i )
58+
59+ # Check if residue is suppressed
60+ if residue in suppressed_residues :
61+ suppressed_residues_idx .append (i )
62+ suppressed_residues .remove (residue )
63+
64+ if len (suppressed_residues ) > 0 :
65+ raise ValueError (
66+ f"Suppressed residues not found in vocabulary: { suppressed_residues } "
67+ )
68+
69+ self .terminal_residue_indices = torch .tensor (
70+ terminal_residues_idx , dtype = torch .long
71+ )
72+ self .suppressed_residue_indices = torch .tensor (
73+ suppressed_residues_idx , dtype = torch .long
74+ )
4475
4576 self .vocab_size = len (self .model .residue_set )
4677
@@ -270,10 +301,53 @@ def decode( # type:ignore
270301 next_token_probabilities_filtered [
271302 :, self .model .residue_set .EOS_INDEX
272303 ] = - float ("inf" )
304+ # Allow the model to predict PAD when all residues are -inf
305+ # next_token_probabilities_filtered[
306+ # :, self.model.residue_set.PAD_INDEX
307+ # ] = -float("inf")
273308 next_token_probabilities_filtered [
274309 :, self .model .residue_set .SOS_INDEX
275310 ] = - float ("inf" )
276- # TODO set probability of n-terminal modifications to 0 when i > 0, requires n-terms to be specified in residue_set
311+ next_token_probabilities_filtered [
312+ :, self .suppressed_residue_indices
313+ ] = - float ("inf" )
314+ # Set probability of n-terminal modifications to -inf when i > 0
315+ if self .disable_terminal_residues_anywhere :
316+ # Check if adding terminal residues would result in a complete sequence
317+ # First generate remaining mass matrix with isotopes
318+ remaining_mass_incomplete_isotope = remaining_mass_incomplete [
319+ :, None
320+ ].expand (sub_batch_size , max_isotope + 1 ) - CARBON_MASS_DELTA * (
321+ torch .arange (max_isotope + 1 , device = device )
322+ )
323+ # Expand with terminal residues and subtract
324+ remaining_mass_incomplete_isotope_delta = (
325+ remaining_mass_incomplete_isotope [:, :, None ].expand (
326+ sub_batch_size ,
327+ max_isotope + 1 ,
328+ self .terminal_residue_indices .shape [0 ],
329+ )
330+ - self .residue_masses [self .terminal_residue_indices ]
331+ )
332+
333+ # If within target delta, allow these residues to be predicted, otherwise set probability to -inf
334+ allow_terminal = (
335+ remaining_mass_incomplete_isotope_delta .abs ()
336+ < mass_target_incomplete [:, None , None ]
337+ ).any (dim = 1 )
338+ allow_terminal_full = torch .ones (
339+ (sub_batch_size , self .vocab_size ),
340+ device = spectra .device ,
341+ dtype = bool ,
342+ )
343+ allow_terminal_full [:, self .terminal_residue_indices ] = (
344+ allow_terminal
345+ )
346+
347+ # Set to -inf
348+ next_token_probabilities_filtered [~ allow_terminal_full ] = - float (
349+ "inf"
350+ )
277351
278352 # Step 5: Select next token:
279353 next_token = next_token_probabilities_filtered .argmax (- 1 ).unsqueeze (
@@ -362,7 +436,7 @@ def decode( # type:ignore
362436 token_log_probabilities = [
363437 x .cpu ().item ()
364438 for x in all_log_probabilities [i , : len (sequence )]
365- ], # list[float] (sequence_length) excludes EOS
439+ ][:: - 1 ] , # list[float] (sequence_length) excludes EOS
366440 )
367441 )
368442
0 commit comments