1
1
from __future__ import annotations
2
2
3
+ import warnings
3
4
from collections import defaultdict
4
5
5
6
from emmet .core .summary import HasProps , SummaryDoc
6
7
from emmet .core .symmetry import CrystalSystem
7
8
from pymatgen .analysis .magnetism import Ordering
8
9
9
- from mp_api .client .core import BaseRester , MPRestError
10
+ from mp_api .client .core import BaseRester , MPRestError , MPRestWarning
10
11
from mp_api .client .core .utils import validate_ids
11
12
12
13
@@ -15,7 +16,7 @@ class SummaryRester(BaseRester[SummaryDoc]):
15
16
document_model = SummaryDoc # type: ignore
16
17
primary_key = "material_id"
17
18
18
- def search (
19
+ def search ( # noqa: D417
19
20
self ,
20
21
band_gap : tuple [float , float ] | None = None ,
21
22
chemsys : str | list [str ] | None = None ,
@@ -72,6 +73,7 @@ def search(
72
73
chunk_size : int = 1000 ,
73
74
all_fields : bool = True ,
74
75
fields : list [str ] | None = None ,
76
+ ** kwargs ,
75
77
) -> list [SummaryDoc ] | list [dict ]:
76
78
"""Query core data using a variety of search criteria.
77
79
@@ -117,7 +119,8 @@ def search(
117
119
material_ids (str, List[str]): A single Material ID string or list of strings
118
120
(e.g., mp-149, [mp-149, mp-13]).
119
121
n (Tuple[float,float]): Minimum and maximum refractive index to consider.
120
- num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
122
+ nelements (Tuple[int,int]): Minimum and maximum number of elements to consider.
123
+ num_elements (Tuple[int,int]): Alias for `nelements`, deprecated. Minimum and maximum number of elements to consider.
121
124
num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
122
125
num_magnetic_sites (Tuple[int,int]): Minimum and maximum number of magnetic sites to consider.
123
126
num_unique_magnetic_sites (Tuple[int,int]): Minimum and maximum number of unique magnetic sites to consider.
@@ -153,53 +156,138 @@ def search(
153
156
"""
154
157
query_params = defaultdict (dict ) # type: dict
155
158
159
+ not_aliased_kwargs = [
160
+ "energy_above_hull" ,
161
+ "nsites" ,
162
+ "volume" ,
163
+ "density" ,
164
+ "band_gap" ,
165
+ "efermi" ,
166
+ "total_magnetization" ,
167
+ "total_magnetization_normalized_vol" ,
168
+ "total_magnetization_normalized_formula_units" ,
169
+ "num_magnetic_sites" ,
170
+ "num_unique_magnetic_sites" ,
171
+ "k_voigt" ,
172
+ "k_reuss" ,
173
+ "k_vrh" ,
174
+ "g_voigt" ,
175
+ "g_reuss" ,
176
+ "g_vrh" ,
177
+ "e_total" ,
178
+ "e_ionic" ,
179
+ "e_electronic" ,
180
+ "n" ,
181
+ "weighted_surface_energy" ,
182
+ "weighted_work_function" ,
183
+ "shape_factor" ,
184
+ ]
185
+
156
186
min_max_name_dict = {
157
187
"total_energy" : "energy_per_atom" ,
158
188
"formation_energy" : "formation_energy_per_atom" ,
159
- "energy_above_hull" : "energy_above_hull" ,
160
189
"uncorrected_energy" : "uncorrected_energy_per_atom" ,
161
190
"equilibrium_reaction_energy" : "equilibrium_reaction_energy_per_atom" ,
162
- "nsites" : "nsites" ,
163
- "volume" : "volume" ,
164
- "density" : "density" ,
165
- "band_gap" : "band_gap" ,
166
- "efermi" : "efermi" ,
167
- "total_magnetization" : "total_magnetization" ,
168
- "total_magnetization_normalized_vol" : "total_magnetization_normalized_vol" ,
169
- "total_magnetization_normalized_formula_units" : "total_magnetization_normalized_formula_units" ,
170
- "num_magnetic_sites" : "num_magnetic_sites" ,
171
- "num_unique_magnetic_sites" : "num_unique_magnetic_sites" ,
172
- "k_voigt" : "k_voigt" ,
173
- "k_reuss" : "k_reuss" ,
174
- "k_vrh" : "k_vrh" ,
175
- "g_voigt" : "g_voigt" ,
176
- "g_reuss" : "g_reuss" ,
177
- "g_vrh" : "g_vrh" ,
178
191
"elastic_anisotropy" : "universal_anisotropy" ,
179
192
"poisson_ratio" : "homogeneous_poisson" ,
180
- "e_total" : "e_total" ,
181
- "e_ionic" : "e_ionic" ,
182
- "e_electronic" : "e_electronic" ,
183
- "n" : "n" ,
184
193
"num_sites" : "nsites" ,
185
194
"num_elements" : "nelements" ,
186
195
"piezoelectric_modulus" : "e_ij_max" ,
187
- "weighted_surface_energy" : "weighted_surface_energy" ,
188
- "weighted_work_function" : "weighted_work_function" ,
189
196
"surface_energy_anisotropy" : "surface_anisotropy" ,
190
- "shape_factor" : "shape_factor" ,
191
197
}
192
198
193
- for param , value in locals ().items ():
194
- if param in min_max_name_dict and value :
195
- if isinstance (value , (int , float )):
196
- value = (value , value )
197
- query_params .update (
198
- {
199
- f"{ min_max_name_dict [param ]} _min" : value [0 ],
200
- f"{ min_max_name_dict [param ]} _max" : value [1 ],
201
- }
199
+ min_max_name_dict .update ({k : k for k in not_aliased_kwargs })
200
+ mmnd_inv = {v : k for k , v in min_max_name_dict .items () if k != v }
201
+
202
+ # Set user query params from `locals`
203
+ user_settings = {
204
+ k : v for k , v in locals ().items () if k in min_max_name_dict and v
205
+ }
206
+
207
+ # Check to see if user specified _search fields using **kwargs,
208
+ # or if any of the **kwargs are unparsable
209
+ db_keys = {k : [] for k in ("duplicate" , "warn" , "unknown" )}
210
+ for k , v in kwargs .items ():
211
+ category = "unknown"
212
+ if non_db_k := mmnd_inv .get (k ):
213
+ if user_settings .get (non_db_k ):
214
+ # Both a search and _search equivalent field are specified
215
+ category = "duplicate"
216
+ elif v :
217
+ # Only the _search field is specified
218
+ category = "warn"
219
+ user_settings [non_db_k ] = v
220
+ db_keys [category ].append (non_db_k or k )
221
+
222
+ # If any _search or unknown fields were set, throw warnings/exceptions
223
+ if any (db_keys .values ()):
224
+ warning_strs : list [str ] = []
225
+ exc_strs : list [str ] = []
226
+
227
+ def csrc (x ):
228
+ return f"\x1b [34m{ x } \x1b [39m"
229
+
230
+ def _csrc (x ):
231
+ return f"\x1b [31m{ x } \x1b [39m"
232
+
233
+ # Warn the user if they input any fields from _search without setting equivalent kwargs in search
234
+ if db_keys ["warn" ]:
235
+ warning_strs .extend (
236
+ [
237
+ f"You have specified fields used by { _csrc ('`_search`' )} that can be understood by { csrc ('`search`' )} " ,
238
+ f" { ', ' .join ([_csrc (min_max_name_dict [k ]) for k in db_keys ['warn' ]])} " ,
239
+ f"To ensure long term support, please use their { csrc ('`search`' )} equivalents:" ,
240
+ f" { ', ' .join ([csrc (k ) for k in db_keys ['warn' ]])} " ,
241
+ ]
242
+ )
243
+
244
+ # Throw an exception if the user input a field from _search and its equivalent search kwarg
245
+ if db_keys ["duplicate" ]:
246
+ dupe_pairs = "\n " .join (
247
+ f"{ csrc (k )} and { _csrc (min_max_name_dict [k ])} "
248
+ for k in db_keys ["duplicate" ]
202
249
)
250
+ exc_strs .extend (
251
+ [
252
+ f"You have specified fields known to both { csrc ('`search`' )} and { _csrc ('`_search`' )} " ,
253
+ f" { dupe_pairs } " ,
254
+ f"To avoid query ambiguity, please check your { csrc ('`search`' )} query and only specify" ,
255
+ f" { ', ' .join ([csrc (k ) for k in db_keys ['duplicate' ]])} " ,
256
+ ]
257
+ )
258
+ # Throw an exception if any unknown kwargs were input
259
+ if db_keys ["unknown" ]:
260
+ exc_strs .extend (
261
+ [
262
+ f"You have specified the following kwargs which are unknown to { csrc ('`search`' )} , "
263
+ f"but may be known to { _csrc ('`_search`' )} " ,
264
+ f" \x1b [36m{ ', ' .join (db_keys ['unknown' ])} \x1b [39m" ,
265
+ ]
266
+ )
267
+
268
+ # Always print links to documentation on warning / exception
269
+ warn_ref_strs = [
270
+ "Please see the documentation:" ,
271
+ f" { csrc ('`search`: https://materialsproject.github.io/api/_autosummary/mp_api.client.routes.materials.summary.SummaryRester.html#mp_api.client.routes.materials.summary.SummaryRester.search' )} " ,
272
+ f" { _csrc ('`_search`: https://api.materialsproject.org/redoc#tag/Materials-Summary/operation/search_materials_summary__get' )} " ,
273
+ ]
274
+
275
+ if exc_strs :
276
+ raise MPRestError ("\n " .join ([* warning_strs , * exc_strs , * warn_ref_strs ]))
277
+ if warn_ref_strs :
278
+ warnings .warn (
279
+ "\n " .join ([* warning_strs , * warn_ref_strs ]), category = MPRestWarning
280
+ )
281
+
282
+ for param , value in user_settings .items ():
283
+ if isinstance (value , (int , float )):
284
+ value = (value , value )
285
+ query_params .update (
286
+ {
287
+ f"{ min_max_name_dict [param ]} _min" : value [0 ],
288
+ f"{ min_max_name_dict [param ]} _max" : value [1 ],
289
+ }
290
+ )
203
291
204
292
if material_ids :
205
293
if isinstance (material_ids , str ):
0 commit comments