From b7f6cc665dbb22c7b24b4f94cba6bdb8db7606a8 Mon Sep 17 00:00:00 2001 From: Umesh Timalsina Date: Fri, 27 Mar 2020 13:28:13 -0500 Subject: [PATCH 1/2] Add get_associations method to topology class adds get_association method to return a set of sites associated with an atomtype or a set of bonds associated with a bond_type and so on. --- gmso/core/topology.py | 71 +++++++++++++++++++++++++++++++++++-- gmso/tests/test_topology.py | 19 ++++++++++ gmso/utils/decorators.py | 2 ++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 7dbf41d6d..f019a1239 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -137,11 +137,16 @@ def __init__(self, name="Topology", box=None): self._impropers = IndexedSet() self._subtops = IndexedSet() self._atom_types = {} + self._atom_types_associations = {} self._connection_types = {} self._bond_types = {} + self._bond_types_associations = {} self._angle_types = {} + self._angle_type_associations = {} self._dihedral_types = {} + self._dihedral_types_associations = {} self._improper_types = {} + self._improper_types_associations = {} self._combining_rule = 'lorentz' self._set_refs = { ATOM_TYPE_DICT: self._atom_types, @@ -150,6 +155,13 @@ def __init__(self, name="Topology", box=None): DIHEDRAL_TYPE_DICT: self._dihedral_types, IMPROPER_TYPE_DICT: self._improper_types, } + self._association_refs = { + ATOM_TYPE_DICT: self._atom_types_associations, + BOND_TYPE_DICT: self._bond_types_associations, + ANGLE_TYPE_DICT: self._angle_type_associations, + DIHEDRAL_TYPE_DICT: self._dihedral_types_associations, + IMPROPER_TYPE_DICT: self._improper_types_associations + } @property def name(self): @@ -317,7 +329,10 @@ def add_site(self, site, update_types=True): if update_types and site.atom_type: site.atom_type.topology = self site.atom_type = self._atom_types.get(site.atom_type, site.atom_type) + conns = self._atom_types_associations.get(site.atom_type, set()) self._atom_types[site.atom_type] = site.atom_type + conns.add(site) + self._atom_types_associations[site.atom_type] = conns self.is_typed(updated=False) def update_sites(self): @@ -483,21 +498,29 @@ def update_connection_types(self): self._connection_types[c.connection_type] = c.connection_type if isinstance(c.connection_type, BondType): self._bond_types[c.connection_type] = c.connection_type + self._bond_types_associations[c.connection_type] = {c} if isinstance(c.connection_type, AngleType): self._angle_types[c.connection_type] = c.connection_type + self._angle_type_associations[c.connection_type] = {c} if isinstance(c.connection_type, DihedralType): self._dihedral_types[c.connection_type] = c.connection_type + self._dihedral_types_associations[c.connection_type] = {c} if isinstance(c.connection_type, ImproperType): self._improper_types[c.connection_type] = c.connection_type - elif c.connection_type in self.connection_types: + self._improper_types_associations[c.connection_type] = {c} + elif c.connection_type in self._connection_types: if isinstance(c.connection_type, BondType): c.connection_type = self._bond_types[c.connection_type] + self._bond_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, AngleType): c.connection_type = self._angle_types[c.connection_type] + self._angle_type_associations[c.connection_type].add(c) if isinstance(c.connection_type, DihedralType): c.connection_type = self._dihedral_types[c.connection_type] + self._dihedral_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, ImproperType): c.connection_type = self._improper_types[c.connection_type] + self._improper_types_associations[c.connection_type].add(c) def update_atom_types(self): """Update atom types in the topology @@ -519,8 +542,10 @@ def update_atom_types(self): elif site.atom_type not in self._atom_types: site.atom_type.topology = self self._atom_types[site.atom_type] = site.atom_type + self._atom_types_associations[site.atom_type] = {site} elif site.atom_type in self._atom_types: site.atom_type = self._atom_types[site.atom_type] + self._atom_types_associations[site.atom_type].add(site) self.is_typed(updated=True) def add_subtopology(self, subtop): @@ -543,7 +568,7 @@ def add_subtopology(self, subtop): subtop.parent = self self._sites.union(subtop.sites) - def is_typed(self, updated=False): + def is_typed(self, updated=False, ): if not updated: self.update_connection_types() self.update_atom_types() @@ -610,6 +635,48 @@ def update_topology(self): self.update_connection_types() self.is_typed(updated=True) + def get_associations(self, conn_or_atom_type): + """ Return objects associated with a connection or atom type in the topology + + This method takes `conn_or_atom_type` and returns the number of + sites(if `conn_or_atom_type` is of type GMSO.AtomType) or connections + (i.e. bonds, angles, dihedrals or impropers (if `conn_or_atom_type` + is one of gmso.BondType, gmso.AngleType, + gmso.DihedralType gmso.ImproperType respectively)) associated with it. + + Parameters + ---------- + conn_or_atom_type : gmso.AtomType or gmso.BondType or gmso.AngleType or gmso.DihedralType or gmso.ImproperType + The connection_type for which to return the association for + + Returns + ------- + set + A set of sites or connections associated with `conn_or_atom_type` + + """ + if conn_or_atom_type not in self._get_ref(conn_or_atom_type, type_='items'): + raise GMSOError(f'{conn_or_atom_type} is not associated with any items in the topology') + association_dict = self._get_ref(conn_or_atom_type, type_='associations') + return set(association_dict[conn_or_atom_type]) + + def _get_ref(self, conn_or_atom_type, type_='items'): + """Get book keeping reference dictionary for the object""" + assert type_ in ('items', 'associations') + ref_dict = { + AtomType: ATOM_TYPE_DICT, + BondType: BOND_TYPE_DICT, + AngleType: ATOM_TYPE_DICT, + DihedralType: DIHEDRAL_TYPE_DICT, + ImproperType: IMPROPER_TYPE_DICT + } + if type_ == 'items': + _container_dict = self._set_refs + elif type_ == 'associations': + _container_dict = self._association_refs + + return _container_dict[ref_dict[type(conn_or_atom_type)]] + def __repr__(self): descr = list('<') descr.append(self.name + ' ') diff --git a/gmso/tests/test_topology.py b/gmso/tests/test_topology.py index 232dfd619..c438666b0 100644 --- a/gmso/tests/test_topology.py +++ b/gmso/tests/test_topology.py @@ -418,4 +418,23 @@ def test_topology_atom_type_changes(self): assert top.sites[10].atom_type.name == 'atom_type_changed' assert top.is_typed() + def test_get_associations_atom_types(self): + top = Topology() + atom_type = AtomType(name='test_atomtype') + for i in range(10): + top.add_site(Site(name=f'site{i+1}', atom_type=atom_type), update_types=True) + assert len(top.get_associations(atom_type)) == 10 + + def test_get_association_atom_types_after_changes(self, typed_ar_system): + typed_ar_system.atom_types[0].name = 'Typed Ar System' + assert typed_ar_system.get_associations(typed_ar_system.atom_types[0]) == set(typed_ar_system.sites) + + def test_get_association_bond_types(self, typed_water_system): + bond_type = typed_water_system.bond_types[0] + assert len(typed_water_system.get_associations(bond_type)) == 4 + + def test_get_association_bond_types_after_changes(self, typed_water_system): + bond_type = typed_water_system.bond_types[0] + bond_type.name = 'Typed water system' + assert type(typed_water_system.get_associations(bond_type)) == set diff --git a/gmso/utils/decorators.py b/gmso/utils/decorators.py index aff3219d0..24ced9d45 100644 --- a/gmso/utils/decorators.py +++ b/gmso/utils/decorators.py @@ -10,8 +10,10 @@ def confirm_dict_existence(setter_function): def setter_with_dict_removal(self, *args, **kwargs): if self._topology: self._topology._set_refs[self._set_ref].pop(self, None) + prev_associations = self._topology._association_refs[self._set_ref].pop(self, set()) setter_function(self, *args, **kwargs) self._topology._set_refs[self._set_ref][self] = (self) + self._topology._association_refs[self._set_ref][self] = prev_associations else: setter_function(self, *args, **kwargs) return setter_with_dict_removal From 8ea2d2da0ee7d148c131691a0d0e934c30e4ed7d Mon Sep 17 00:00:00 2001 From: Umesh Timalsina Date: Wed, 15 Apr 2020 17:08:05 -0500 Subject: [PATCH 2/2] Address PR comments and fix previous merge --- gmso/core/topology.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gmso/core/topology.py b/gmso/core/topology.py index 0c5929f68..8402c106e 100644 --- a/gmso/core/topology.py +++ b/gmso/core/topology.py @@ -538,15 +538,16 @@ def update_connection_types(self): elif c.connection_type in self._connection_types: if isinstance(c.connection_type, BondType): + c.connection_type = self._bond_types[c.connection_type] self._bond_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, AngleType): - self._angle_types[c.connection_type] = c.connection_type + c.connection_type = self._angle_types[c.connection_type] self._angle_type_associations[c.connection_type].add(c) if isinstance(c.connection_type, DihedralType): - self._dihedral_types[c.connection_type] = c.connection_type + c.connection_type = self._dihedral_types[c.connection_type] self._dihedral_types_associations[c.connection_type].add(c) if isinstance(c.connection_type, ImproperType): - self._improper_types[c.connection_type] = c.connection_type + c.connection_type = self._improper_types[c.connection_type] self._improper_types_associations[c.connection_type].add(c) def update_atom_types(self): @@ -597,7 +598,7 @@ def add_subtopology(self, subtop): subtop.parent = self self._sites.union(subtop.sites) - def is_typed(self, updated=False, ): + def is_typed(self, updated=False): if not updated: self.update_connection_types() self.update_atom_types()