Skip to content

Commit eb16192

Browse files
sort and format
1 parent 841d8ce commit eb16192

File tree

19 files changed

+1361
-861
lines changed

19 files changed

+1361
-861
lines changed

src/vivarium_helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
__title__,
88
__uri__,
99
)
10-
from vivarium_helpers._version import __version__
10+
from vivarium_helpers._version import __version__

src/vivarium_helpers/hobbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
def goSledding():
2-
print("wheeee")
2+
print("wheeee")

src/vivarium_helpers/id_helper.py

Lines changed: 115 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,82 +8,97 @@
88
# The following list of valid entities was retrieved on 2020-10-12 from the hosted documentation:
99
# https://scicomp-docs.ihme.washington.edu/db_queries/current/get_ids.html
1010
_entities = [
11-
'age_group',
12-
'age_group_set',
13-
'cause',
14-
'cause_set',
15-
'cause_set_version',
16-
'covariate',
17-
'decomp_step',
18-
'gbd_round',
19-
'healthstate',
20-
'indicator_component',
21-
'life_table_parameter',
22-
'location',
23-
'location_set',
24-
'location_set_version',
25-
'measure',
26-
'metric',
27-
'modelable_entity',
28-
'sdg_indicator',
29-
'sequela',
30-
'sequela_set',
31-
'sequela_set_version',
32-
'sex',
33-
'split',
34-
'study_covariate',
35-
'rei',
36-
'rei_set',
37-
'rei_set_version',
38-
'year'
11+
"age_group",
12+
"age_group_set",
13+
"cause",
14+
"cause_set",
15+
"cause_set_version",
16+
"covariate",
17+
"decomp_step",
18+
"gbd_round",
19+
"healthstate",
20+
"indicator_component",
21+
"life_table_parameter",
22+
"location",
23+
"location_set",
24+
"location_set_version",
25+
"measure",
26+
"metric",
27+
"modelable_entity",
28+
"sdg_indicator",
29+
"sequela",
30+
"sequela_set",
31+
"sequela_set_version",
32+
"sex",
33+
"split",
34+
"study_covariate",
35+
"rei",
36+
"rei_set",
37+
"rei_set_version",
38+
"year",
3939
]
4040

41+
4142
def get_entities(source=None):
4243
"""Returns a list the entities that are valid arguments to `get_ids()`. Entities are represented as strings.
43-
44+
4445
If source is None (default):
4546
Returns the 28 entities listed as valid arguments to `get_ids()` in the online documentation on 2020-10-12.
4647
https://scicomp-docs.ihme.washington.edu/db_queries/current/get_ids.html
47-
48+
4849
If source is 'docstring':
4950
Returns the entities listed as valid arguments in the docstring of `get_ids()`.
5051
As of 2020-10-12, there were only 22 entities listed in the docstring, whereas 28 entities were listed
5152
in the online documentation; those are accessible via get_entities() with the default source=None.
52-
53+
5354
Currently no other entity sources are supported (a ValueError will be raised if another value is passed).
5455
"""
5556
if source is None:
5657
entities = _entities
57-
elif source == 'docstring':
58+
elif source == "docstring":
5859
docstring = _get_ids.__doc__
5960
# This simplistic solution works with the current version, but it may need to be updated
6061
# to a more robust solution if the docstring changes...
61-
entities = docstring[docstring.find('[')+1:docstring.find(']')].split()
62+
entities = docstring[docstring.find("[") + 1 : docstring.find("]")].split()
6263
else:
6364
raise ValueError(f"Unknonwn source of valid entities for `get_ids`: {source}")
6465
return entities
6566

67+
6668
def find_anomalous_name_columns(entities):
6769
"""Lists columns of entity tables that do not conatin a column called f'{entity}_name'."""
6870
# Use temporary dict to avoid calling _get_ids() twice in dictionary comprehension (or use a walrus := instead!)
6971
entities_columns = {entity: _get_ids(entity).columns for entity in entities}
70-
return {entity: columns for entity, columns in entities_columns.items() if f'{entity}_name' not in columns}
72+
return {
73+
entity: columns
74+
for entity, columns in entities_columns.items()
75+
if f"{entity}_name" not in columns
76+
}
7177
# Better solution using walrus operator, but requires Python version 3.8 (https://www.python.org/dev/peps/pep-0572/):
78+
79+
7280
# return {entity: columns for entity in entities if f'{entity}_name' not in (columns:=_get_ids(entity).columns)}
7381

82+
7483
def get_name_column(entity):
7584
"""Returns the name column for the entity in the entity id table."""
76-
if entity=='year':
77-
return 'year_id' # Year table has only one column, 'year_id'
78-
elif entity=='life_table_parameter':
79-
return 'parameter_name'
85+
if entity == "year":
86+
return "year_id" # Year table has only one column, 'year_id'
87+
elif entity == "life_table_parameter":
88+
return "parameter_name"
8089
elif entity in [
81-
'cause_set_version', 'gbd_round', 'location_set_version',
82-
'sequela_set_version', 'sex', 'study_covariate', 'rei_set_version'
83-
]:
90+
"cause_set_version",
91+
"gbd_round",
92+
"location_set_version",
93+
"sequela_set_version",
94+
"sex",
95+
"study_covariate",
96+
"rei_set_version",
97+
]:
8498
return entity
8599
else:
86-
return f'{entity}_name'
100+
return f"{entity}_name"
101+
87102

