Skip to content

Add get_associations method to topology class (Closes #375) #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,27 @@ def __init__(self, name="Topology", box=None):
self._impropers = IndexedSet()
self._subtops = IndexedSet()
self._atom_types = {}
self._atom_types_associations = {}
self._atom_types_idx = {}

self._connection_types = {}

self._bond_types = {}
self._bond_types_associations = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really small nit but types is plural in _atom_types_associations and _bond_types_associations but is singular in _angle_type_associations. I think I would actually prefer these all be singular type (_atom_type_associations, _bond_type_associations, etc.) because it reads a little bit better to me, but it's worth getting opinions from the rest of @mosdef-hub/mosdef-contributors.

Either way the naming should be consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for naming it with _****_types_associations was because it is a dictionary of all the connection types in the topology, which will contain a set. I have no preference one way or other.

Copy link
Collaborator

@rmatsum836 rmatsum836 Apr 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I'm fine with _****_types_associations then.

self._bond_types_idx = {}

self._angle_types = {}
self._angle_type_associations = {}
self._angle_types_idx = {}

self._dihedral_types = {}
self._dihedral_types_associations = {}
self._dihedral_types_idx = {}

self._improper_types = {}
self._improper_types_associations = {}
self._improper_types_idx = {}

self._combining_rule = 'lorentz'
self._set_refs = {
ATOM_TYPE_DICT: self._atom_types,
Expand All @@ -156,6 +167,13 @@ def __init__(self, name="Topology", box=None):
DIHEDRAL_TYPE_DICT: self._dihedral_types,
IMPROPER_TYPE_DICT: self._improper_types,
}
self._association_refs = {
ATOM_TYPE_DICT: self._atom_types_associations,
BOND_TYPE_DICT: self._bond_types_associations,
ANGLE_TYPE_DICT: self._angle_type_associations,
DIHEDRAL_TYPE_DICT: self._dihedral_types_associations,
IMPROPER_TYPE_DICT: self._improper_types_associations
}

self._index_refs = {
ATOM_TYPE_DICT: self._atom_types_idx,
Expand Down Expand Up @@ -331,11 +349,13 @@ def add_site(self, site, update_types=True):
self._sites.add(site)
if update_types and site.atom_type:
site.atom_type.topology = self
if site.atom_type in self._atom_types:
site.atom_type = self._atom_types[site.atom_type]
else:
self._atom_types[site.atom_type] = site.atom_type
self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1
site.atom_type = self._atom_types.get(site.atom_type, site.atom_type)
conns = self._atom_types_associations.get(site.atom_type, set())
self._atom_types[site.atom_type] = site.atom_type
self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1
conns.add(site)
self._atom_types_associations[site.atom_type] = conns

self.is_typed(updated=False)

def update_sites(self):
Expand Down Expand Up @@ -501,25 +521,34 @@ def update_connection_types(self):
self._connection_types[c.connection_type] = c.connection_type
if isinstance(c.connection_type, BondType):
self._bond_types[c.connection_type] = c.connection_type
self._bond_types_associations[c.connection_type] = {c}
self._bond_types_idx[c.connection_type] = len(self._bond_types) - 1
if isinstance(c.connection_type, AngleType):
self._angle_types[c.connection_type] = c.connection_type
self._angle_type_associations[c.connection_type] = {c}
self._angle_types_idx[c.connection_type] = len(self._bond_types) - 1
if isinstance(c.connection_type, DihedralType):
self._dihedral_types[c.connection_type] = c.connection_type
self._dihedral_types_associations[c.connection_type] = {c}
self._dihedral_types_idx[c.connection_type] = len(self._bond_types) - 1
if isinstance(c.connection_type, ImproperType):
self._improper_types[c.connection_type] = c.connection_type
self._improper_types_associations[c.connection_type] = {c}
self._improper_types_idx[c.connection_type] = len(self._bond_types) - 1
elif c.connection_type in self.connection_types:

elif c.connection_type in self._connection_types:
if isinstance(c.connection_type, BondType):
c.connection_type = self._bond_types[c.connection_type]
self._bond_types_associations[c.connection_type].add(c)
if isinstance(c.connection_type, AngleType):
c.connection_type = self._angle_types[c.connection_type]
self._angle_type_associations[c.connection_type].add(c)
if isinstance(c.connection_type, DihedralType):
c.connection_type = self._dihedral_types[c.connection_type]
self._dihedral_types_associations[c.connection_type].add(c)
if isinstance(c.connection_type, ImproperType):
c.connection_type = self._improper_types[c.connection_type]
self._improper_types_associations[c.connection_type].add(c)

def update_atom_types(self):
"""Update atom types in the topology
Expand All @@ -541,9 +570,12 @@ def update_atom_types(self):
elif site.atom_type not in self._atom_types:
site.atom_type.topology = self
self._atom_types[site.atom_type] = site.atom_type
self._atom_types_associations[site.atom_type] = {site}
self._atom_types_idx[site.atom_type] = len(self._atom_types) - 1

elif site.atom_type in self._atom_types:
site.atom_type = self._atom_types[site.atom_type]
self._atom_types_associations[site.atom_type].add(site)
self.is_typed(updated=True)

def add_subtopology(self, subtop):
Expand Down Expand Up @@ -633,6 +665,48 @@ def update_topology(self):
self.update_connection_types()
self.is_typed(updated=True)

def get_associations(self, conn_or_atom_type):
""" Return objects associated with a connection or atom type in the topology

This method takes `conn_or_atom_type` and returns the number of
sites(if `conn_or_atom_type` is of type GMSO.AtomType) or connections
(i.e. bonds, angles, dihedrals or impropers (if `conn_or_atom_type`
is one of gmso.BondType, gmso.AngleType,
gmso.DihedralType gmso.ImproperType respectively)) associated with it.

Parameters
----------
conn_or_atom_type : gmso.AtomType or gmso.BondType or gmso.AngleType or gmso.DihedralType or gmso.ImproperType
The connection_type for which to return the association for

Returns
-------
set
A set of sites or connections associated with `conn_or_atom_type`

"""
if conn_or_atom_type not in self._get_ref(conn_or_atom_type, type_='items'):
raise GMSOError(f'{conn_or_atom_type} is not associated with any items in the topology')
association_dict = self._get_ref(conn_or_atom_type, type_='associations')
return set(association_dict[conn_or_atom_type])

def _get_ref(self, conn_or_atom_type, type_='items'):
"""Get book keeping reference dictionary for the object"""
assert type_ in ('items', 'associations')
ref_dict = {
AtomType: ATOM_TYPE_DICT,
BondType: BOND_TYPE_DICT,
AngleType: ATOM_TYPE_DICT,
DihedralType: DIHEDRAL_TYPE_DICT,
ImproperType: IMPROPER_TYPE_DICT
}
if type_ == 'items':
_container_dict = self._set_refs
elif type_ == 'associations':
_container_dict = self._association_refs

return _container_dict[ref_dict[type(conn_or_atom_type)]]

def get_index(self, member):
"""Get index of a member in the topology

Expand Down
21 changes: 21 additions & 0 deletions gmso/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,26 @@ def test_topology_atom_type_changes(self):
assert top.sites[10].atom_type.name == 'atom_type_changed'
assert top.is_typed()

def test_get_associations_atom_types(self):
top = Topology()
atom_type = AtomType(name='test_atomtype')
for i in range(10):
top.add_site(Site(name=f'site{i+1}', atom_type=atom_type), update_types=True)
assert len(top.get_associations(atom_type)) == 10

def test_get_association_atom_types_after_changes(self, typed_ar_system):
typed_ar_system.atom_types[0].name = 'Typed Ar System'
assert typed_ar_system.get_associations(typed_ar_system.atom_types[0]) == set(typed_ar_system.sites)

def test_get_association_bond_types(self, typed_water_system):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this fits here, but what if we had a system with bonds (water would be an example: H-O-H)

In this case, bondtype should have a set that contains all 3 sites. If we then removed a bond between one of the H and O (H O-H)

We expect that the len(get_associations(broken_water) is now 2.

Would be a good test to have. and this same example should work for angletype too.

bond_type = typed_water_system.bond_types[0]
assert len(typed_water_system.get_associations(bond_type)) == 4

def test_get_association_bond_types_after_changes(self, typed_water_system):
bond_type = typed_water_system.bond_types[0]
bond_type.name = 'Typed water system'
assert type(typed_water_system.get_associations(bond_type)) == set

def test_topology_get_index(self):
top = Topology()
conn_members = [Site(), Site(), Site(), Site()]
Expand Down Expand Up @@ -487,3 +507,4 @@ def test_topology_get_index_angle_type_after_change(self, typed_methylnitroanili
prev_idx = typed_methylnitroaniline.get_index(angle_type_to_test)
typed_methylnitroaniline.angles[0].connection_type.name = 'changed name'
assert typed_methylnitroaniline.get_index(angle_type_to_test) != prev_idx

2 changes: 2 additions & 0 deletions gmso/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def confirm_dict_existence(setter_function):
def setter_with_dict_removal(self, *args, **kwargs):
if self._topology:
self._topology._set_refs[self._set_ref].pop(self, None)
prev_associations = self._topology._association_refs[self._set_ref].pop(self, set())
setter_function(self, *args, **kwargs)
self._topology._set_refs[self._set_ref][self] = (self)
self._topology._association_refs[self._set_ref][self] = prev_associations
self._topology._reindex_connection_types(self._set_ref)
else:
setter_function(self, *args, **kwargs)
Expand Down