-
Notifications
You must be signed in to change notification settings - Fork 261
Expand file tree
/
Copy pathtest_inference_dataset.py
More file actions
123 lines (104 loc) · 5.11 KB
/
test_inference_dataset.py
File metadata and controls
123 lines (104 loc) · 5.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
"""
Tests for inference dataset.
"""
import pytest
import torch
from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw
from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import (
AllAtomResidueTokenizer,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.sources.rdkit import RefConformerGenerator
@pytest.fixture
def tokenizer() -> AllAtomResidueTokenizer:
return AllAtomResidueTokenizer(RefConformerGenerator())
def test_malformed_smiles(tokenizer: AllAtomResidueTokenizer):
"""Malformed SMILES should be dropped."""
# Zn ligand is malformed (should be [Zn+2])
inputs = [
Input("RKDESES", entity_type=EntityType.PROTEIN.value, entity_name="foo"),
Input("Zn", entity_type=EntityType.LIGAND.value, entity_name="bar"),
Input("RKEEE", entity_type=EntityType.PROTEIN.value, entity_name="baz"),
Input("EEEEEEEEEEEE", entity_type=EntityType.PROTEIN.value, entity_name="boz"),
]
chains = load_chains_from_raw(
inputs,
identifier="test",
tokenizer=tokenizer,
)
assert len(chains) == 3
for chain in chains:
# NOTE this check is only valid because there are no residues that are tokenized per-atom
# Ensures that the entity data and the structure context in each chain are paired correctly
assert chain.structure_context.num_tokens == len(
chain.entity_data.full_sequence
)
def test_ions_parsing(tokenizer: AllAtomResidueTokenizer):
"""Ions as SMILES strings should carry the correct charge."""
inputs = [Input("[Mg+2]", entity_type=EntityType.LIGAND.value, entity_name="foo")]
chains = load_chains_from_raw(inputs, identifier="foo", tokenizer=tokenizer)
assert len(chains) == 1
chain = chains[0]
assert chain.structure_context.num_atoms == 1
assert chain.structure_context.atom_ref_charge == 2
assert chain.structure_context.atom_ref_element.item() == 12
def test_protein_with_smiles(tokenizer: AllAtomResidueTokenizer):
"""Complex with multiple duplicated protein chains and SMILES ligands."""
# Based on https://www.rcsb.org/structure/1AFS
seq = "MDSISLRVALNDGNFIPVLGFGTTVPEKVAKDEVIKATKIAIDNGFRHFDSAYLYEVEEEVGQAIRSKIEDGTVKREDIFYTSKLWSTFHRPELVRTCLEKTLKSTQLDYVDLYIIHFPMALQPGDIFFPRDEHGKLLFETVDICDTWEAMEKCKDAGLAKSIGVSNFNCRQLERILNKPGLKYKPVCNQVECHLYLNQSKMLDYCKSKDIILVSYCTLGSSRDKTWVDQKSPVLLDDPVLCAIAKKYKQTPALVALRYQLQRGVVPLIRSFNAKRIKELTQVFEFQLASEDMKALDGLNRNFRYNNAKYFDDHPNHPFTDEN"
nap = "NC(=O)c1ccc[n+](c1)[CH]2O[CH](CO[P]([O-])(=O)O[P](O)(=O)OC[CH]3O[CH]([CH](O[P](O)(O)=O)[CH]3O)n4cnc5c(N)ncnc45)[CH](O)[CH]2O"
tes = "O=C4C=C3C(C2CCC1(C(CCC1O)C2CC3)C)(C)CC4"
inputs = [
Input(seq, EntityType.PROTEIN.value, entity_name="A"),
Input(seq, EntityType.PROTEIN.value, entity_name="B"),
Input(nap, EntityType.LIGAND.value, entity_name="C"),
Input(nap, EntityType.LIGAND.value, entity_name="D"),
Input(tes, EntityType.LIGAND.value, entity_name="E"),
Input(tes, EntityType.LIGAND.value, entity_name="F"),
]
chains: list[Chain] = load_chains_from_raw(inputs, tokenizer=tokenizer)
assert len(chains) == len(inputs)
example = AllAtomStructureContext.merge(
[chain.structure_context for chain in chains]
)
# Should be 1 protein chain, 2 ligand chains
assert example.token_entity_id.unique().numel() == 3
assert example.token_asym_id.unique().numel() == 6
# Check protein chains
prot_entity_ids = example.token_entity_id[
example.token_entity_type == EntityType.PROTEIN.value
]
assert torch.unique(prot_entity_ids).numel() == 1
prot_sym_ids = example.token_sym_id[
example.token_entity_type == EntityType.PROTEIN.value
]
assert torch.unique(prot_sym_ids).numel() == 2 # Two copies of this chain
# Check ligand chains
lig_entity_ids = example.token_entity_id[
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_entity_ids).numel() == 2
lig_sym_ids = example.token_sym_id[
example.token_entity_type == EntityType.LIGAND.value
]
assert torch.unique(lig_sym_ids).numel() == 2 # Two copies of each ligand
def test_entity_names_as_subchain(tokenizer: AllAtomResidueTokenizer):
inputs = [
Input(sequence="RKDES", entity_type=EntityType.PROTEIN.value, entity_name="X"),
Input(sequence="GHGHGH", entity_type=EntityType.PROTEIN.value, entity_name="Y"),
]
chains = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=False, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chains] == ["A", "B"]
chain_manual_named = load_chains_from_raw(
inputs=inputs, entity_name_as_subchain=True, tokenizer=tokenizer
)
assert [chain.entity_data.subchain_id for chain in chain_manual_named] == ["X", "Y"]