Skip to content

Commit 2a2d4a9

Browse files
authored
Allow using names specified in fasta file as chain names in cif (#378)
1 parent 79dc5b9 commit 2a2d4a9

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

chai_lab/chai1.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
9696
TokenPairPocketRestraint,
9797
)
98-
from chai_lab.data.io.cif_utils import get_chain_letter, save_to_cif
98+
from chai_lab.data.io.cif_utils import _CHAIN_VOCAB, get_chain_letter, save_to_cif
9999
from chai_lab.data.parsing.restraints import parse_pairwise_table
100100
from chai_lab.data.parsing.structure.entity_type import EntityType
101101
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
@@ -514,6 +514,8 @@ def run_inference(
514514
seed: int | None = None,
515515
device: str | None = None,
516516
low_memory: bool = True,
517+
# IO options
518+
fasta_names_as_cif_chains: bool = False,
517519
) -> StructureCandidates:
518520
assert num_trunk_samples > 0 and num_diffn_samples > 0
519521
if output_dir.exists():
@@ -553,6 +555,7 @@ def run_inference(
553555
seed=seed + trunk_idx if seed is not None else None,
554556
device=torch_device,
555557
low_memory=low_memory,
558+
entity_names_as_chain_names_in_output_cif=fasta_names_as_cif_chains,
556559
)
557560
all_candidates.append(cand)
558561
return StructureCandidates.concat(all_candidates)
@@ -573,6 +576,7 @@ def run_folding_on_context(
573576
num_diffn_timesteps: int = 200,
574577
# all diffusion samples come from the same trunk
575578
num_diffn_samples: int = 5,
579+
entity_names_as_chain_names_in_output_cif: bool = False,
576580
seed: int | None = None,
577581
device: torch.device | None = None,
578582
low_memory: bool,
@@ -602,6 +606,19 @@ def run_folding_on_context(
602606
# NOTE profile MSA used only for statistics; no depth check
603607
feature_context.structure_context.report_bonds()
604608

609+
if entity_names_as_chain_names_in_output_cif:
610+
# Ensure that entity names are unique and are valid chain names
611+
entity_names: list[str] = [
612+
chain.entity_data.entity_name for chain in feature_context.chains
613+
]
614+
assert len(set(entity_names)) == len(
615+
entity_names
616+
), f"Using entity names for cif chains, but got duplicates: {entity_names}"
617+
assert all(e in _CHAIN_VOCAB for e in entity_names), (
618+
"Using entity names for cif chains, but got invalid names "
619+
f"{entity_names}; must be in {_CHAIN_VOCAB}"
620+
)
621+
605622
##
606623
## Prepare batch
607624
##
@@ -1004,10 +1021,15 @@ def avg_per_token_1d(x):
10041021
bfactors=scaled_plddt_scores_per_atom,
10051022
output_batch=inputs,
10061023
write_path=cif_out_path,
1007-
# Set asym names to be A, B, C, ...
1024+
# Set asym names to match entity names from fasta if requested;
1025+
# otherwise auto-generate A, B, C, ... sequentially
10081026
asym_entity_names={
1009-
i: get_chain_letter(i)
1010-
for i in range(1, len(feature_context.chains) + 1)
1027+
i: (
1028+
chain.entity_data.entity_name
1029+
if entity_names_as_chain_names_in_output_cif
1030+
else get_chain_letter(i)
1031+
)
1032+
for i, chain in enumerate(feature_context.chains, start=1)
10111033
},
10121034
)
10131035
cif_paths.append(cif_out_path)

0 commit comments

Comments
 (0)