Skip to content

Commit fee3843

Browse files
author
Nils Edvin Richard Zimmermann
committed
Fixed failing tests, added link to the issue that describes problem with old semicircle function, and type-hinted functions
1 parent 252efa7 commit fee3843

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

src/pymatgen/analysis/local_env.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919
import numpy as np
2020
from monty.dev import deprecated, requires
2121
from monty.serialization import loadfn
22-
from ruamel.yaml import YAML
23-
from scipy.spatial import Voronoi
24-
2522
from pymatgen.analysis.bond_valence import BV_PARAMS, BVAnalyzer
2623
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
2724
from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
2825
from pymatgen.core import Element, IStructure, PeriodicNeighbor, PeriodicSite, Site, Species, Structure
26+
from ruamel.yaml import YAML
27+
from scipy.spatial import Voronoi
2928

3029
try:
3130
from openbabel import openbabel
@@ -35,10 +34,9 @@
3534
if TYPE_CHECKING:
3635
from typing import Any
3736

38-
from typing_extensions import Self
39-
4037
from pymatgen.core.composition import SpeciesLike
4138
from pymatgen.util.typing import Tuple3Ints
39+
from typing_extensions import Self
4240

4341

4442
__author__ = "Shyue Ping Ong, Geoffroy Hautier, Sai Jayaraman, "
@@ -1161,7 +1159,7 @@ def _is_in_targets(site, targets):
11611159
targets ([Element]) List of elements
11621160
11631161
Returns:
1164-
bool: Whether this site contains a certain list of elements
1162+
boolean: Whether this site contains a certain list of elements
11651163
"""
11661164
elems = _get_elements(site)
11671165
return all(elem in targets for elem in elems)
@@ -1218,7 +1216,7 @@ def __init__(
12181216

12191217
# Update any user preference elemental radii
12201218
if el_radius_updates:
1221-
self.el_radius |= el_radius_updates
1219+
self.el_radius.update(el_radius_updates)
12221220

12231221
@property
12241222
def structures_allowed(self) -> bool:
@@ -1984,7 +1982,7 @@ def get_okeeffe_distance_prediction(el1, el2):
19841982
"""Get an estimate of the bond valence parameter (bond length) using
19851983
the derived parameters from 'Atoms Sizes and Bond Lengths in Molecules
19861984
and Crystals' (O'Keeffe & Brese, 1991). The estimate is based on two
1987-
experimental parameters: r and c. The value for r is based off radius,
1985+
experimental parameters: r and c. The value for r is based off radius,
19881986
while c is (usually) the Allred-Rochow electronegativity. Values used
19891987
are *not* generated from pymatgen, and are found in
19901988
'okeeffe_params.json'.
@@ -2755,7 +2753,7 @@ def get_type(self, index):
27552753
raise ValueError("Index for getting order parameter type out-of-bounds!")
27562754
return self._types[index]
27572755

2758-
def get_parameters(self, index: int) -> list[float]:
2756+
def get_parameters(self, index):
27592757
"""Get list of floats that represents
27602758
the parameters associated
27612759
with calculation of the order
@@ -2764,10 +2762,12 @@ def get_parameters(self, index: int) -> list[float]:
27642762
inputted because of processing out of efficiency reasons.
27652763
27662764
Args:
2767-
index (int): index of order parameter for which to return associated params.
2765+
index (int):
2766+
index of order parameter for which associated parameters
2767+
are to be returned.
27682768
27692769
Returns:
2770-
list[float]: parameters of a given OP.
2770+
[float]: parameters of a given OP.
27712771
"""
27722772
if index < 0 or index >= len(self._types):
27732773
raise ValueError("Index for getting parameters associated with order parameter calculation out-of-bounds!")
@@ -3990,7 +3990,7 @@ def get_nn_data(self, structure: Structure, n: int, length=None):
39903990
nn_info.append(entry)
39913991
cn = len(nn_info)
39923992
cn_nninfo[cn] = nn_info
3993-
cn_weights[cn] = self._semicircle_integral(dist_bins, idx)
3993+
cn_weights[cn] = self._quadrant_integral(dist_bins, idx)
39943994

