diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 582b584a9..8402c106e 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -138,16 +138,27 @@ def __init__(self, name="Topology", box=None): self._impropers = IndexedSet() self._subtops = IndexedSet() self._atom_types = {} + self._atom_types_associations = {} self._atom_types_idx = {} + self._connection_types = {} + self._bond_types = {} + self._bond_types_associations = {} self._bond_types_idx = {} + self._angle_types = {} + self._angle_type_associations = {} self._angle_types_idx = {} + self._dihedral_types = {} + self._dihedral_types_associations = {} self._dihedral_types_idx = {} + self._improper_types = {} + self._improper_types_associations = {} self._improper_types_idx = {} + self._combining_rule = 'lorentz' self._set_refs = { ATOM_TYPE_DICT: self._atom_types, @@ -156,6 +167,13 @@ def __init__(self, name="Topology", box=None): DIHEDRAL_TYPE_DICT: self._dihedral_types, IMPROPER_TYPE_DICT: self._improper_types, } + self._association_refs = { + ATOM_TYPE_DICT: self._atom_types_associations, + BOND_TYPE_DICT: self._bond_types_associations, + ANGLE_TYPE_DICT: self._angle_type_associations, + DIHEDRAL_TYPE_DICT: self._dihedral_types_associations, + IMPROPER_TYPE_DICT: self._improper_types_associations + } self._index_refs = { ATOM_TYPE_DICT: self._atom_types_idx, @@ -331,11 +349,13 @@ def add_site(self, site, update_types=True): self._sites.add(site) if update_types and site.atom_type: site.atom_type.topology = self - if site.atom_type in self._atom_types: - site.atom_type = self._atom_types[site.atom_type] - else: - self._atom_types[site.atom_type] = site.atom_type - self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1 + site.atom_type = self._atom_types.get(site.atom_type, site.atom_type) + conns = self._atom_types_associations.get(site.atom_type, set()) + self._atom_types[site.atom_type] = site.atom_type + self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1 + conns.add(site) + self._atom_types_associations[site.atom_type] = conns + self.is_typed(updated=False) def update_sites(self): @@ -501,25 +521,34 @@ def update_connection_types(self): self._connection_types[c.connection_type] = c.connection_type if isinstance(c.connection_type, BondType): self._bond_types[c.connection_type] = c.connection_type + self._bond_types_associations[c.connection_type] = {c} self._bond_types_idx[c.connection_type] = len(self._bond_types) - 1 if isinstance(c.connection_type, AngleType): self._angle_types[c.connection_type] = c.connection_type + self._angle_type_associations[c.connection_type] = {c} self._angle_types_idx[c.connection_type] = len(self._bond_types) - 1 if isinstance(c.connection_type, DihedralType): self._dihedral_types[c.connection_type] = c.connection_type + self._dihedral_types_associations[c.connection_type] = {c} self._dihedral_types_idx[c.connection_type] = len(self._bond_types) - 1 if isinstance(c.connection_type, ImproperType): self._improper_types[c.connection_type] = c.connection_type + self._improper_types_associations[c.connection_type] = {c} self._improper_types_idx[c.connection_type] = len(self._bond_types) - 1 - elif c.connection_type in self.connection_types: + + elif c.connection_type in self._connection_types: if isinstance(c.connection_type, BondType): c.connection_type = self._bond_types[c.connection_type] + self._bond_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, AngleType): c.connection_type = self._angle_types[c.connection_type] + self._angle_type_associations[c.connection_type].add(c) if isinstance(c.connection_type, DihedralType): c.connection_type = self._dihedral_types[c.connection_type] + self._dihedral_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, ImproperType): c.connection_type = self._improper_types[c.connection_type] + self._improper_types_associations[c.connection_type].add(c) def update_atom_types(self): """Update atom types in the topology @@ -541,9 +570,12 @@ def update_atom_types(self): elif site.atom_type not in self._atom_types: site.atom_type.topology = self self._atom_types[site.atom_type] = site.atom_type + self._atom_types_associations[site.atom_type] = {site} self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1 + elif site.atom_type in self._atom_types: site.atom_type = self._atom_types[site.atom_type] + self._atom_types_associations[site.atom_type].add(site) self.is_typed(updated=True) def add_subtopology(self, subtop): @@ -633,6 +665,48 @@ def update_topology(self): self.update_connection_types() self.is_typed(updated=True) + def get_associations(self, conn_or_atom_type): + """ Return objects associated with a connection or atom type in the topology + + This method takes `conn_or_atom_type` and returns the number of + sites(if `conn_or_atom_type` is of type GMSO.AtomType) or connections + (i.e. bonds, angles, dihedrals or impropers (if `conn_or_atom_type` + is one of gmso.BondType, gmso.AngleType, + gmso.DihedralType gmso.ImproperType respectively)) associated with it. + + Parameters + ---------- + conn_or_atom_type : gmso.AtomType or gmso.BondType or gmso.AngleType or gmso.DihedralType or gmso.ImproperType + The connection_type for which to return the association for + + Returns + ------- + set + A set of sites or connections associated with `conn_or_atom_type` + + """ + if conn_or_atom_type not in self._get_ref(conn_or_atom_type, type_='items'): + raise GMSOError(f'{conn_or_atom_type} is not associated with any items in the topology') + association_dict = self._get_ref(conn_or_atom_type, type_='associations') + return set(association_dict[conn_or_atom_type]) + + def _get_ref(self, conn_or_atom_type, type_='items'): + """Get book keeping reference dictionary for the object""" + assert type_ in ('items', 'associations') + ref_dict = { + AtomType: ATOM_TYPE_DICT, + BondType: BOND_TYPE_DICT, + AngleType: ATOM_TYPE_DICT, + DihedralType: DIHEDRAL_TYPE_DICT, + ImproperType: IMPROPER_TYPE_DICT + } + if type_ == 'items': + _container_dict = self._set_refs + elif type_ == 'associations': + _container_dict = self._association_refs + + return _container_dict[ref_dict[type(conn_or_atom_type)]] + def get_index(self, member): """Get index of a member in the topology diff --git a/gmso/tests/test_topology.py b/gmso/tests/test_topology.py index 374a07aaa..da2cc7f0a 100644 --- a/gmso/tests/test_topology.py +++ b/gmso/tests/test_topology.py @@ -419,6 +419,26 @@ def test_topology_atom_type_changes(self): assert top.sites[10].atom_type.name == 'atom_type_changed' assert top.is_typed() + def test_get_associations_atom_types(self): + top = Topology() + atom_type = AtomType(name='test_atomtype') + for i in range(10): + top.add_site(Site(name=f'site{i+1}', atom_type=atom_type), update_types=True) + assert len(top.get_associations(atom_type)) == 10 + + def test_get_association_atom_types_after_changes(self, typed_ar_system): + typed_ar_system.atom_types[0].name = 'Typed Ar System' + assert typed_ar_system.get_associations(typed_ar_system.atom_types[0]) == set(typed_ar_system.sites) + + def test_get_association_bond_types(self, typed_water_system): + bond_type = typed_water_system.bond_types[0] + assert len(typed_water_system.get_associations(bond_type)) == 4 + + def test_get_association_bond_types_after_changes(self, typed_water_system): + bond_type = typed_water_system.bond_types[0] + bond_type.name = 'Typed water system' + assert type(typed_water_system.get_associations(bond_type)) == set + def test_topology_get_index(self): top = Topology() conn_members = [Site(), Site(), Site(), Site()] @@ -487,3 +507,4 @@ def test_topology_get_index_angle_type_after_change(self, typed_methylnitroanili prev_idx = typed_methylnitroaniline.get_index(angle_type_to_test) typed_methylnitroaniline.angles[0].connection_type.name = 'changed name' assert typed_methylnitroaniline.get_index(angle_type_to_test) != prev_idx + diff --git a/gmso/utils/decorators.py b/gmso/utils/decorators.py index 5265de203..782517b21 100644 --- a/gmso/utils/decorators.py +++ b/gmso/utils/decorators.py @@ -10,8 +10,10 @@ def confirm_dict_existence(setter_function): def setter_with_dict_removal(self, *args, **kwargs): if self._topology: self._topology._set_refs[self._set_ref].pop(self, None) + prev_associations = self._topology._association_refs[self._set_ref].pop(self, set()) setter_function(self, *args, **kwargs) self._topology._set_refs[self._set_ref][self] = (self) + self._topology._association_refs[self._set_ref][self] = prev_associations self._topology._reindex_connection_types(self._set_ref) else: setter_function(self, *args, **kwargs)