@@ -241,11 +241,10 @@ def inner(*args, **kwargs):
241
241
242
242
# get atompair bonds functions
243
243
244
- ATOM_BOND_INDEX = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
245
-
246
244
@typecheck
247
245
def get_atompair_ids (
248
246
mol : Mol ,
247
+ atom_bonds : List [str ],
249
248
directed_bonds : bool
250
249
) -> Int ['m m' ] | None :
251
250
@@ -258,8 +257,10 @@ def get_atompair_ids(
258
257
bonds = mol .GetBonds ()
259
258
num_bonds = len (bonds )
260
259
261
- num_atom_bond_types = len (ATOM_BOND_INDEX )
262
- other_index = len (ATOM_BONDS ) + 1
260
+ atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (atom_bonds )}
261
+
262
+ num_atom_bond_types = len (atom_bond_index )
263
+ other_index = len (atom_bond_index ) + 1
263
264
264
265
for bond in bonds :
265
266
atom_start_index = bond .GetBeginAtomIdx ()
@@ -273,7 +274,7 @@ def get_atompair_ids(
273
274
)
274
275
275
276
bond_type = bond .GetBondType ()
276
- bond_id = ATOM_BOND_INDEX .get (bond_type , other_index ) + 1
277
+ bond_id = atom_bond_index .get (bond_type , other_index ) + 1
277
278
278
279
# default to symmetric bond type (undirected atom bonds)
279
280
@@ -761,7 +762,8 @@ class MoleculeInput:
761
762
directed_bonds : bool = False
762
763
extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
763
764
extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
764
- custom_atoms : List [str ]| None = None
765
+ custom_atoms : List [str ] | None = None
766
+ custom_bonds : List [str ] | None = None
765
767
766
768
@typecheck
767
769
def molecule_to_atom_input (mol_input : MoleculeInput ) -> AtomInput :
@@ -891,6 +893,8 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
891
893
prev_mol = None
892
894
prev_src_tgt_atom_indices = None
893
895
896
+ atom_bonds = default (i .custom_bonds , ATOM_BONDS )
897
+
894
898
for (
895
899
mol ,
896
900
mol_id ,
@@ -914,7 +918,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
914
918
should_cache = is_chainable_biomolecule .item ()
915
919
)
916
920
917
- mol_atompair_ids = maybe_cached_get_atompair_ids (mol , directed_bonds = i .directed_bonds )
921
+ mol_atompair_ids = maybe_cached_get_atompair_ids (mol , atom_bonds , directed_bonds = i .directed_bonds )
918
922
919
923
# /einx.set_at
920
924
@@ -1103,7 +1107,8 @@ class MoleculeLengthMoleculeInput:
1103
1107
directed_bonds : bool = False
1104
1108
extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
1105
1109
extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
1106
- custom_atoms : List [str ]| None = None
1110
+ custom_atoms : List [str ] | None = None
1111
+ custom_bonds : List [str ] | None = None
1107
1112
1108
1113
1109
1114
@typecheck
@@ -1354,6 +1359,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
1354
1359
prev_mol = None
1355
1360
prev_src_tgt_atom_indices = None
1356
1361
1362
+ atom_bonds = default (i .custom_bonds , ATOM_BONDS )
1363
+
1357
1364
for (
1358
1365
mol ,
1359
1366
mol_id ,
@@ -1377,7 +1384,7 @@ def molecule_lengthed_molecule_input_to_atom_input(
1377
1384
should_cache = is_chainable_biomolecule .item ()
1378
1385
)
1379
1386
1380
- mol_atompair_ids = maybe_cached_get_atompair_ids (mol , directed_bonds = i .directed_bonds )
1387
+ mol_atompair_ids = maybe_cached_get_atompair_ids (mol , atom_bonds , directed_bonds = i .directed_bonds )
1381
1388
1382
1389
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
1383
1390
@@ -1553,6 +1560,7 @@ class Alphafold3Input:
1553
1560
extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
1554
1561
extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
1555
1562
custom_atoms : List [str ] | None = None
1563
+ custom_bonds : List [str ] | None = None
1556
1564
1557
1565
@typecheck
1558
1566
def map_int_or_string_indices_to_mol (
@@ -1999,7 +2007,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
1999
2007
directed_bonds = i .directed_bonds ,
2000
2008
extract_atom_feats_fn = i .extract_atom_feats_fn ,
2001
2009
extract_atompair_feats_fn = i .extract_atompair_feats_fn ,
2002
- custom_atoms = i .custom_atoms
2010
+ custom_atoms = i .custom_atoms ,
2011
+ custom_bonds = i .custom_bonds
2003
2012
)
2004
2013
2005
2014
return molecule_input
@@ -2166,7 +2175,8 @@ class PDBInput:
2166
2175
add_atom_ids : bool = False
2167
2176
add_atompair_ids : bool = False
2168
2177
directed_bonds : bool = False
2169
- custom_atoms : List [str ]| None = None
2178
+ custom_atoms : List [str ] | None = None
2179
+ custom_bonds : List [str ] | None = None
2170
2180
training : bool = False
2171
2181
inference : bool = False
2172
2182
distillation : bool = False
@@ -3982,7 +3992,8 @@ def pdb_input_to_molecule_input(
3982
3992
directed_bonds = i .directed_bonds ,
3983
3993
extract_atom_feats_fn = i .extract_atom_feats_fn ,
3984
3994
extract_atompair_feats_fn = i .extract_atompair_feats_fn ,
3985
- custom_atoms = i .custom_atoms
3995
+ custom_atoms = i .custom_atoms ,
3996
+ custom_bonds = i .custom_bonds
3986
3997
)
3987
3998
3988
3999
return molecule_input
0 commit comments