Skip to content

Commit 4c6ef6a

Browse files
committed
fix test_numeric: trim arrays to its real data
1 parent 1f9a19e commit 4c6ef6a

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

atomdb/species.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class DensitySpline:
179179
def __init__(self, x, y, log=False):
180180
r"""Initialize the CubicSpline instance."""
181181
self._log = log
182+
x, y = self._trim_padded_data(x, y)
182183
self._obj = CubicSpline(
183184
x,
184185
# Clip y values to >= ε^2 if using log because they have to be above 0;
@@ -189,6 +190,13 @@ def __init__(self, x, y, log=False):
189190
extrapolate=True,
190191
)
191192

193+
def _trim_padded_data(self, data_x, data_y, tol=1e-10):
194+
"""Trim padded zeros from the end of arrays."""
195+
trimmed_x = trim_padded_array(data_x, tol)
196+
effective_length = len(trimmed_x)
197+
trimmed_y = data_y[:effective_length]
198+
return trimmed_x, trimmed_y
199+
192200
def __call__(self, x, deriv=0):
193201
r"""
194202
Compute the interpolation at some x-values.
@@ -1002,9 +1010,33 @@ def get_species_data(DATASETS_H5FILE, folder_path, elem, DATASET_PROPERTY_CONFIG
10021010
for prop in ("atmass", "cov_radius", "vdw_radius", "at_radius", "polarizability", "dispersion"):
10031011
fields[prop] = get_scalar_data(prop, fields["atnum"], fields["nelec"])
10041012

1013+
if "rs" in fields:
1014+
rs_trimmed_array = trim_padded_array(fields["rs"])
1015+
effective_length = len(rs_trimmed_array)
1016+
for config in DATASET_PROPERTY_CONFIGS:
1017+
if "Carray_property" in config:
1018+
# Trim only if spins == "no" or spins is not specified
1019+
if config.get("spins") == "no" or ("spins" not in config):
1020+
fields[config["Carray_property"]] = fields[config["Carray_property"]][
1021+
:effective_length
1022+
]
10051023
return fields
10061024

10071025

1026+
def trim_padded_array(data, tol=1e-10):
1027+
"""Trim padded zeros from the end of arrays."""
1028+
# Find non-zero elements
1029+
non_zero_mask = np.abs(data) > tol
1030+
non_zero_indices = np.where(non_zero_mask)[0]
1031+
1032+
if len(non_zero_indices) == 0:
1033+
return data[:0]
1034+
1035+
# Get the last non-zero index
1036+
last_non_zero_index = non_zero_indices[-1]
1037+
return data[: (last_non_zero_index + 1)]
1038+
1039+
10081040
def raw_datafile(
10091041
suffix,
10101042
elem,

atomdb/test/test_numeric.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test_numerical_hf_data_h():
4545
sp = load("H", 0, 2, dataset="numeric", datapath=TEST_DATAPATH)
4646

4747
# check shape radial grid and total density arrays
48-
assert sp._data.rs.shape == (1000,)
49-
assert sp._data.dens_tot.shape == sp._data.rs.shape
48+
npoints = len(sp._data.rs)
49+
assert sp._data.dens_tot.shape == (npoints,)
5050

5151
# check radial grid and total density arrays values
5252
assert all(sp._data.rs >= 0.0)
@@ -81,8 +81,8 @@ def test_numerical_hf_data_h_anion():
8181
assert_almost_equal(sp.energy, -0.487929734301232, decimal=10)
8282

8383
# check shape radial grid and total density arrays
84-
assert sp._data.rs.shape == (1000,)
85-
assert sp._data.dens_tot.shape == sp._data.rs.shape
84+
npoints = len(sp._data.rs)
85+
assert sp._data.dens_tot.shape == (npoints,)
8686

8787
# reference radial values sample and corresponding indices
8888
ref_rs = np.array(
@@ -135,7 +135,7 @@ def test_numerical_hf_energy_especies(atom, mult, energy):
135135

136136

137137
@pytest.mark.parametrize(
138-
"atom, mult, npoints, nelec", [("Be", 1, 1000, 4.0), ("Cl", 2, 1000, 17.0), ("Ne", 1, 1000, 10.0)]
138+
"atom, mult, npoints, nelec", [("Be", 1, 146, 4.0), ("Cl", 2, 164, 17.0), ("Ne", 1, 151, 10.0)]
139139
)
140140
def test_numerical_hf_atomic_density(atom, mult, npoints, nelec):
141141
# load atomic and density data
@@ -230,6 +230,7 @@ def test_numerical_hf_density_laplacian(atom, charge, mult):
230230
ref_lapl = np.load(f"{TEST_DATAPATH}/numeric/db/{fname}")
231231

232232
# check interpolated Laplacian of density values against reference values
233-
assert np.allclose(laplacian_dens, ref_lapl, atol=1e-10)
233+
meaningful_length = len(ref_lapl)
234+
assert np.allclose(laplacian_dens[:meaningful_length], ref_lapl, atol=1e-10)
234235
# for r=0, the Laplacian function in not well defined and is set to zero
235236
assert np.allclose(laplacian_dens[0], [0.0], atol=1e-10)

0 commit comments

Comments
 (0)