-
Notifications
You must be signed in to change notification settings - Fork 262
Expand file tree
/
Copy pathtest_restraints.py
More file actions
125 lines (111 loc) · 4.6 KB
/
test_restraints.py
File metadata and controls
125 lines (111 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
import pytest
import torch
from chai_lab import chai1
from chai_lab.data.collate.collate import Collate
from chai_lab.data.dataset.all_atom_feature_context import AllAtomFeatureContext
from chai_lab.data.dataset.constraints.restraint_context import (
load_manual_restraints_for_chai1,
)
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.templates.context import TemplateContext
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.restraints import (
PairwiseInteraction,
PairwiseInteractionType,
parse_pairwise_table,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.utils.paths import repo_root
def test_loading_restraints():
"""Small test to ensure that restraints can be loaded."""
contact_path = repo_root / "examples" / "restraints" / "contact.restraints"
pocket_path = repo_root / "examples" / "restraints" / "pocket.restraints"
assert len(parse_pairwise_table(contact_path)) > 0
assert len(parse_pairwise_table(pocket_path)) > 0
@pytest.mark.parametrize(
"entity_name_as_subchain,restraints_wrt_entity_names",
[
(True, True), # Both given w.r.t. automatic names; should load correctly
(True, False), # Mismatched; should not load
(False, True), # Mismatched; should not load
(False, False), # Both w.r.t. fasta-derived entity names; should load correctly
],
)
def test_restraints_with_manual_chain_names(
entity_name_as_subchain: bool, restraints_wrt_entity_names: bool
):
"""subchain ID scheme and restraint scheme compatibility"""
inputs = [
Input("GGGGGG", entity_type=EntityType.PROTEIN.value, entity_name="G"),
Input("HHHHHH", entity_type=EntityType.PROTEIN.value, entity_name="H"),
]
restraints = [
PairwiseInteraction(
chainA="G" if restraints_wrt_entity_names else "A",
res_idxA="G1",
atom_nameA="",
chainB="H" if restraints_wrt_entity_names else "B",
res_idxB="H1",
atom_nameB="",
connection_type=PairwiseInteractionType.CONTACT,
),
PairwiseInteraction(
chainA="G" if restraints_wrt_entity_names else "A",
res_idxA="",
atom_nameA="",
chainB="H" if restraints_wrt_entity_names else "B",
res_idxB="H1",
atom_nameB="",
connection_type=PairwiseInteractionType.POCKET,
),
]
chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=entity_name_as_subchain
)
assert len(chains) == 2
structure_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
ft_ctx = AllAtomFeatureContext(
chains=chains,
structure_context=structure_context,
msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
profile_msa_context=MSAContext.create_single_seq(
dataset_source=MSADataSource.QUERY,
tokens=structure_context.token_residue_type.to(dtype=torch.uint8),
),
template_context=TemplateContext.empty(
n_templates=1, n_tokens=structure_context.num_tokens
),
embedding_context=EmbeddingContext.empty(n_tokens=structure_context.num_tokens),
restraint_context=load_manual_restraints_for_chai1(chains, None, restraints),
)
collator = Collate(
feature_factory=chai1.feature_factory, num_key_atoms=128, num_query_atoms=32
)
batch = collator([ft_ctx])
assert batch
ft = batch["features"]
contact_ft = ft["TokenDistanceRestraint"]
contact_ft_all_null = torch.allclose(contact_ft, torch.tensor(-1).float())
pocket_ft = ft["TokenPairPocketRestraint"]
pocket_ft_all_null = torch.allclose(pocket_ft, torch.tensor(-1).float())
if entity_name_as_subchain == restraints_wrt_entity_names:
# Loaded correctly, so some should not be null
assert not contact_ft_all_null
assert not pocket_ft_all_null
else:
# Did not load; all null
assert contact_ft_all_null
assert pocket_ft_all_null