Skip to content

Commit 4126ef0

Browse files
Push coverage to ~97% headroom above the 95% nightly gate
After the parser fix the nightly cleared 95% with little margin. This commit closes the largest remaining gaps so a small incidental change does not immediately drop us back under the gate. New / updated test files: * tests/test_entropy_solver_option_z_smoke.py: end-to-end EntropySolver.solve() with use_jax_jacobian=true and a registered JAX CVODE factory. Exercises the wire-up at entropy_solver.py:2417-2424 (factory invocation + Option-Z-active log) and _solve_cvode lines 1900-1905 + 2002-2005 (JAX RHS / Jacobian install path: dense linsolver, banded keys popped). * tests/test_cvode_jax_exception_paths.py: forces the rhs_fn / jacfn except branches by passing wrong-shape ydot / J buffers so the in-place numpy assignment raises (cvode_jax.py:217-237). Plus a happy-path test of the radio_isotope_params 5-tuple branch (line 179). * tests/test_entropy_solver_remaining_gaps.py: smoke + unit tests for the small scattered branches: energy_balance with inner_bc_kind=1 dispatch (line 1450), set_initial_entropy fallback paths when called pre-initialize (1143-1144, 1157-1158), and the Phi_global mean fallback when mass_total = 0 (line 2885). * tests/test_parser_toml_dispatch.py: extends with two new regression tests covering the bad-radionuclide TypeError wrap and the IC-method-2 init_temperature loadtxt path. * tests/test_cli_more_branches.py: adds tests for the ImportError fallback in _version_message (lines 60-61) and the vnv module-spec-load failure (line 921). * tests/test_entropy_solver_eos_method2_smoke.py: now also calls get_state() so the UserDefinedEOS branch in entropy_solver.py:2848-2850 (no get_mass_within_radii) is exercised end-to-end. No production-code changes. All 33 affected tests pass; full suite (565 tests) verified locally before this commit.
1 parent 90f9a21 commit 4126ef0

6 files changed

Lines changed: 916 additions & 0 deletions

tests/test_cli_more_branches.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,71 @@ def read_text(self, encoding=None):
262262
# ──────────────────────────────────────────────────────────────────────
263263

264264