39953995
# add zero coord
39963996
cn0_weight = 1 - sum(cn_weights.values())
@@ -4047,10 +4047,13 @@ def get_cn_dict(self, structure: Structure, n: int, use_weights: bool = False, *
40474047
return super().get_cn_dict(structure, n, use_weights)
40484048

40494049
@staticmethod
4050-
def _semicircle_integral(dist_bins, idx):
4050+
def _semicircle_integral(dist_bins: list, idx: int) -> float:
40514051
"""
40524052
An internal method to get an integral between two bounds of a unit
40534053
semicircle. Used in algorithm to determine bond probabilities.
4054+
This function has an issue, which is detailed at
4055+
https://github.com/materialsproject/pymatgen/issues/3973.
4056+
Therefore, _quadrant_integral is now the method of choice.
40544057
40554058
Args:
40564059
dist_bins: (float) list of all possible bond weights
@@ -4075,6 +4078,35 @@ def _semicircle_integral(dist_bins, idx):
40754078

40764079
return (area1 - area2) / (0.25 * math.pi * radius**2)
40774080

4081+
@staticmethod
4082+
def _quadrant_integral(dist_bins: list, idx: int) -> float:
4083+
"""
4084+
An internal method to get an integral between two bounds of a unit
4085+
quadrant. Used in algorithm to determine bond probabilities.
4086+
4087+
Args:
4088+
dist_bins: (float) list of all possible bond weights
4089+
idx: (float) index of starting bond weight
4090+
4091+
Returns:
4092+
float: integral of portion of unit quadrant
4093+
"""
4094+
radius = 1
4095+
4096+
x1 = dist_bins[idx]
4097+
x2 = dist_bins[idx + 1]
4098+
4099+
areaquarter = 0.25 * math.pi * radius**2
4100+
4101+
area1 = areaquarter - 0.5 * (radius**2 * math.acos(
4102+
1 - x1 / radius) - (radius - x1) * math.sqrt(
4103+
2 * radius * x1 - x1**2))
4104+
area2 = areaquarter - 0.5 * (radius**2 * math.acos(
4105+
1 - x2 / radius) - (radius - x2) * math.sqrt(
4106+
2 * radius * x2 - x2**2))
4107+
4108+
return (area2 - area1) / areaquarter
4109+
40784110
@staticmethod
40794111
def transform_to_length(nn_data, length):
40804112
"""

tests/analysis/test_local_env.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numpy as np
88
import pytest
99
from numpy.testing import assert_allclose
10-
from pytest import approx
11-
1210
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
1311
from pymatgen.analysis.local_env import (
1412
BrunnerNNReal,
@@ -38,6 +36,7 @@
3836
)
3937
from pymatgen.core import Element, Lattice, Molecule, Structure
4038
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest
39+
from pytest import approx
4140

4241
TEST_DIR = f"{TEST_FILES_DIR}/analysis/local_env/fragmenter_files"
4342

@@ -445,6 +444,10 @@ def test_all_nn_classes(self):
445444
assert voronoi_nn.get_cn(self.cscl, 0) == 8
446445
assert voronoi_nn.get_cn(self.lifepo4, 0) == 6
447446

447+
assert CrystalNN._quadrant_integral([1,0.36], 0) == approx(
448+
0.7551954297486029)
449+
assert CrystalNN._quadrant_integral([1,0.36,0], 1) == approx(
450+
1 - 0.7551954297486029)
448451
crystal_nn = CrystalNN()
449452
assert crystal_nn.get_cn(self.diamond, 0) == 4
450453
assert crystal_nn.get_cn(self.nacl, 0) == 6
@@ -1178,7 +1181,7 @@ def test_sanity(self):
11781181
def test_discrete_cn(self):
11791182
cnn = CrystalNN()
11801183
cn_array = []
1181-
expected_array = 8 * [6] + 20 * [4]
1184+
expected_array = 8 * [6] + 6 * [4] + [1] + 2 * [4] + [1] + 4 * [4] + [1] + + 2 * [4] + [1] + 2 * [4]
11821185
for idx, _ in enumerate(self.lifepo4):
11831186
cn_array.append(cnn.get_cn(self.lifepo4, idx))
11841187

@@ -1202,6 +1205,7 @@ def test_weighted_cn(self):
12021205

12031206
def test_weighted_cn_no_oxid(self):
12041207
cnn = CrystalNN(weighted_cn=True)
1208+
cn_array = []
12051209
# fmt: off
12061210
expected_array = [
12071211
5.8962, 5.8996, 5.8962, 5.8996, 5.7195, 5.7195, 5.7202, 5.7194, 4.0012, 4.0012,
@@ -1210,7 +1214,8 @@ def test_weighted_cn_no_oxid(self):
12101214
]
12111215
# fmt: on
12121216
struct = self.lifepo4.copy().remove_oxidation_states()
1213-
cn_array = [cnn.get_cn(struct, idx, use_weights=True) for idx in range(len(struct))]
1217+
for idx in range(len(struct)):
1218+
cn_array.append(cnn.get_cn(struct, idx, use_weights=True))
12141219

12151220
assert_allclose(expected_array, cn_array, 2)
12161221

@@ -1222,11 +1227,11 @@ def test_fixed_length(self):
12221227

12231228
def test_cation_anion(self):
12241229
cnn = CrystalNN(weighted_cn=True, cation_anion=True)
1225-
assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.8630, abs=1e-2)
1230+
assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.5426, abs=1e-2)
12261231

12271232
def test_x_diff_weight(self):
12281233
cnn = CrystalNN(weighted_cn=True, x_diff_weight=0)
1229-
assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.8630, abs=1e-2)
1234+
assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.5426, abs=1e-2)
12301235

12311236
def test_noble_gas_material(self):
12321237
cnn = CrystalNN()

0 commit comments

Comments
 (0)