diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index eb267ff..193aa56 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -339,6 +339,7 @@ def make_all_atom_feature_context( fasta_file: Path, *, output_dir: Path, + entity_name_as_subchain: bool = False, use_esm_embeddings: bool = True, use_msa_server: bool = False, msa_server_url: str = "https://api.colabfold.com", @@ -368,7 +369,9 @@ def make_all_atom_feature_context( ) # Load structure context - chains = load_chains_from_raw(fasta_inputs) + chains = load_chains_from_raw( + fasta_inputs, entity_name_as_subchain=entity_name_as_subchain + ) del fasta_inputs # Do not reference inputs after creating chains from them merged_context = AllAtomStructureContext.merge( @@ -517,6 +520,12 @@ def run_inference( # IO options fasta_names_as_cif_chains: bool = False, ) -> StructureCandidates: + """Runs inference on sequences in the provided fasta file. + + Important notes: + - If fasta_names_as_cif_chains is True, fasta entity names are used for parsing + and writing chains. Restraints must ALSO be named w.r.t. fasta names. + """ assert num_trunk_samples > 0 and num_diffn_samples > 0 if output_dir.exists(): assert not any( @@ -525,9 +534,11 @@ def run_inference( torch_device = torch.device(device if device is not None else "cuda:0") + # NOTE if fastas are cif chain names, we also use them to parse chains and restraints feature_context = make_all_atom_feature_context( fasta_file=fasta_file, output_dir=output_dir, + entity_name_as_subchain=fasta_names_as_cif_chains, use_esm_embeddings=use_esm_embeddings, use_msa_server=use_msa_server, msa_server_url=msa_server_url, diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index 7ac2fd7..094f95e 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -92,7 +92,9 @@ def _synth_subchain_id(idx: int) -> str: def raw_inputs_to_entitites_data( - inputs: list[Input], identifier: str = "test" + inputs: list[Input], + entity_name_as_subchain: bool = False, + identifier: str = "test", ) -> list[AllAtomEntityData]: """Load an entity for each raw input.""" entities = [] @@ -153,12 +155,20 @@ def raw_inputs_to_entitites_data( resolution=0.0, release_datetime=datetime.now(), pdb_id=identifier, - source_pdb_chain_id=_synth_subchain_id(i), + source_pdb_chain_id=( + _synth_subchain_id(i) + if not entity_name_as_subchain + else input.entity_name + ), entity_name=input.entity_name, entity_id=entity_id, method="none", entity_type=entity_type, - subchain_id=_synth_subchain_id(i), + subchain_id=( + _synth_subchain_id(i) + if not entity_name_as_subchain + else input.entity_name + ), original_record=input.sequence, ) ) @@ -170,6 +180,7 @@ def raw_inputs_to_entitites_data( def load_chains_from_raw( inputs: list[Input], identifier: str = "test", + entity_name_as_subchain: bool = False, tokenizer: AllAtomResidueTokenizer | None = None, ) -> list[Chain]: """ @@ -183,6 +194,7 @@ def load_chains_from_raw( # Extract the entity data from the gemmi structure. entities: list[AllAtomEntityData] = raw_inputs_to_entitites_data( inputs, + entity_name_as_subchain=entity_name_as_subchain, identifier=identifier, ) diff --git a/chai_lab/data/parsing/restraints.py b/chai_lab/data/parsing/restraints.py index 294c745..c4820ec 100644 --- a/chai_lab/data/parsing/restraints.py +++ b/chai_lab/data/parsing/restraints.py @@ -166,7 +166,7 @@ def _parse_row(row: pd.Series) -> PairwiseInteraction: min_dist_angstrom=row["min_distance_angstrom"], connection_type=PairwiseInteractionType(row["connection_type"]), confidence=row["confidence"], - comment=row["comment"], + comment="" if pd.isna(row["comment"]) else row["comment"], ) diff --git a/tests/test_inference_dataset.py b/tests/test_inference_dataset.py index c769eb8..5571bd6 100644 --- a/tests/test_inference_dataset.py +++ b/tests/test_inference_dataset.py @@ -104,3 +104,20 @@ def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer): example.token_entity_type == EntityType.LIGAND.value ] assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand + + +def test_entity_names_as_subchain(tokenizer: AllAtomResidueTokenizer): + inputs = [ + Input(sequence="RKDES", entity_type=EntityType.PROTEIN.value, entity_name="X"), + Input(sequence="GHGHGH", entity_type=EntityType.PROTEIN.value, entity_name="Y"), + ] + + chains = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=False, tokenizer=tokenizer + ) + assert [chain.entity_data.subchain_id for chain in chains] == ["A", "B"] + + chain_manual_named = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=True, tokenizer=tokenizer + ) + assert [chain.entity_data.subchain_id for chain in chain_manual_named] == ["X", "Y"] diff --git a/tests/test_restraints.py b/tests/test_restraints.py index b4db0ef..60eb6af 100644 --- a/tests/test_restraints.py +++ b/tests/test_restraints.py @@ -2,7 +2,29 @@ # Licensed under the Apache License, Version 2.0. # See the LICENSE file for details. -from chai_lab.data.parsing.restraints import parse_pairwise_table +import pytest +import torch + +from chai_lab import chai1 +from chai_lab.data.collate.collate import Collate +from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext +from chai_lab.data.dataset.constraints.restraint_context import ( + load_manual_restraints_for_chai1, +) +from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext +from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw +from chai_lab.data.dataset.msas.msa_context import MSAContext +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) +from chai_lab.data.dataset.templates.context import TemplateContext +from chai_lab.data.parsing.msas.data_source import MSADataSource +from chai_lab.data.parsing.restraints import ( + PairwiseInteraction, + PairwiseInteractionType, + parse_pairwise_table, +) +from chai_lab.data.parsing.structure.entity_type import EntityType from chai_lab.utils.paths import repo_root @@ -13,3 +35,91 @@ def test_loading_restraints(): assert len(parse_pairwise_table(contact_path)) > 0 assert len(parse_pairwise_table(pocket_path)) > 0 + + +@pytest.mark.parametrize( + "entity_name_as_subchain,restraints_wrt_entity_names", + [ + (True, True), # Both given w.r.t. automatic names; should load correctly + (True, False), # Mismatched; should not load + (False, True), # Mismatched; should not load + (False, False), # Both w.r.t. fasta-derived entity names; should load correctly + ], +) +def test_restraints_with_manual_chain_names( + entity_name_as_subchain: bool, restraints_wrt_entity_names: bool +): + """subchain ID scheme and restraint scheme compatibility""" + inputs = [ + Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"), + Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"), + ] + + restraints = [ + PairwiseInteraction( + chainA="G" if restraints_wrt_entity_names else "A", + res_idxA="G1", + atom_nameA="", + chainB="H" if restraints_wrt_entity_names else "B", + res_idxB="H1", + atom_nameB="", + connection_type=PairwiseInteractionType.CONTACT, + ), + PairwiseInteraction( + chainA="G" if restraints_wrt_entity_names else "A", + res_idxA="", + atom_nameA="", + chainB="H" if restraints_wrt_entity_names else "B", + res_idxB="H1", + atom_nameB="", + connection_type=PairwiseInteractionType.POCKET, + ), + ] + + chains = load_chains_from_raw( + inputs=inputs, entity_name_as_subchain=entity_name_as_subchain + ) + assert len(chains) == 2 + + structure_context = AllAtomStructureContext.merge( + [c.structure_context for c in chains] + ) + ft_ctx = AllAtomFeatureContext( + chains=chains, + structure_context=structure_context, + msa_context=MSAContext.create_single_seq( + dataset_source=MSADataSource.QUERY, + tokens=structure_context.token_residue_type.to(dtype=torch.uint8), + ), + profile_msa_context=MSAContext.create_single_seq( + dataset_source=MSADataSource.QUERY, + tokens=structure_context.token_residue_type.to(dtype=torch.uint8), + ), + template_context=TemplateContext.empty( + n_templates=1, n_tokens=structure_context.num_tokens + ), + embedding_context=EmbeddingContext.empty(n_tokens=structure_context.num_tokens), + restraint_context=load_manual_restraints_for_chai1(chains, None, restraints), + ) + + collator = Collate( + feature_factory=chai1.feature_factory, num_key_atoms=128, num_query_atoms=32 + ) + + batch = collator([ft_ctx]) + + assert batch + ft = batch["features"] + contact_ft = ft["TokenDistanceRestraint"] + contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float()) + pocket_ft = ft["TokenPairPocketRestraint"] + pocket_ft_all_null = torch.allclose(pocket_ft, torch.tensor(-1).float()) + + if entity_name_as_subchain == restraints_wrt_entity_names: + # Loaded correctly, so some should not be null + assert not contact_ft_all_null + assert not pocket_ft_all_null + else: + # Did not load; all null + assert contact_ft_all_null + assert pocket_ft_all_null