265+
def test_version_message_handles_missing_dependency(monkeypatch):
266+
"""Lines 60-61 of cli.py: ``_version_message`` builds the
267+
``aragog --versions`` block by importing each dependency by name;
268+
when one raises ImportError or other exceptions, the line must
269+
fall back to ``"not installed"`` rather than aborting the whole
270+
listing.
271+
272+
Discriminator: a regression that let the exception propagate
273+
would make ``aragog --versions`` (the bug-report-friendly mode)
274+
crash on any environment that lacks one of the optional deps.
275+
"""
276+
import importlib
277+
278+
import aragog.cli as cli_mod
279+
280+
real_import = importlib.import_module
281+
282+
def _faulty_import(name, package=None):
283+
if name == 'jax':
284+
raise ImportError('intentional test fault')
285+
return real_import(name, package=package)
286+
287+
monkeypatch.setattr(importlib, 'import_module', _faulty_import)
288+
msg = cli_mod._version_message()
289+
assert 'aragog ' in msg
290+
assert 'jax: not installed' in msg, (
291+
f'expected "jax: not installed" line on ImportError; got {msg!r}'
292+
)
293+
294+
295+
def test_vnv_handles_module_spec_load_failure(monkeypatch, tmp_path):
296+
"""Line 921 of cli.py: when ``importlib.util.spec_from_file_location``
297+
cannot build a spec / loader, ``aragog vnv <topic>`` must raise
298+
ClickException with a "could not load" message.
299+
300+
Discriminator: forced spec=None via monkeypatch makes the loader
301+
branch fire. A regression that swallowed the failure or let the
302+
None-spec into ``module_from_spec`` would crash with AttributeError
303+
deeper in the stack instead of the user-facing "could not load".
304+
"""
305+
import importlib.util
306+
307+
from click.testing import CliRunner
308+
309+
import aragog.cli as cli_mod
310+
311+
figures_dir = tmp_path / 'figs'
312+
figures_dir.mkdir()
313+
(figures_dir / 'verify_demo.py').write_text('def main(): pass\n', encoding='utf-8')
314+
monkeypatch.setattr(cli_mod, '_vnv_figures_dir', lambda: figures_dir)
315+
316+
real_spec_from_file = importlib.util.spec_from_file_location
317+
318+
def _none_spec(*args, **kwargs):
319+
return None
320+
321+
monkeypatch.setattr(importlib.util, 'spec_from_file_location', _none_spec)
322+
runner = CliRunner()
323+
result = runner.invoke(cli_mod.cli, ['vnv', 'demo'])
324+
assert result.exit_code != 0
325+
assert 'could not load' in (result.output or '').lower()
326+
# Restore (autouse cleanup happens via monkeypatch but assert).
327+
monkeypatch.setattr(importlib.util, 'spec_from_file_location', real_spec_from_file)
328+
329+
265330
def test_serialize_params_walks_radionuclide_list_and_unboxes_np_scalars():
266331
"""Lines 732, 736: ``_serialize_params`` recurses into lists and
267332
unboxes numpy scalars via ``.item()``.
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Exception paths inside ``build_jax_rhs_and_jacobian``'s rhs_fn /
2+
jacfn callbacks.
3+
4+
The factory's contract is that any uncaught exception inside the
5+
JAX-traced RHS or Jacobian must be swallowed (logged + return 1) so
6+
CVODE can fall back to its own finite-difference Jacobian rather
7+
than aborting the integrator partway through a long solve.
8+
9+
The contract tests in test_cvode_jax_factory.py and the invocation
10+
tests in test_cvode_jax_factory_invocation.py exercise the happy
11+
paths but not these except branches (lines 221-223, 235-237 of
12+
solver/cvode_jax.py). Force the failures by giving the in-place
13+
output buffers the wrong shape so the trailing ``ydot[:] = ...`` /
14+
``J[...] = ...`` write raises during numpy broadcast.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import logging
20+
import os
21+
from pathlib import Path
22+
23+
import numpy as np
24+
import pytest
25+
26+
jax = pytest.importorskip('jax')
27+
jnp = pytest.importorskip('jax.numpy')
28+
eqx = pytest.importorskip('equinox')
29+
30+
jax.config.update('jax_enable_x64', True)
31+
32+
pytestmark = pytest.mark.unit
33+
34+
35+
_REPO_ROOT = Path(__file__).resolve().parent.parent
36+
_FWL_DATA = os.environ.get('FWL_DATA')
37+
_CANDIDATES = [
38+
os.environ.get('ARAGOG_TEST_EOS_DIR'),
39+
f'{_FWL_DATA}/aragog/spider_eos' if _FWL_DATA else None,
40+
str(_REPO_ROOT.parent / 'output' / 'coupled_parity' / 'spider' / 'data' / 'spider_eos'),
41+
]
42+
EOS_DIR = next(
43+
(Path(p) for p in _CANDIDATES if p and Path(p).exists()),
44+
Path(_CANDIDATES[-1]),
45+
)
46+
needs_eos = pytest.mark.skipif(
47+
not EOS_DIR.exists(),
48+
reason=f'SPIDER P-S tables not found at {EOS_DIR}.',
49+
)
50+
51+
52+
def _make_const_property_mesh(N: int = 8):
53+
"""Minimal MeshArrays for cheap factory tests."""
54+
from aragog.jax.phase import MeshArrays
55+
56+
R_INNER = 3.480e6
57+
R_OUTER = 6.371e6
58+
r_stag = np.linspace(R_INNER, R_OUTER, N)
59+
dr = np.diff(r_stag)
60+
r_basic = np.zeros(N + 1)
61+
r_basic[0] = R_INNER
62+
r_basic[-1] = R_OUTER
63+
r_basic[1:-1] = 0.5 * (r_stag[:-1] + r_stag[1:])
64+
area = 4.0 * np.pi * r_basic**2
65+
volume = (4.0 / 3.0) * np.pi * np.diff(r_basic**3)
66+
ml = np.maximum(np.minimum(r_basic - R_INNER, R_OUTER - r_basic), 1.0)
67+
d_dr = np.zeros((N + 1, N))
68+
for i in range(1, N):
69+
d_dr[i, i - 1] = -1.0 / dr[i - 1]
70+
d_dr[i, i] = 1.0 / dr[i - 1]
71+
d_dr[0, :] = d_dr[1, :]
72+
d_dr[-1, :] = d_dr[-2, :]
73+
q_mat = np.zeros((N + 1, N))
74+
q_mat[0, 0] = 1.0
75+
q_mat[-1, -1] = 1.0
76+
for i in range(1, N):
77+
q_mat[i, i - 1] = 0.5
78+
q_mat[i, i] = 0.5
79+
P_stag = np.linspace(135e9, 1e5, N)
80+
P_basic = q_mat @ P_stag
81+
return MeshArrays(
82+
d_dr_matrix=jnp.asarray(d_dr),
83+
quantity_matrix=jnp.asarray(q_mat),
84+
area=jnp.asarray(area),
85+
volume=jnp.asarray(volume),
86+
radii_basic=jnp.asarray(r_basic),
87+
radii_stag=jnp.asarray(r_stag),
88+
mixing_length=jnp.asarray(ml),
89+
mixing_length_sq=jnp.asarray(ml**2),
90+
mixing_length_cu=jnp.asarray(ml**3),
91+
P_stag=jnp.asarray(P_stag),
92+
P_basic=jnp.asarray(P_basic),
93+
gravity=jnp.full(N + 1, 10.0),
94+
)
95+
96+
97+
def _make_bc():
98+
from aragog.jax.solver import BoundaryParams
99+
100+
return BoundaryParams(
101+
outer_bc_type=4,
102+
outer_bc_value=0.0,
103+
emissivity=1.0,
104+
T_eq=255.0,
105+
inner_bc_type=2,
106+
inner_bc_value=0.0,
107+
core_density=10500.0,
108+
core_heat_capacity=880.0,
109+
tfac_core_avg=1.147,
110+
)
111+
112+
113+
def _build_factory():
114+
"""Build a factory at quasi_steady mode against a real EOS."""
115+
if not EOS_DIR.exists():
116+
pytest.skip(f'EOS unavailable at {EOS_DIR}')
117+
118+
from aragog.jax.eos import EntropyEOS_JAX
119+
from aragog.jax.nondim import NonDimScales
120+
from aragog.jax.phase import PhaseParams
121+
from aragog.solver.cvode_jax import build_jax_rhs_and_jacobian
122+
123+
eos_jax = EntropyEOS_JAX(EOS_DIR)
124+
params = PhaseParams()
125+
mesh = _make_const_property_mesh(N=8)
126+
bc = _make_bc()
127+
n_stag = int(mesh.P_stag.shape[0])
128+
state_scale = np.full(n_stag, 3.0e3)
129+
scales = NonDimScales(state_scale=state_scale, t_ref=1.0)
130+
heating = np.zeros(n_stag)
131+
rhs_fn, jac_fn, info = build_jax_rhs_and_jacobian(
132+
eos_jax=eos_jax,
133+
phase_params=params,
134+
mesh_arrays=mesh,
135+
boundary_params=bc,
136+
heating_array=heating,
137+
scales=scales,
138+
core_bc_mode='quasi_steady',
139+
)
140+
return rhs_fn, jac_fn, info, n_stag
141+
142+
143+
@needs_eos
144+
def test_factory_rhs_fn_returns_one_on_buffer_shape_mismatch(caplog):
145+
"""Passing an ``ydot_nd`` whose length disagrees with the JAX
146+
output forces the in-place write ``ydot_nd[:] = np.asarray(result)``
147+
to raise. The except branch must catch, log an ERROR-level
148+
message naming the failure, and return 1 so CVODE drops to its
149+
own FD-Jacobian.
150+
151+
Discriminator: silence + return 0 would let CVODE proceed with
152+
a stale ydot, producing a corrupt trajectory. The non-zero
153+
return AND the error log are both required.
154+
"""
155+
rhs_fn, _, info, n_stag = _build_factory()
156+
y_nd = np.full(n_stag, 3050.0 / 3.0e3)
157+
ydot_nd = np.zeros(n_stag - 3) # wrong length: 5 vs n_stag=8
158+
159+
with caplog.at_level(logging.ERROR, logger='fwl.aragog.solver.cvode_jax'):
160+
rc = rhs_fn(0.0, y_nd, ydot_nd)
161+
162+
assert rc == 1, f'shape-mismatch rhs_fn must return 1; got {rc}'
163+
err_msgs = [r.message for r in caplog.records if r.levelno >= logging.ERROR]
164+
assert any('JAX RHS failed' in m for m in err_msgs), (
165+
f'expected "JAX RHS failed" error log; got {err_msgs}'
166+
)
167+
168+
169+
@needs_eos
170+
def test_factory_jacfn_returns_one_on_buffer_shape_mismatch(caplog):
171+
"""Same exception contract for the Jacobian callback. Wrong-shape
172+
J buffer forces the in-place write ``J[...] = np.asarray(jac)``
173+
to raise. The except branch must catch, log an ERROR, and return
174+
1 so CVODE drops to its FD Jacobian.
175+
"""
176+
_, jacfn, info, n_stag = _build_factory()
177+
y_nd = np.full(n_stag, 3050.0 / 3.0e3)
178+
fy_nd = np.zeros(n_stag)
179+
J = np.zeros((n_stag - 1, n_stag - 1)) # wrong shape: (7,7) vs (8,8)
180+
181+
with caplog.at_level(logging.ERROR, logger='fwl.aragog.solver.cvode_jax'):
182+
rc = jacfn(0.0, y_nd, fy_nd, J)
183+
184+
assert rc == 1, f'shape-mismatch jacfn must return 1; got {rc}'
185+
err_msgs = [r.message for r in caplog.records if r.levelno >= logging.ERROR]
186+
assert any('JAX Jacobian failed' in m for m in err_msgs), (
187+
f'expected "JAX Jacobian failed" error log; got {err_msgs}'
188+
)
189+
190+
191+
@needs_eos
192+
def test_factory_with_radio_isotope_params_builds_callable_radio_heating(caplog):
193+
"""Lines 179 of cvode_jax.py: the ``radio_isotope_params`` 5-tuple
194+
branch builds a JAX-traceable ``H_radio(t_yr)`` callable via
195+
``make_radio_heating_fn``.
196+
197+
Discriminator: a regression that lost the 5-tuple unpacking
198+
would raise during construction. Verifies the factory builds
199+
AND the resulting RHS callback executes once without errors,
200+
which means the radio-heating call inside the JAX trace did
201+
not fail.
202+
"""
203+
from aragog.jax.eos import EntropyEOS_JAX
204+
from aragog.jax.nondim import NonDimScales
205+
from aragog.jax.phase import PhaseParams
206+
from aragog.solver.cvode_jax import build_jax_rhs_and_jacobian
207+
208+
eos_jax = EntropyEOS_JAX(EOS_DIR)
209+
params = PhaseParams()
210+
mesh = _make_const_property_mesh(N=8)
211+
bc = _make_bc()
212+
n_stag = int(mesh.P_stag.shape[0])
213+
state_scale = np.full(n_stag, 3.0e3)
214+
scales = NonDimScales(state_scale=state_scale, t_ref=1.0)
215+
heating = np.zeros(n_stag)
216+
217+
# 5-tuple in the factory's expected order:
218+
# (heat_prod, abundance, concentration, t0_years, half_life_years).
219+
# Values approximate Al26 at the Solar System epoch.
220+
rhs_fn, _jac_fn, info = build_jax_rhs_and_jacobian(
221+
eos_jax=eos_jax,
222+
phase_params=params,
223+
mesh_arrays=mesh,
224+
boundary_params=bc,
225+
heating_array=heating,
226+
scales=scales,
227+
core_bc_mode='quasi_steady',
228+
radio_isotope_params=(
229+
0.3583, # heat_prod
230+
1.0, # abundance
231+
1.0e-9, # concentration
232+
4.55e9, # t0_years
233+
7.17e5, # half_life_years
234+
),
235+
)
236+
# First call must succeed: the radio_heating function fires
237+
# inside the JAX trace at this t.
238+
y_nd = np.full(n_stag, 3050.0 / 3.0e3)
239+
ydot_nd = np.zeros(n_stag)
240+
rc = rhs_fn(0.0, y_nd, ydot_nd)
241+
assert rc == 0, f'radio-heating-aware rhs_fn returned {rc}; expected 0'
242+
assert info['rhs_calls'] == 1

tests/test_entropy_solver_eos_method2_smoke.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,18 @@ def test_solver_loads_external_mesh_with_eos_method_2(shared_eos, tmp_path_facto
204204
final_y = solver._solution.y[:, -1] if solver._solution.y.ndim == 2 else solver._solution.y
205205
assert np.all(np.isfinite(final_y)), 'eos_method=2 final state has NaN'
206206

207+
# Discriminator: get_state() with UserDefinedEOS triggers the
208+
# discrete-sum mantle-mass branch (entropy_solver.py:2848-2850)
209+
# because UserDefinedEOS lacks the analytic
210+
# ``get_mass_within_radii`` method that AdamsWilliamsonEOS
211+
# provides. M_mantle must come back finite and physically
212+
# plausible for an Earth-like geometry (a few 1e24 kg).
213+
out = solver.get_state()
214+
assert np.isfinite(out.M_mantle)
215+
assert 1.0e23 < float(out.M_mantle) < 1.0e25, (
216+
f'M_mantle = {float(out.M_mantle):.3e} kg outside plausible Earth-like mantle range'
217+
)
218+
207219

208220
def test_solver_reset_reloads_external_mesh_on_each_call(shared_eos, tmp_path_factory):
209221
"""``EntropySolver.reset()`` must re-read the eos_file when

0 commit comments

Comments
 (0)