9595from chai_lab .data .features .generators .token_pair_pocket_restraint import (
9696 TokenPairPocketRestraint ,
9797)
98- from chai_lab .data .io .cif_utils import get_chain_letter , save_to_cif
98+ from chai_lab .data .io .cif_utils import _CHAIN_VOCAB , get_chain_letter , save_to_cif
9999from chai_lab .data .parsing .restraints import parse_pairwise_table
100100from chai_lab .data .parsing .structure .entity_type import EntityType
101101from chai_lab .model .diffusion_schedules import InferenceNoiseSchedule
@@ -514,6 +514,8 @@ def run_inference(
514514 seed : int | None = None ,
515515 device : str | None = None ,
516516 low_memory : bool = True ,
517+ # IO options
518+ fasta_names_as_cif_chains : bool = False ,
517519) -> StructureCandidates :
518520 assert num_trunk_samples > 0 and num_diffn_samples > 0
519521 if output_dir .exists ():
@@ -553,6 +555,7 @@ def run_inference(
553555 seed = seed + trunk_idx if seed is not None else None ,
554556 device = torch_device ,
555557 low_memory = low_memory ,
558+ entity_names_as_chain_names_in_output_cif = fasta_names_as_cif_chains ,
556559 )
557560 all_candidates .append (cand )
558561 return StructureCandidates .concat (all_candidates )
@@ -573,6 +576,7 @@ def run_folding_on_context(
573576 num_diffn_timesteps : int = 200 ,
574577 # all diffusion samples come from the same trunk
575578 num_diffn_samples : int = 5 ,
579+ entity_names_as_chain_names_in_output_cif : bool = False ,
576580 seed : int | None = None ,
577581 device : torch .device | None = None ,
578582 low_memory : bool ,
@@ -602,6 +606,19 @@ def run_folding_on_context(
602606 # NOTE profile MSA used only for statistics; no depth check
603607 feature_context .structure_context .report_bonds ()
604608
609+ if entity_names_as_chain_names_in_output_cif :
610+ # Ensure that entity names are unique and are valid chain names
611+ entity_names : list [str ] = [
612+ chain .entity_data .entity_name for chain in feature_context .chains
613+ ]
614+ assert len (set (entity_names )) == len (
615+ entity_names
616+ ), f"Using entity names for cif chains, but got duplicates: { entity_names } "
617+ assert all (e in _CHAIN_VOCAB for e in entity_names ), (
618+ "Using entity names for cif chains, but got invalid names "
619+ f"{ entity_names } ; must be in { _CHAIN_VOCAB } "
620+ )
621+
605622 ##
606623 ## Prepare batch
607624 ##
@@ -1004,10 +1021,15 @@ def avg_per_token_1d(x):
10041021 bfactors = scaled_plddt_scores_per_atom ,
10051022 output_batch = inputs ,
10061023 write_path = cif_out_path ,
1007- # Set asym names to be A, B, C, ...
1024+ # Set asym names to match entity names from fasta if requested;
1025+ # otherwise auto-generate A, B, C, ... sequentially
10081026 asym_entity_names = {
1009- i : get_chain_letter (i )
1010- for i in range (1 , len (feature_context .chains ) + 1 )
1027+ i : (
1028+ chain .entity_data .entity_name
1029+ if entity_names_as_chain_names_in_output_cif
1030+ else get_chain_letter (i )
1031+ )
1032+ for i , chain in enumerate (feature_context .chains , start = 1 )
10111033 },
10121034 )
10131035 cif_paths .append (cif_out_path )
0 commit comments