Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def make_all_atom_feature_context(
fasta_file: Path,
*,
output_dir: Path,
entity_name_as_subchain: bool = False,
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
Expand Down Expand Up @@ -368,7 +369,9 @@ def make_all_atom_feature_context(
)

# Load structure context
chains = load_chains_from_raw(fasta_inputs)
chains = load_chains_from_raw(
fasta_inputs, entity_name_as_subchain=entity_name_as_subchain
)
del fasta_inputs # Do not reference inputs after creating chains from them

merged_context = AllAtomStructureContext.merge(
Expand Down Expand Up @@ -517,6 +520,12 @@ def run_inference(
# IO options
fasta_names_as_cif_chains: bool = False,
) -> StructureCandidates:
"""Runs inference on sequences in the provided fasta file.

Important notes:
- If fasta_names_as_cif_chains is True, fasta entity names are used for parsing
and writing chains. Restraints must ALSO be named w.r.t. fasta names.
"""
assert num_trunk_samples > 0 and num_diffn_samples > 0
if output_dir.exists():
assert not any(
Expand All @@ -525,9 +534,11 @@ def run_inference(

torch_device = torch.device(device if device is not None else "cuda:0")

# NOTE if fastas are cif chain names, we also use them to parse chains and restraints
feature_context = make_all_atom_feature_context(
fasta_file=fasta_file,
output_dir=output_dir,
entity_name_as_subchain=fasta_names_as_cif_chains,
use_esm_embeddings=use_esm_embeddings,
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
Expand Down
18 changes: 15 additions & 3 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def _synth_subchain_id(idx: int) -> str:


def raw_inputs_to_entitites_data(
inputs: list[Input], identifier: str = "test"
inputs: list[Input],
entity_name_as_subchain: bool = False,
identifier: str = "test",
) -> list[AllAtomEntityData]:
"""Load an entity for each raw input."""
entities = []
Expand Down Expand Up @@ -153,12 +155,20 @@ def raw_inputs_to_entitites_data(
resolution=0.0,
release_datetime=datetime.now(),
pdb_id=identifier,
source_pdb_chain_id=_synth_subchain_id(i),
source_pdb_chain_id=(
_synth_subchain_id(i)
if not entity_name_as_subchain
else input.entity_name
),
entity_name=input.entity_name,
entity_id=entity_id,
method="none",
entity_type=entity_type,
subchain_id=_synth_subchain_id(i),
subchain_id=(
_synth_subchain_id(i)
if not entity_name_as_subchain
else input.entity_name
),
original_record=input.sequence,
)
)
Expand All @@ -170,6 +180,7 @@ def raw_inputs_to_entitites_data(
def load_chains_from_raw(
inputs: list[Input],
identifier: str = "test",
entity_name_as_subchain: bool = False,
tokenizer: AllAtomResidueTokenizer | None = None,
) -> list[Chain]:
"""
Expand All @@ -183,6 +194,7 @@ def load_chains_from_raw(
# Extract the entity data from the gemmi structure.
entities: list[AllAtomEntityData] = raw_inputs_to_entitites_data(
inputs,
entity_name_as_subchain=entity_name_as_subchain,
identifier=identifier,
)

Expand Down
2 changes: 1 addition & 1 deletion chai_lab/data/parsing/restraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _parse_row(row: pd.Series) -> PairwiseInteraction:
min_dist_angstrom=row["min_distance_angstrom"],
connection_type=PairwiseInteractionType(row["connection_type"]),
confidence=row["confidence"],
comment=row["comment"],
comment="" if pd.isna(row["comment"]) else row["comment"],
)


Expand Down
17 changes: 17 additions & 0 deletions tests/test_inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,20 @@ def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer):
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand


def test_entity_names_as_subchain(tokenizer: AllAtomResidueTokenizer):
inputs = [
Input(sequence="RKDES", entity_type=EntityType.PROTEIN.value, entity_name="X"),
Input(sequence="GHGHGH", entity_type=EntityType.PROTEIN.value, entity_name="Y"),
]

chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=False, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chains] == ["A", "B"]

chain_manual_named = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=True, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chain_manual_named] == ["X", "Y"]
112 changes: 111 additions & 1 deletion tests/test_restraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,29 @@
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

from chai_lab.data.parsing.restraints import parse_pairwise_table
import pytest
import torch

from chai_lab import chai1
from chai_lab.data.collate.collate import Collate
from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext
from chai_lab.data.dataset.constraints.restraint_context import (
load_manual_restraints_for_chai1,
)
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.restraints import (
PairwiseInteraction,
PairwiseInteractionType,
parse_pairwise_table,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.utils.paths import repo_root


Expand All @@ -13,3 +35,91 @@ def test_loading_restraints():

assert len(parse_pairwise_table(contact_path)) > 0
assert len(parse_pairwise_table(pocket_path)) > 0


@pytest.mark.parametrize(
"entity_name_as_subchain,restraints_wrt_entity_names",
[
(True, True), # Both given w.r.t. automatic names; should load correctly
(True, False), # Mismatched; should not load
(False, True), # Mismatched; should not load
(False, False), # Both w.r.t. fasta-derived entity names; should load correctly
],
)
def test_restraints_with_manual_chain_names(
entity_name_as_subchain: bool, restraints_wrt_entity_names: bool
):
"""subchain ID scheme and restraint scheme compatibility"""
inputs = [
Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"),
Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"),
]

restraints = [
PairwiseInteraction(
chainA="G" if restraints_wrt_entity_names else "A",
res_idxA="G1",
atom_nameA="",
chainB="H" if restraints_wrt_entity_names else "B",
res_idxB="H1",
atom_nameB="",
connection_type=PairwiseInteractionType.CONTACT,
),
PairwiseInteraction(
chainA="G" if restraints_wrt_entity_names else "A",
res_idxA="",
atom_nameA="",
chainB="H" if restraints_wrt_entity_names else "B",
res_idxB="H1",
atom_nameB="",
connection_type=PairwiseInteractionType.POCKET,
),
]

chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=entity_name_as_subchain
)
assert len(chains) == 2

structure_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
ft_ctx = AllAtomFeatureContext(
chains=chains,
structure_context=structure_context,
msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
profile_msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
template_context=TemplateContext.empty(
n_templates=1, n_tokens=structure_context.num_tokens
),
embedding_context=EmbeddingContext.empty(n_tokens=structure_context.num_tokens),
restraint_context=load_manual_restraints_for_chai1(chains, None, restraints),
)

collator = Collate(
feature_factory=chai1.feature_factory, num_key_atoms=128, num_query_atoms=32
)

batch = collator([ft_ctx])

assert batch
ft = batch["features"]
contact_ft = ft["TokenDistanceRestraint"]
contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float())
pocket_ft = ft["TokenPairPocketRestraint"]
pocket_ft_all_null = torch.allclose(pocket_ft, torch.tensor(-1).float())

if entity_name_as_subchain == restraints_wrt_entity_names:
# Loaded correctly, so some should not be null
assert not contact_ft_all_null
assert not pocket_ft_all_null
else:
# Did not load; all null
assert contact_ft_all_null
assert pocket_ft_all_null