Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def make_all_atom_feature_context(
fasta_file: Path,
*,
output_dir: Path,
entity_name_as_subchain: bool = False,
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
Expand Down Expand Up @@ -368,7 +369,9 @@ def make_all_atom_feature_context(
)

# Load structure context
chains = load_chains_from_raw(fasta_inputs)
chains = load_chains_from_raw(
fasta_inputs, entity_name_as_subchain=entity_name_as_subchain
)
del fasta_inputs # Do not reference inputs after creating chains from them

merged_context = AllAtomStructureContext.merge(
Expand Down Expand Up @@ -525,9 +528,11 @@ def run_inference(

torch_device = torch.device(device if device is not None else "cuda:0")

# NOTE if fastas are cif chain names, we use this to parse chains as well
feature_context = make_all_atom_feature_context(
fasta_file=fasta_file,
output_dir=output_dir,
entity_name_as_subchain=fasta_names_as_cif_chains,
use_esm_embeddings=use_esm_embeddings,
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
Expand Down
18 changes: 15 additions & 3 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def _synth_subchain_id(idx: int) -> str:


def raw_inputs_to_entitites_data(
inputs: list[Input], identifier: str = "test"
inputs: list[Input],
entity_name_as_subchain: bool = False,
identifier: str = "test",
) -> list[AllAtomEntityData]:
"""Load an entity for each raw input."""
entities = []
Expand Down Expand Up @@ -153,12 +155,20 @@ def raw_inputs_to_entitites_data(
resolution=0.0,
release_datetime=datetime.now(),
pdb_id=identifier,
source_pdb_chain_id=_synth_subchain_id(i),
source_pdb_chain_id=(
_synth_subchain_id(i)
if not entity_name_as_subchain
else input.entity_name
),
entity_name=input.entity_name,
entity_id=entity_id,
method="none",
entity_type=entity_type,
subchain_id=_synth_subchain_id(i),
subchain_id=(
_synth_subchain_id(i)
if not entity_name_as_subchain
else input.entity_name
),
original_record=input.sequence,
)
)
Expand All @@ -170,6 +180,7 @@ def raw_inputs_to_entitites_data(
def load_chains_from_raw(
inputs: list[Input],
identifier: str = "test",
entity_name_as_subchain: bool = False,
tokenizer: AllAtomResidueTokenizer | None = None,
) -> list[Chain]:
"""
Expand All @@ -183,6 +194,7 @@ def load_chains_from_raw(
# Extract the entity data from the gemmi structure.
entities: list[AllAtomEntityData] = raw_inputs_to_entitites_data(
inputs,
entity_name_as_subchain=entity_name_as_subchain,
identifier=identifier,
)

Expand Down
2 changes: 1 addition & 1 deletion chai_lab/data/parsing/restraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _parse_row(row: pd.Series) -> PairwiseInteraction:
min_dist_angstrom=row["min_distance_angstrom"],
connection_type=PairwiseInteractionType(row["connection_type"]),
confidence=row["confidence"],
comment=row["comment"],
comment="" if pd.isna(row["comment"]) else row["comment"],
)


Expand Down
17 changes: 17 additions & 0 deletions tests/test_inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,20 @@ def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer):
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand


def test_entity_names_as_subchain(tokenizer: AllAtomResidueTokenizer):
inputs = [
Input(sequence="RKDES", entity_type=EntityType.PROTEIN.value, entity_name="X"),
Input(sequence="GHGHGH", entity_type=EntityType.PROTEIN.value, entity_name="Y"),
]

chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=False, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chains] == ["A", "B"]

chain_manual_named = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=True, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chain_manual_named] == ["X", "Y"]
92 changes: 91 additions & 1 deletion tests/test_restraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,29 @@
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

from chai_lab.data.parsing.restraints import parse_pairwise_table
import pytest
import torch

from chai_lab import chai1
from chai_lab.data.collate.collate import Collate
from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext
from chai_lab.data.dataset.constraints.restraint_context import (
load_manual_restraints_for_chai1,
)
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.restraints import (
PairwiseInteraction,
PairwiseInteractionType,
parse_pairwise_table,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.utils.paths import repo_root


Expand All @@ -13,3 +35,71 @@ def test_loading_restraints():

assert len(parse_pairwise_table(contact_path)) > 0
assert len(parse_pairwise_table(pocket_path)) > 0


@pytest.mark.parametrize(
"entity_name_as_subchain",
[True, False],
)
def test_restraints_with_manual_chain_names(entity_name_as_subchain: bool):
"""when entity name is used as chain name, restraints are also specified by entity name."""
inputs = [
Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"),
Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"),
]

restraints = [
PairwiseInteraction(
chainA="G",
res_idxA="G1",
atom_nameA="",
chainB="H",
res_idxB="H1",
atom_nameB="",
connection_type=PairwiseInteractionType.CONTACT,
)
]

chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=entity_name_as_subchain
)
assert len(chains) == 2

structure_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
ft_ctx = AllAtomFeatureContext(
chains=chains,
structure_context=structure_context,
msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
profile_msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
template_context=TemplateContext.empty(
n_templates=1, n_tokens=structure_context.num_tokens
),
embedding_context=EmbeddingContext.empty(n_tokens=structure_context.num_tokens),
restraint_context=load_manual_restraints_for_chai1(chains, None, restraints),
)

collator = Collate(
feature_factory=chai1.feature_factory, num_key_atoms=128, num_query_atoms=32
)

batch = collator([ft_ctx])

assert batch
ft = batch["features"]
contact_ft = ft["TokenDistanceRestraint"]
contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float())

if entity_name_as_subchain:
# Loaded correctly, so some should not be null
assert not contact_ft_all_null
else:
# Did not load; all null
assert contact_ft_all_null