88103
# Should this have a parameter to optionally igonre NaN's? See comment for ids_to_names below.
89104
# Other possible parameters:
@@ -97,13 +112,14 @@ def names_to_ids(entity, *entity_names):
97112
"""Returns a pandas Series mapping entity names to entity id's for the specified GBD entity."""
98113
ids = _get_ids(entity)
99114
entity_name_col = get_name_column(entity)
100-
if len(entity_names)>0:
101-
ids = ids.query(f'{entity_name_col} in {entity_names}')
115+
if len(entity_names) > 0:
116+
ids = ids.query(f"{entity_name_col} in {entity_names}")
102117
# Year table only has one column, so we copy it
103-
if entity=='year':
104-
entity_name_col = 'year'
105-
ids[entity_name_col] = ids['year_id']
106-
return ids.set_index(entity_name_col)[f'{entity}_id']
118+
if entity == "year":
119+
entity_name_col = "year"
120+
ids[entity_name_col] = ids["year_id"]
121+
return ids.set_index(entity_name_col)[f"{entity}_id"]
122+
107123

108124
# Should this have a parameter to optionally igonre NaN's?
109125
# I got an error when I tried to pass entity_ids directly from a DataFrame that contained NaN's.
@@ -113,28 +129,30 @@ def names_to_ids(entity, *entity_names):
113129
def ids_to_names(entity, *entity_ids):
114130
"""Returns a pandas Series mapping entity id's to entity names for the specified GBD entity."""
115131
ids = _get_ids(entity)
116-
if len(entity_ids)>0:
132+
if len(entity_ids) > 0:
117133
# I think this raises an exception (KeyError and/or UndefinedVariableError) if entity_ids contains NaN
118-
ids = ids.query(f'{entity}_id in {entity_ids}')
134+
ids = ids.query(f"{entity}_id in {entity_ids}")
119135
entity_name_col = get_name_column(entity)
120136
# Year table only has one column, so we copy it
121-
if entity=='year':
122-
entity_name_col = 'year'
123-
ids[entity_name_col] = ids['year_id']
124-
return ids.set_index(f'{entity}_id')[entity_name_col]
137+
if entity == "year":
138+
entity_name_col = "year"
139+
ids[entity_name_col] = ids["year_id"]
140+
return ids.set_index(f"{entity}_id")[entity_name_col]
141+
125142

126143
def process_singleton_ids(ids, entity):
127144
"""Returns a single id if len(ids)==1. If len(ids)>1, returns ids (assumed to be a list), or raises
128145
a ValueError if the shared functions expect a single id rather than a list for the specified entity.
129146
"""
130-
if len(ids)==1:
147+
if len(ids) == 1:
131148
ids = ids[0]
132-
elif entity=='gbd_round': # Also version id's?
149+
elif entity == "gbd_round": # Also version id's?
133150
# It might be better to just let the shared functions raise an exception
134151
# rather than me doing it for them. In which case this function would be almost pointless...
135152
raise ValueError(f"Only single {entity} id's are allowed in shared functions.")
136153
return ids
137154

