3232from chai_lab .data .dataset .msas .colabfold import generate_colabfold_msas
3333from chai_lab .data .dataset .msas .load import get_msa_contexts
3434from chai_lab .data .dataset .msas .msa_context import MSAContext
35+ from chai_lab .data .dataset .msas .utils import subsample_msa_rows
3536from chai_lab .data .dataset .structure .all_atom_structure_context import (
3637 AllAtomStructureContext ,
3738)
@@ -441,6 +442,7 @@ def run_inference(
441442 msa_directory : Path | None = None ,
442443 constraint_path : Path | None = None ,
443444 # expose some params for easy tweaking
445+ recycle_msa_subsample : int = 0 ,
444446 num_trunk_recycles : int = 3 ,
445447 num_diffn_timesteps : int = 200 ,
446448 num_diffn_samples : int = 5 ,
@@ -472,6 +474,7 @@ def run_inference(
472474 num_trunk_recycles = num_trunk_recycles ,
473475 num_diffn_timesteps = num_diffn_timesteps ,
474476 num_diffn_samples = num_diffn_samples ,
477+ recycle_msa_subsample = recycle_msa_subsample ,
475478 seed = seed ,
476479 device = torch_device ,
477480 low_memory = low_memory ,
@@ -488,6 +491,7 @@ def run_folding_on_context(
488491 * ,
489492 output_dir : Path ,
490493 # expose some params for easy tweaking
494+ recycle_msa_subsample : int = 0 ,
491495 num_trunk_recycles : int = 3 ,
492496 num_diffn_timesteps : int = 200 ,
493497 # all diffusion samples come from the same trunk
@@ -647,7 +651,7 @@ def run_folding_on_context(
647651 token_single_trunk_repr = token_single_trunk_repr , # recycled
648652 token_pair_trunk_repr = token_pair_trunk_repr , # recycled
649653 msa_input_feats = msa_input_feats ,
650- msa_mask = msa_mask ,
654+ msa_mask = subsample_msa_rows ( msa_mask , select_n_rows = recycle_msa_subsample ) ,
651655 template_input_feats = template_input_feats ,
652656 template_input_masks = template_input_masks ,
653657 token_single_mask = token_single_mask ,
0 commit comments