Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2aac06d

Browse files
authoredSep 24, 2024
also allow for customizing atom bonds (#292)
* also allow for customizing atom bonds * fix pdbinput and bump version
1 parent dee1e97 commit 2aac06d

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed
 

‎alphafold3_pytorch/inputs.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,10 @@ def inner(*args, **kwargs):
241241

242242
# get atompair bonds functions
243243

244-
ATOM_BOND_INDEX = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
245-
246244
@typecheck
247245
def get_atompair_ids(
248246
mol: Mol,
247+
atom_bonds: List[str],
249248
directed_bonds: bool
250249
) -> Int['m m'] | None:
251250

@@ -258,8 +257,10 @@ def get_atompair_ids(
258257
bonds = mol.GetBonds()
259258
num_bonds = len(bonds)
260259

261-
num_atom_bond_types = len(ATOM_BOND_INDEX)
262-
other_index = len(ATOM_BONDS) + 1
260+
atom_bond_index = {symbol: (idx + 1) for idx, symbol in enumerate(atom_bonds)}
261+
262+
num_atom_bond_types = len(atom_bond_index)
263+
other_index = len(atom_bond_index) + 1
263264

264265
for bond in bonds:
265266
atom_start_index = bond.GetBeginAtomIdx()
@@ -273,7 +274,7 @@ def get_atompair_ids(
273274
)
274275

275276
bond_type = bond.GetBondType()
276-
bond_id = ATOM_BOND_INDEX.get(bond_type, other_index) + 1
277+
bond_id = atom_bond_index.get(bond_type, other_index) + 1
277278

278279
# default to symmetric bond type (undirected atom bonds)
279280

@@ -761,7 +762,8 @@ class MoleculeInput:
761762
directed_bonds: bool = False
762763
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
763764
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
764-
custom_atoms: List[str]| None = None
765+
custom_atoms: List[str] | None = None
766+
custom_bonds: List[str] | None = None
765767

766768
@typecheck
767769
def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
@@ -891,6 +893,8 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
891893
prev_mol = None
892894
prev_src_tgt_atom_indices = None
893895

896+
atom_bonds = default(i.custom_bonds, ATOM_BONDS)
897+
894898
for (
895899
mol,
896900
mol_id,
@@ -914,7 +918,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
914918
should_cache = is_chainable_biomolecule.item()
915919
)
916920

917-
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, directed_bonds = i.directed_bonds)
921+
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, atom_bonds, directed_bonds = i.directed_bonds)
918922

919923
# /einx.set_at
920924

@@ -1103,7 +1107,8 @@ class MoleculeLengthMoleculeInput:
11031107
directed_bonds: bool = False
11041108
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
11051109
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
1106-
custom_atoms: List[str]| None = None
1110+
custom_atoms: List[str] | None = None
1111+
custom_bonds: List[str] | None = None
11071112

11081113

11091114
@typecheck
@@ -1354,6 +1359,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
13541359
prev_mol = None
13551360
prev_src_tgt_atom_indices = None
13561361

1362+
atom_bonds = default(i.custom_bonds, ATOM_BONDS)
1363+
13571364
for (
13581365
mol,
13591366
mol_id,
@@ -1377,7 +1384,7 @@ def molecule_lengthed_molecule_input_to_atom_input(
13771384
should_cache = is_chainable_biomolecule.item()
13781385
)
13791386

1380-
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, directed_bonds = i.directed_bonds)
1387+
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, atom_bonds, directed_bonds = i.directed_bonds)
13811388

13821389
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
13831390

@@ -1553,6 +1560,7 @@ class Alphafold3Input:
15531560
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
15541561
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
15551562
custom_atoms: List[str] | None = None
1563+
custom_bonds: List[str] | None = None
15561564

15571565
@typecheck
15581566
def map_int_or_string_indices_to_mol(
@@ -1999,7 +2007,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
19992007
directed_bonds=i.directed_bonds,
20002008
extract_atom_feats_fn=i.extract_atom_feats_fn,
20012009
extract_atompair_feats_fn=i.extract_atompair_feats_fn,
2002-
custom_atoms=i.custom_atoms
2010+
custom_atoms=i.custom_atoms,
2011+
custom_bonds=i.custom_bonds
20032012
)
20042013

20052014
return molecule_input
@@ -2166,7 +2175,8 @@ class PDBInput:
21662175
add_atom_ids: bool = False
21672176
add_atompair_ids: bool = False
21682177
directed_bonds: bool = False
2169-
custom_atoms: List[str]| None = None
2178+
custom_atoms: List[str] | None = None
2179+
custom_bonds: List[str] | None = None
21702180
training: bool = False
21712181
inference: bool = False
21722182
distillation: bool = False
@@ -3982,7 +3992,8 @@ def pdb_input_to_molecule_input(
39823992
directed_bonds=i.directed_bonds,
39833993
extract_atom_feats_fn=i.extract_atom_feats_fn,
39843994
extract_atompair_feats_fn=i.extract_atompair_feats_fn,
3985-
custom_atoms=i.custom_atoms
3995+
custom_atoms=i.custom_atoms,
3996+
custom_bonds=i.custom_bonds
39863997
)
39873998

39883999
return molecule_input

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.42"
3+
version = "0.5.43"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" },

0 commit comments

Comments
 (0)