155+
138156
def list_ids(entity, *entity_names):
139157
"""Returns a list of ids (or a single id) for the specified entity names,
140158
suitable for passing to GBD shared functions.
@@ -145,29 +163,33 @@ def list_ids(entity, *entity_names):
145163
ids = process_singleton_ids(ids, entity)
146164
return ids
147165

166+
148167
def get_entity_and_id_colname(table):
149168
"""Returns the entity and entity id column name from an id table,
150169
assuming the entity id column name is f'{entity}_id',
151170
and that this is the first (or only) column ending in '_id'.
152171
"""
153-
# id_colname = table.columns[table.columns.str.contains(r'\w+_id$')][0]
154-
id_colname = table.filter(regex=r'\w+_id$').columns[0]
172+
# id_colname = table.columns[table.columns.str.contains(r'\w+_id$')][0]
173+
id_colname = table.filter(regex=r"\w+_id$").columns[0]
155174
entity = id_colname[:-3]
156175
return entity, id_colname
157176

177+
158178
def get_entity(table):
159179
"""Returns the entity represented by a given id table,
160180
assuming the id column name is f'{entity}_id',
161181
and that this is the first (or only) column ending in '_id'.
162182
"""
163183
return get_entity_and_id_colname(table)[0]
164184

185+
165186
def get_id_colname(table):
166187
"""Returns the entity id column name in the given id table,
167188
assuming it is the first (or only) column name that ends with '_id'.
168189
"""
169190
return get_entity_and_id_colname(table)[1]
170191

192+
171193
def ids_in(table):
172194
"""Returns the ids in the given dataframe, either as a list of ints or a single int."""
173195
entity, id_colname = get_entity_and_id_colname(table)
@@ -177,7 +199,10 @@ def ids_in(table):
177199
ids = process_singleton_ids(ids, entity)
178200
return ids
179201

180-
def search_id_table(table_or_entity, pattern, search_col=None, return_all_columns=False, **kwargs_for_contains):
202+
203+
def search_id_table(
204+
table_or_entity, pattern, search_col=None, return_all_columns=False, **kwargs_for_contains
205+
):
181206
"""Searches an entity id table for entity names matching the specified pattern, using pandas.Series.str.contains()."""
182207
if isinstance(table_or_entity, _DataFrame):
183208
df = table_or_entity
@@ -186,55 +211,72 @@ def search_id_table(table_or_entity, pattern, search_col=None, return_all_column
186211
entity = table_or_entity
187212
df = _get_ids(entity, return_all_columns)
188213
else:
189-
raise TypeError(f'Expecting type {_DataFrame} or {str} for `table_or_entity`. Got type {type(table_or_entity)}.')
214+
raise TypeError(
215+
f"Expecting type {_DataFrame} or {str} for `table_or_entity`. Got type {type(table_or_entity)}."
216+
)
190217

191218
if search_col is None:
192219
search_col = get_name_column(entity)
193220

194221
return df[df[search_col].str.contains(pattern, **kwargs_for_contains)]
195222

196-
def find_ids(table_or_entity, pattern, search_col=None, return_all_columns=False, **kwargs_for_contains):
223+
224+
def find_ids(
225+
table_or_entity, pattern, search_col=None, return_all_columns=False, **kwargs_for_contains
226+
):
197227
"""Searches an entity id table for entity names matching the specified pattern, using pandas.Series.str.contains(),
198228
and returns a list of ids (or a single id) for the specified entity names, suitable for passing to GBD shared functions.
199229
"""
200-
df = search_id_table(table_or_entity, pattern, search_col=None, return_all_columns=False, **kwargs_for_contains)
230+
df = search_id_table(
231+
table_or_entity,
232+
pattern,
233+
search_col=None,
234+
return_all_columns=False,
235+
**kwargs_for_contains,
236+
)
201237
return ids_in(df)
202238

239+
203240
def add_entity_names(df, *entities):
204241
"""Adds a name column for each specified entity in the dataframe df, by merging on the entity_id column.
205242
If no entities are passed, a name column is added for each enttity id column found in the dataframe.
206243
Intended to be called on dataframes returned by the shared functions.
207244
Returns a new object (does not modify df in place).
208245
"""
209246
if len(entities) == 0:
210-
entities = df.filter(regex=r'\w+_id$').columns.str.replace('_id', '')
247+
entities = df.filter(regex=r"\w+_id$").columns.str.replace("_id", "")
211248
for entity in entities:
212-
if entity == 'year': # Avoid error from trying to merge duplicate year_id columns
213-
df = df.assign(year=df['year_id'])
249+
if entity == "year": # Avoid error from trying to merge duplicate year_id columns
250+
df = df.assign(year=df["year_id"])
214251
else:
215-
df = df.merge(_get_ids(entity)[[f'{entity}_id', get_name_column(entity)]])
252+
df = df.merge(_get_ids(entity)[[f"{entity}_id", get_name_column(entity)]])
216253
return df
217254

218-
def drop_id_columns(df, *entities, keep=False, errors='raise'):
255+
256+
def drop_id_columns(df, *entities, keep=False, errors="raise"):
219257
"""If `keep` is False (default), drops the id column for each specified entity in the dataframe df.
220258
Drops all id columns if no entities are passed (you should probably only do this if you have added
221259
the corresponding entity name for the relevant id columns).
222260
If `keep` is set to True, the passed entities are those ids to keep, and all others will be dropped.
223261
Intended to be called on dataframes returned by the shared functions.
224262
Returns a new object (does not modify df in place).
225263
"""
226-
id_colnames = [f'{entity}_id' for entity in entities]
264+
id_colnames = [f"{entity}_id" for entity in entities]
227265
if len(entities) == 0 or keep:
228-
all_id_colnames = df.filter(regex=r'\w+_id$').columns # If no entities passed, drop all id columns
229-
if len(entities) > 0: # entities are those to keep, not drop
266+
all_id_colnames = df.filter(
267+
regex=r"\w+_id$"
268+
).columns # If no entities passed, drop all id columns
269+
if len(entities) > 0: # entities are those to keep, not drop
230270
id_colnames = all_id_colnames.difference(id_colnames)
231271
return df.drop(columns=id_colnames, errors=errors)
232272

233-
def replace_ids_with_names(df, *entities, invert=False, errors='raise'):
273+
274+
def replace_ids_with_names(df, *entities, invert=False, errors="raise"):
234275
"""Replaces entity id columns in df with corresponding entity name columns."""
235276
if invert:
236-
entities = df.filter(regex=r'\w+_id$').columns.difference([f'{entity}_id' for entity in entities])
277+
entities = df.filter(regex=r"\w+_id$").columns.difference(
278+
[f"{entity}_id" for entity in entities]
279+
)
237280
df = add_entity_names(df, *entities)
238281
df = drop_id_columns(df, *entities, errors=errors)
239282
return df
240-

0 commit comments

Comments
 (0)