From 8ac0d2fc8b049b376f656d1dd1bcbec2ae42aad0 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Mon, 2 Jun 2025 22:37:18 +0000 Subject: [PATCH 1/8] Handle case when no comments are provided --- chai_lab/data/parsing/restraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chai_lab/data/parsing/restraints.py b/chai_lab/data/parsing/restraints.py index 294c7453..c4820ece 100644 --- a/chai_lab/data/parsing/restraints.py +++ b/chai_lab/data/parsing/restraints.py @@ -166,7 +166,7 @@ def _parse_row(row: pd.Series) -> PairwiseInteraction: min_dist_angstrom=row["min_distance_angstrom"], connection_type=PairwiseInteractionType(row["connection_type"]), confidence=row["confidence"], - comment=row["comment"], + comment="" if pd.isna(row["comment"]) else row["comment"], ) From 176ead9ae3e38312311470913e4cda2d8f9fdf56 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Mon, 2 Jun 2025 23:11:27 +0000 Subject: [PATCH 2/8] Allow subchain entity name as subchain --- chai_lab/chai1.py | 7 ++++++- chai_lab/data/dataset/inference_dataset.py | 18 +++++++++++++++--- tests/test_inference_dataset.py | 17 +++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index eb267ffd..d7254113 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -339,6 +339,7 @@ def make_all_atom_feature_context( fasta_file: Path, *, output_dir: Path, + entity_name_as_subchain: bool = False, use_esm_embeddings: bool = True, use_msa_server: bool = False, msa_server_url: str = "https://api.colabfold.com", @@ -368,7 +369,9 @@ def make_all_atom_feature_context( ) # Load structure context - chains = load_chains_from_raw(fasta_inputs) + chains = load_chains_from_raw( + fasta_inputs, entity_name_as_subchain=entity_name_as_subchain + ) del fasta_inputs # Do not reference inputs after creating chains from them merged_context = AllAtomStructureContext.merge( @@ -525,9 +528,11 @@ def run_inference( torch_device = torch.device(device if device is not None else "cuda:0") + # NOTE if fastas are cif chain names, we use this to parse chains as well feature_context = make_all_atom_feature_context( fasta_file=fasta_file, output_dir=output_dir, + entity_name_as_subchain=fasta_names_as_cif_chains, use_esm_embeddings=use_esm_embeddings, use_msa_server=use_msa_server, msa_server_url=msa_server_url, diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index 7ac2fd7a..094f95ed 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -92,7 +92,9 @@ def _synth_subchain_id(idx: int) -> str: def raw_inputs_to_entitites_data( - inputs: list[Input], identifier: str = "test" + inputs: list[Input], + entity_name_as_subchain: bool = False, + identifier: str = "test", ) -> list[AllAtomEntityData]: """Load an entity for each raw input.""" entities = [] @@ -153,12 +155,20 @@ def raw_inputs_to_entitites_data( resolution=0.0, release_datetime=datetime.now(), pdb_id=identifier, - source_pdb_chain_id=_synth_subchain_id(i), + source_pdb_chain_id=( + _synth_subchain_id(i) + if not entity_name_as_subchain + else input.entity_name + ), entity_name=input.entity_name, entity_id=entity_id, method="none", entity_type=entity_type, - subchain_id=_synth_subchain_id(i), + subchain_id=( + _synth_subchain_id(i) + if not entity_name_as_subchain + else input.entity_name + ), original_record=input.sequence, ) ) @@ -170,6 +180,7 @@ def raw_inputs_to_entitites_data( def load_chains_from_raw( inputs: list[Input], identifier: str = "test", + entity_name_as_subchain: bool = False, tokenizer: AllAtomResidueTokenizer | None = None, ) -> list[Chain]: """ @@ -183,6 +194,7 @@ def load_chains_from_raw( # Extract the entity data from the gemmi structure. entities: list[AllAtomEntityData] = raw_inputs_to_entitites_data( inputs, + entity_name_as_subchain=entity_name_as_subchain, identifier=identifier, ) diff --git a/tests/test_inference_dataset.py b/tests/test_inference_dataset.py index c769eb8f..5571bd61 100644 --- a/tests/test_inference_dataset.py +++ b/tests/test_inference_dataset.py @@ -104,3 +104,20 @@ def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer): example.token_entity_type == EntityType.LIGAND.value ] assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand + + +def test_entity_names_as_subchain(tokenizer: AllAtomResidueTokenizer): + inputs = [ + Input(sequence="RKDES", entity_type=EntityType.PROTEIN.value, entity_name="X"), + Input(sequence="GHGHGH", entity_type=EntityType.PROTEIN.value, entity_name="Y"), + ] + + chains = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=False, tokenizer=tokenizer + ) + assert [chain.entity_data.subchain_id for chain in chains] == ["A", "B"] + + chain_manual_named = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=True, tokenizer=tokenizer + ) + assert [chain.entity_data.subchain_id for chain in chain_manual_named] == ["X", "Y"] From d9f638f8c2dceac6cd8d7589de7d9360613fc281 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:09:20 +0000 Subject: [PATCH 3/8] Add tests --- tests/test_restraints.py | 91 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/tests/test_restraints.py b/tests/test_restraints.py index b4db0ef4..337f71e8 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -2,7 +2,29 @@ # Licensed under the Apache License, Version 2.0. # See the LICENSE file for details. -from chai_lab.data.parsing.restraints import parse_pairwise_table +import pytest +import torch + +from chai_lab import chai1 +from chai_lab.data.collate.collate import Collate +from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext +from chai_lab.data.dataset.constraints.restraint_context import ( + load_manual_restraints_for_chai1, +) +from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext +from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw +from chai_lab.data.dataset.msas.msa_context import MSAContext +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.templates.context import TemplateContext +from chai_lab.data.parsing.msas.data_source import MSADataSource +from chai_lab.data.parsing.restraints import ( + PairwiseInteraction, + PairwiseInteractionType, + parse_pairwise_table, +) +from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.utils.paths import repo_root @@ -13,3 +35,70 @@ def test_loading_restraints(): assert len(parse_pairwise_table(contact_path)) > 0 assert len(parse_pairwise_table(pocket_path)) > 0 + + +@pytest.mark.parametrize( + "entity_name_as_subchain", + [True, False], +) +def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): + """when entity name is used as chain name, restraints are also specified by entity name.""" + inputs = [ + Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"), + Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"), + ] + + restraints = [ + PairwiseInteraction( + chainA="G", + res_idxA="G1", + atom_nameA="", + chainB="H", + res_idxB="H1", + atom_nameB="", + connection_type=PairwiseInteractionType.CONTACT, + ) + ] + + chains = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=entity_name_as_subchain + ) + assert len(chains) == 2 + + structure_context = AllAtomStructureContext.merge( + [c.structure_context for c in chains] + ) + ft_ctx = AllAtomFeatureContext( + chains=chains, + structure_context=structure_context, + msa_context=MSAContext.create_single_seq( + dataset_source=MSADataSource.QUERY, + tokens=structure_context.token_residue_type.to(dtype=torch.uint8), + ), + profile_msa_context=MSAContext.create_single_seq( + dataset_source=MSADataSource.QUERY, + tokens=structure_context.token_residue_type.to(dtype=torch.uint8), + ), + template_context=TemplateContext.empty( + n_templates=1, n_tokens=structure_context.num_tokens + ), + embedding_context=EmbeddingContext.empty(n_tokens=structure_context.num_tokens), + restraint_context=load_manual_restraints_for_chai1(chains, None, restraints), + ) + + collator = Collate( + feature_factory=chai1.feature_factory, num_key_atoms=128, num_query_atoms=32 + ) + + batch = collator([ft_ctx]) + + assert batch + ft = batch["features"] + contact_ft = ft["TokenDistanceRestraint"] + contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float()) + + if entity_name_as_subchain: + # Loaded correctly, there are + assert not contact_ft_all_null + else: + assert contact_ft_all_null From eef05b9170529a0b5b0e031bc6e2c1a5678b6c80 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:09:49 +0000 Subject: [PATCH 4/8] comments --- tests/test_restraints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_restraints.py b/tests/test_restraints.py index 337f71e8..f4b97bd5 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -98,7 +98,8 @@ def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float()) if entity_name_as_subchain: - # Loaded correctly, there are + # Loaded correctly, so some should not be null assert not contact_ft_all_null else: + # Did not load; all null assert contact_ft_all_null From f433e90b86ec64e0f23fd3f2e383ee52b4dd6f68 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:14:54 +0000 Subject: [PATCH 5/8] Update comment --- chai_lab/chai1.py | 2 +- tests/test_restraints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index d7254113..e0900f41 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -528,7 +528,7 @@ def run_inference( torch_device = torch.device(device if device is not None else "cuda:0") - # NOTE if fastas are cif chain names, we use this to parse chains as well + # NOTE if fastas are cif chain names, we also use them to parse chains and restraints feature_context = make_all_atom_feature_context( fasta_file=fasta_file, output_dir=output_dir, diff --git a/tests/test_restraints.py b/tests/test_restraints.py index f4b97bd5..486de387 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -42,7 +42,7 @@ def test_loading_restraints(): [True, False], ) def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): - """when entity name is used as chain name, restraints are also specified by entity name.""" + """when entity name is used as chain name, restraints are given w.r.t. entity name.""" inputs = [ Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"), Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"), From 2e9e67962a4fbb2614c875b670578bc94e027533 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:24:12 +0000 Subject: [PATCH 6/8] Expand test --- tests/test_restraints.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_restraints.py b/tests/test_restraints.py index 486de387..787fa581 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -57,7 +57,16 @@ def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): res_idxB="H1", atom_nameB="", connection_type=PairwiseInteractionType.CONTACT, - ) + ), + PairwiseInteraction( + chainA="G", + res_idxA="", + atom_nameA="", + chainB="H", + res_idxB="H1", + atom_nameB="", + connection_type=PairwiseInteractionType.POCKET, + ), ] chains = load_chains_from_raw( @@ -96,10 +105,14 @@ def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): ft = batch["features"] contact_ft = ft["TokenDistanceRestraint"] contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float()) + pocket_ft = ft["TokenPairPocketRestraint"] + pocket_ft_all_null = torch.allclose(pocket_ft, torch.tensor(-1).float()) if entity_name_as_subchain: # Loaded correctly, so some should not be null assert not contact_ft_all_null + assert not pocket_ft_all_null else: # Did not load; all null assert contact_ft_all_null + assert pocket_ft_all_null From 9fb6532c8422113249430b9fa9da22290fd73b0b Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:32:17 +0000 Subject: [PATCH 7/8] Even more comprehensive tests --- tests/test_restraints.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/test_restraints.py b/tests/test_restraints.py index 787fa581..60eb6af0 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -38,11 +38,18 @@ def test_loading_restraints(): @pytest.mark.parametrize( - "entity_name_as_subchain", - [True, False], + "entity_name_as_subchain,restraints_wrt_entity_names", + [ + (True, True), # Both given w.r.t. automatic names; should load correctly + (True, False), # Mismatched; should not load + (False, True), # Mismatched; should not load + (False, False), # Both w.r.t. fasta-derived entity names; should load correctly + ], ) -def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): - """when entity name is used as chain name, restraints are given w.r.t. entity name.""" +def test_restraints_with_manual_chain_names( + entity_name_as_subchain: bool, restraints_wrt_entity_names: bool +): + """subchain ID scheme and restraint scheme compatibility""" inputs = [ Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"), Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"), @@ -50,19 +57,19 @@ def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): restraints = [ PairwiseInteraction( - chainA="G", + chainA="G" if restraints_wrt_entity_names else "A", res_idxA="G1", atom_nameA="", - chainB="H", + chainB="H" if restraints_wrt_entity_names else "B", res_idxB="H1", atom_nameB="", connection_type=PairwiseInteractionType.CONTACT, ), PairwiseInteraction( - chainA="G", + chainA="G" if restraints_wrt_entity_names else "A", res_idxA="", atom_nameA="", - chainB="H", + chainB="H" if restraints_wrt_entity_names else "B", res_idxB="H1", atom_nameB="", connection_type=PairwiseInteractionType.POCKET, @@ -108,7 +115,7 @@ def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool): pocket_ft = ft["TokenPairPocketRestraint"] pocket_ft_all_null = torch.allclose(pocket_ft, torch.tensor(-1).float()) - if entity_name_as_subchain: + if entity_name_as_subchain == restraints_wrt_entity_names: # Loaded correctly, so some should not be null assert not contact_ft_all_null assert not pocket_ft_all_null From f775b637064a99eaafac90b44e7cc28446583ad6 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 3 Jun 2025 00:36:36 +0000 Subject: [PATCH 8/8] Add docstring --- chai_lab/chai1.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index e0900f41..193aa56e 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -520,6 +520,12 @@ def run_inference( # IO options fasta_names_as_cif_chains: bool = False, ) -> StructureCandidates: + """Runs inference on sequences in the provided fasta file. + + Important notes: + - If fasta_names_as_cif_chains is True, fasta entity names are used for parsing + and writing chains. Restraints must ALSO be named w.r.t. fasta names. + """ assert num_trunk_samples > 0 and num_diffn_samples > 0 if output_dir.exists(): assert not any(