Skip to content

Commit 031c2c4

Browse files
search kwargs (#982)
* add nelements to summary API, indicated num_elements deprecation via warning * port get_stability + test to api * move get_stability to separate pr * generalize checks on _search/search fields and ensure appropriate warnings/exceptions are thrown * minor documentation of code logic --------- Co-authored-by: esoteric-ephemera <[email protected]>
1 parent f617529 commit 031c2c4

File tree

4 files changed

+130
-37
lines changed

4 files changed

+130
-37
lines changed

mp_api/client/core/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import annotations
22

3-
from .client import BaseRester, MPRestError
3+
from .client import BaseRester, MPRestError, MPRestWarning
44
from .settings import MAPIClientSettings

mp_api/client/core/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,7 @@ def __str__(self): # pragma: no cover
13351335

13361336
class MPRestError(Exception):
13371337
"""Raised when the query has problems, e.g., bad query format."""
1338+
1339+
1340+
class MPRestWarning(Warning):
1341+
"""Raised when a query is malformed but interpretable."""

mp_api/client/routes/materials/summary.py

+124-36
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections import defaultdict
45

56
from emmet.core.summary import HasProps, SummaryDoc
67
from emmet.core.symmetry import CrystalSystem
78
from pymatgen.analysis.magnetism import Ordering
89

9-
from mp_api.client.core import BaseRester, MPRestError
10+
from mp_api.client.core import BaseRester, MPRestError, MPRestWarning
1011
from mp_api.client.core.utils import validate_ids
1112

1213

@@ -15,7 +16,7 @@ class SummaryRester(BaseRester[SummaryDoc]):
1516
document_model = SummaryDoc # type: ignore
1617
primary_key = "material_id"
1718

18-
def search(
19+
def search( # noqa: D417
1920
self,
2021
band_gap: tuple[float, float] | None = None,
2122
chemsys: str | list[str] | None = None,
@@ -72,6 +73,7 @@ def search(
7273
chunk_size: int = 1000,
7374
all_fields: bool = True,
7475
fields: list[str] | None = None,
76+
**kwargs,
7577
) -> list[SummaryDoc] | list[dict]:
7678
"""Query core data using a variety of search criteria.
7779
@@ -117,7 +119,8 @@ def search(
117119
material_ids (str, List[str]): A single Material ID string or list of strings
118120
(e.g., mp-149, [mp-149, mp-13]).
119121
n (Tuple[float,float]): Minimum and maximum refractive index to consider.
120-
num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
122+
nelements (Tuple[int,int]): Minimum and maximum number of elements to consider.
123+
num_elements (Tuple[int,int]): Alias for `nelements`, deprecated. Minimum and maximum number of elements to consider.
121124
num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
122125
num_magnetic_sites (Tuple[int,int]): Minimum and maximum number of magnetic sites to consider.
123126
num_unique_magnetic_sites (Tuple[int,int]): Minimum and maximum number of unique magnetic sites to consider.
@@ -153,53 +156,138 @@ def search(
153156
"""
154157
query_params = defaultdict(dict) # type: dict
155158

159+
not_aliased_kwargs = [
160+
"energy_above_hull",
161+
"nsites",
162+
"volume",
163+
"density",
164+
"band_gap",
165+
"efermi",
166+
"total_magnetization",
167+
"total_magnetization_normalized_vol",
168+
"total_magnetization_normalized_formula_units",
169+
"num_magnetic_sites",
170+
"num_unique_magnetic_sites",
171+
"k_voigt",
172+
"k_reuss",
173+
"k_vrh",
174+
"g_voigt",
175+
"g_reuss",
176+
"g_vrh",
177+
"e_total",
178+
"e_ionic",
179+
"e_electronic",
180+
"n",
181+
"weighted_surface_energy",
182+
"weighted_work_function",
183+
"shape_factor",
184+
]
185+
156186
min_max_name_dict = {
157187
"total_energy": "energy_per_atom",
158188
"formation_energy": "formation_energy_per_atom",
159-
"energy_above_hull": "energy_above_hull",
160189
"uncorrected_energy": "uncorrected_energy_per_atom",
161190
"equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom",
162-
"nsites": "nsites",
163-
"volume": "volume",
164-
"density": "density",
165-
"band_gap": "band_gap",
166-
"efermi": "efermi",
167-
"total_magnetization": "total_magnetization",
168-
"total_magnetization_normalized_vol": "total_magnetization_normalized_vol",
169-
"total_magnetization_normalized_formula_units": "total_magnetization_normalized_formula_units",
170-
"num_magnetic_sites": "num_magnetic_sites",
171-
"num_unique_magnetic_sites": "num_unique_magnetic_sites",
172-
"k_voigt": "k_voigt",
173-
"k_reuss": "k_reuss",
174-
"k_vrh": "k_vrh",
175-
"g_voigt": "g_voigt",
176-
"g_reuss": "g_reuss",
177-
"g_vrh": "g_vrh",
178191
"elastic_anisotropy": "universal_anisotropy",
179192
"poisson_ratio": "homogeneous_poisson",
180-
"e_total": "e_total",
181-
"e_ionic": "e_ionic",
182-
"e_electronic": "e_electronic",
183-
"n": "n",
184193
"num_sites": "nsites",
185194
"num_elements": "nelements",
186195
"piezoelectric_modulus": "e_ij_max",
187-
"weighted_surface_energy": "weighted_surface_energy",
188-
"weighted_work_function": "weighted_work_function",
189196
"surface_energy_anisotropy": "surface_anisotropy",
190-
"shape_factor": "shape_factor",
191197
}
192198

193-
for param, value in locals().items():
194-
if param in min_max_name_dict and value:
195-
if isinstance(value, (int, float)):
196-
value = (value, value)
197-
query_params.update(
198-
{
199-
f"{min_max_name_dict[param]}_min": value[0],
200-
f"{min_max_name_dict[param]}_max": value[1],
201-
}
199+
min_max_name_dict.update({k: k for k in not_aliased_kwargs})
200+
mmnd_inv = {v: k for k, v in min_max_name_dict.items() if k != v}
201+
202+
# Set user query params from `locals`
203+
user_settings = {
204+
k: v for k, v in locals().items() if k in min_max_name_dict and v
205+
}
206+
207+
# Check to see if user specified _search fields using **kwargs,
208+
# or if any of the **kwargs are unparsable
209+
db_keys = {k: [] for k in ("duplicate", "warn", "unknown")}
210+
for k, v in kwargs.items():
211+
category = "unknown"
212+
if non_db_k := mmnd_inv.get(k):
213+
if user_settings.get(non_db_k):
214+
# Both a search and _search equivalent field are specified
215+
category = "duplicate"
216+
elif v:
217+
# Only the _search field is specified
218+
category = "warn"
219+
user_settings[non_db_k] = v
220+
db_keys[category].append(non_db_k or k)
221+
222+
# If any _search or unknown fields were set, throw warnings/exceptions
223+
if any(db_keys.values()):
224+
warning_strs: list[str] = []
225+
exc_strs: list[str] = []
226+
227+
def csrc(x):
228+
return f"\x1b[34m{x}\x1b[39m"
229+
230+
def _csrc(x):
231+
return f"\x1b[31m{x}\x1b[39m"
232+
233+
# Warn the user if they input any fields from _search without setting equivalent kwargs in search
234+
if db_keys["warn"]:
235+
warning_strs.extend(
236+
[
237+
f"You have specified fields used by {_csrc('`_search`')} that can be understood by {csrc('`search`')}",
238+
f" {', '.join([_csrc(min_max_name_dict[k]) for k in db_keys['warn']])}",
239+
f"To ensure long term support, please use their {csrc('`search`')} equivalents:",
240+
f" {', '.join([csrc(k) for k in db_keys['warn']])}",
241+
]
242+
)
243+
244+
# Throw an exception if the user input a field from _search and its equivalent search kwarg
245+
if db_keys["duplicate"]:
246+
dupe_pairs = "\n".join(
247+
f"{csrc(k)} and {_csrc(min_max_name_dict[k])}"
248+
for k in db_keys["duplicate"]
202249
)
250+
exc_strs.extend(
251+
[
252+
f"You have specified fields known to both {csrc('`search`')} and {_csrc('`_search`')}",
253+
f" {dupe_pairs}",
254+
f"To avoid query ambiguity, please check your {csrc('`search`')} query and only specify",
255+
f" {', '.join([csrc(k) for k in db_keys['duplicate']])}",
256+
]
257+
)
258+
# Throw an exception if any unknown kwargs were input
259+
if db_keys["unknown"]:
260+
exc_strs.extend(
261+
[
262+
f"You have specified the following kwargs which are unknown to {csrc('`search`')}, "
263+
f"but may be known to {_csrc('`_search`')}",
264+
f" \x1b[36m{', '.join(db_keys['unknown'])}\x1b[39m",
265+
]
266+
)
267+
268+
# Always print links to documentation on warning / exception
269+
warn_ref_strs = [
270+
"Please see the documentation:",
271+
f" {csrc('`search`: https://materialsproject.github.io/api/_autosummary/mp_api.client.routes.materials.summary.SummaryRester.html#mp_api.client.routes.materials.summary.SummaryRester.search')}",
272+
f" {_csrc('`_search`: https://api.materialsproject.org/redoc#tag/Materials-Summary/operation/search_materials_summary__get')}",
273+
]
274+
275+
if exc_strs:
276+
raise MPRestError("\n".join([*warning_strs, *exc_strs, *warn_ref_strs]))
277+
if warn_ref_strs:
278+
warnings.warn(
279+
"\n".join([*warning_strs, *warn_ref_strs]), category=MPRestWarning
280+
)
281+
282+
for param, value in user_settings.items():
283+
if isinstance(value, (int, float)):
284+
value = (value, value)
285+
query_params.update(
286+
{
287+
f"{min_max_name_dict[param]}_min": value[0],
288+
f"{min_max_name_dict[param]}_max": value[1],
289+
}
290+
)
203291

204292
if material_ids:
205293
if isinstance(material_ids, str):

tests/materials/test_summary.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"theoretical": True,
5858
"has_reconstructed": False,
5959
"magnetic_ordering": Ordering.FM,
60+
"nelements": (8, 9),
6061
} # type: dict
6162

6263

0 commit comments

Comments
 (0)