Skip to content

Summary aliasing of search fields #978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mp_api/client/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

from .client import BaseRester, MPRestError
from .client import BaseRester, MPRestError, MPRestWarning
from .settings import MAPIClientSettings
4 changes: 4 additions & 0 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,3 +1335,7 @@ def __str__(self): # pragma: no cover

class MPRestError(Exception):
"""Raised when the query has problems, e.g., bad query format."""


class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""
160 changes: 124 additions & 36 deletions mp_api/client/routes/materials/summary.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import warnings
from collections import defaultdict

from emmet.core.summary import HasProps, SummaryDoc
from emmet.core.symmetry import CrystalSystem
from pymatgen.analysis.magnetism import Ordering

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core import BaseRester, MPRestError, MPRestWarning
from mp_api.client.core.utils import validate_ids


Expand All @@ -15,7 +16,7 @@ class SummaryRester(BaseRester[SummaryDoc]):
document_model = SummaryDoc # type: ignore
primary_key = "material_id"

def search(
def search( # noqa: D417
self,
band_gap: tuple[float, float] | None = None,
chemsys: str | list[str] | None = None,
Expand Down Expand Up @@ -72,6 +73,7 @@ def search(
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
**kwargs,
) -> list[SummaryDoc] | list[dict]:
"""Query core data using a variety of search criteria.

Expand Down Expand Up @@ -117,7 +119,8 @@ def search(
material_ids (str, List[str]): A single Material ID string or list of strings
(e.g., mp-149, [mp-149, mp-13]).
n (Tuple[float,float]): Minimum and maximum refractive index to consider.
num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
nelements (Tuple[int,int]): Minimum and maximum number of elements to consider.
num_elements (Tuple[int,int]): Alias for `nelements`, deprecated. Minimum and maximum number of elements to consider.
num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
num_magnetic_sites (Tuple[int,int]): Minimum and maximum number of magnetic sites to consider.
num_unique_magnetic_sites (Tuple[int,int]): Minimum and maximum number of unique magnetic sites to consider.
Expand Down Expand Up @@ -153,53 +156,138 @@ def search(
"""
query_params = defaultdict(dict) # type: dict

not_aliased_kwargs = [
"energy_above_hull",
"nsites",
"volume",
"density",
"band_gap",
"efermi",
"total_magnetization",
"total_magnetization_normalized_vol",
"total_magnetization_normalized_formula_units",
"num_magnetic_sites",
"num_unique_magnetic_sites",
"k_voigt",
"k_reuss",
"k_vrh",
"g_voigt",
"g_reuss",
"g_vrh",
"e_total",
"e_ionic",
"e_electronic",
"n",
"weighted_surface_energy",
"weighted_work_function",
"shape_factor",
]

min_max_name_dict = {
"total_energy": "energy_per_atom",
"formation_energy": "formation_energy_per_atom",
"energy_above_hull": "energy_above_hull",
"uncorrected_energy": "uncorrected_energy_per_atom",
"equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom",
"nsites": "nsites",
"volume": "volume",
"density": "density",
"band_gap": "band_gap",
"efermi": "efermi",
"total_magnetization": "total_magnetization",
"total_magnetization_normalized_vol": "total_magnetization_normalized_vol",
"total_magnetization_normalized_formula_units": "total_magnetization_normalized_formula_units",
"num_magnetic_sites": "num_magnetic_sites",
"num_unique_magnetic_sites": "num_unique_magnetic_sites",
"k_voigt": "k_voigt",
"k_reuss": "k_reuss",
"k_vrh": "k_vrh",
"g_voigt": "g_voigt",
"g_reuss": "g_reuss",
"g_vrh": "g_vrh",
"elastic_anisotropy": "universal_anisotropy",
"poisson_ratio": "homogeneous_poisson",
"e_total": "e_total",
"e_ionic": "e_ionic",
"e_electronic": "e_electronic",
"n": "n",
"num_sites": "nsites",
"num_elements": "nelements",
"piezoelectric_modulus": "e_ij_max",
"weighted_surface_energy": "weighted_surface_energy",
"weighted_work_function": "weighted_work_function",
"surface_energy_anisotropy": "surface_anisotropy",
"shape_factor": "shape_factor",
}

for param, value in locals().items():
if param in min_max_name_dict and value:
if isinstance(value, (int, float)):
value = (value, value)
query_params.update(
{
f"{min_max_name_dict[param]}_min": value[0],
f"{min_max_name_dict[param]}_max": value[1],
}
min_max_name_dict.update({k: k for k in not_aliased_kwargs})
mmnd_inv = {v: k for k, v in min_max_name_dict.items() if k != v}

# Set user query params from `locals`
user_settings = {
k: v for k, v in locals().items() if k in min_max_name_dict and v
}

# Check to see if user specified _search fields using **kwargs,
# or if any of the **kwargs are unparsable
db_keys = {k: [] for k in ("duplicate", "warn", "unknown")}
for k, v in kwargs.items():
category = "unknown"
if non_db_k := mmnd_inv.get(k):
if user_settings.get(non_db_k):
# Both a search and _search equivalent field are specified
category = "duplicate"
elif v:
# Only the _search field is specified
category = "warn"
user_settings[non_db_k] = v
db_keys[category].append(non_db_k or k)

# If any _search or unknown fields were set, throw warnings/exceptions
if any(db_keys.values()):
warning_strs: list[str] = []
exc_strs: list[str] = []

def csrc(x):
return f"\x1b[34m{x}\x1b[39m"

def _csrc(x):
return f"\x1b[31m{x}\x1b[39m"

# Warn the user if they input any fields from _search without setting equivalent kwargs in search
if db_keys["warn"]:
warning_strs.extend(
[
f"You have specified fields used by {_csrc('`_search`')} that can be understood by {csrc('`search`')}",
f" {', '.join([_csrc(min_max_name_dict[k]) for k in db_keys['warn']])}",
f"To ensure long term support, please use their {csrc('`search`')} equivalents:",
f" {', '.join([csrc(k) for k in db_keys['warn']])}",
]
)

# Throw an exception if the user input a field from _search and its equivalent search kwarg
if db_keys["duplicate"]:
dupe_pairs = "\n".join(
f"{csrc(k)} and {_csrc(min_max_name_dict[k])}"
for k in db_keys["duplicate"]
)
exc_strs.extend(
[
f"You have specified fields known to both {csrc('`search`')} and {_csrc('`_search`')}",
f" {dupe_pairs}",
f"To avoid query ambiguity, please check your {csrc('`search`')} query and only specify",
f" {', '.join([csrc(k) for k in db_keys['duplicate']])}",
]
)
# Throw an exception if any unknown kwargs were input
if db_keys["unknown"]:
exc_strs.extend(
[
f"You have specified the following kwargs which are unknown to {csrc('`search`')}, "
f"but may be known to {_csrc('`_search`')}",
f" \x1b[36m{', '.join(db_keys['unknown'])}\x1b[39m",
]
)

# Always print links to documentation on warning / exception
warn_ref_strs = [
"Please see the documentation:",
f" {csrc('`search`: https://materialsproject.github.io/api/_autosummary/mp_api.client.routes.materials.summary.SummaryRester.html#mp_api.client.routes.materials.summary.SummaryRester.search')}",
f" {_csrc('`_search`: https://api.materialsproject.org/redoc#tag/Materials-Summary/operation/search_materials_summary__get')}",
]

if exc_strs:
raise MPRestError("\n".join([*warning_strs, *exc_strs, *warn_ref_strs]))
if warn_ref_strs:
warnings.warn(
"\n".join([*warning_strs, *warn_ref_strs]), category=MPRestWarning
)

for param, value in user_settings.items():
if isinstance(value, (int, float)):
value = (value, value)
query_params.update(
{
f"{min_max_name_dict[param]}_min": value[0],
f"{min_max_name_dict[param]}_max": value[1],
}
)

if material_ids:
if isinstance(material_ids, str):
Expand Down
1 change: 1 addition & 0 deletions tests/materials/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"theoretical": True,
"has_reconstructed": False,
"magnetic_ordering": Ordering.FM,
"nelements": (8, 9),
} # type: dict


Expand Down
Loading