Skip to content
This repository was archived by the owner on Aug 8, 2024. It is now read-only.

Commit 732dead

Browse files
committed
more tests
Signed-off-by: Nathaniel Starkman (@nstarman) <[email protected]>
1 parent 58cb75c commit 732dead

File tree

8 files changed

+268
-56
lines changed

8 files changed

+268
-56
lines changed

sample_scf/base.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import numpy.typing as npt
2121
from astropy.coordinates import PhysicsSphericalRepresentation
22+
from galpy.potential import SCFPotential
2223
from scipy._lib._util import check_random_state
2324
from scipy.stats import rv_continuous
2425

@@ -102,6 +103,14 @@ class SCFSamplerBase:
102103
pot : `galpy.potential.SCFPotential`
103104
"""
104105

106+
def __init__(
107+
self,
108+
pot: SCFPotential,
109+
):
110+
self._pot = pot
111+
112+
# /def
113+
105114
_rsampler: rv_continuous_modrvs
106115
_thetasampler: rv_continuous_modrvs
107116
_phisampler: rv_continuous_modrvs
@@ -147,7 +156,7 @@ def cdf(
147156
(N, 3) ndarray
148157
"""
149158
R: NDArray64 = self.rsampler.cdf(r)
150-
Theta: NDArray64 = self.thetasampler.cdf(theta=theta, r=r)
159+
Theta: NDArray64 = self.thetasampler.cdf(theta, r=r)
151160
Phi: NDArray64 = self.phisampler.cdf(phi, r=r, theta=theta)
152161

153162
RTP: NDArray64 = np.c_[R, Theta, Phi]

sample_scf/conftest.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def pytest_configure(config):
7575
# Fixtures
7676

7777

78-
@pytest.fixture(scope="session")
78+
@pytest.fixture(autouse=True, scope="session")
7979
def hernquist_scf_potential():
8080
"""Make a SCF of a Hernquist potential."""
8181
Acos = np.zeros((5, 6, 6))
@@ -90,10 +90,27 @@ def hernquist_scf_potential():
9090
# /def
9191

9292

93-
@pytest.fixture(scope="session")
94-
def nfw_scf_potential():
95-
"""Make a SCF of a triaxial NFW potential."""
96-
raise NotImplementedError("TODO")
97-
98-
99-
# /def
93+
# @pytest.fixture(autouse=True, scope="session")
94+
# def nfw_scf_potential():
95+
# """Make a SCF of a triaxial NFW potential."""
96+
# raise NotImplementedError("TODO")
97+
#
98+
#
99+
# # /def
100+
101+
102+
@pytest.fixture(
103+
# autouse=True,
104+
scope="session",
105+
params=[
106+
"hernquist_scf_potential", # TODO! use hernquist_scf_potential
107+
],
108+
)
109+
def potentials(request):
110+
if request.param == "hernquist_scf_potential":
111+
Acos = np.zeros((5, 6, 6))
112+
Acos_hern = Acos.copy()
113+
Acos_hern[0, 0, 0] = 1
114+
potential = SCFPotential(Acos=Acos_hern)
115+
116+
yield potential

sample_scf/core.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,23 @@
1919
from galpy.potential import SCFPotential
2020

2121
# LOCAL
22-
from .base import SCFSamplerBase
22+
from .base import SCFSamplerBase, rv_continuous_modrvs
23+
from .sample_exact import SCFSampler as SCFSamplerExact
24+
from .sample_intrp import SCFSampler as SCFSamplerIntrp
2325

2426
__all__: T.List[str] = ["SCFSampler"]
2527

2628

29+
##############################################################################
30+
# Parameters
31+
32+
33+
class MethodsMapping(T.TypedDict):
34+
r: rv_continuous_modrvs
35+
theta: rv_continuous_modrvs
36+
phi: rv_continuous_modrvs
37+
38+
2739
##############################################################################
2840
# CODE
2941
##############################################################################
@@ -85,15 +97,32 @@ class SCFSampler(SCFSamplerBase): # metaclass=SCFSamplerSwitch
8597
def __init__(
8698
self,
8799
pot: SCFPotential,
88-
method: T.Union[T.Literal["interp", "exact"], T.Mapping],
100+
method: T.Union[T.Literal["interp", "exact"], MethodsMapping],
89101
**kwargs: T.Any
90102
) -> None:
91-
if not isinstance(method, Mapping):
92-
raise NotImplementedError
93-
94-
self._rsampler = method["r"](pot, **kwargs)
95-
self._thetasampler = method["theta"](pot, **kwargs)
96-
self._phisampler = method["phi"](pot, **kwargs)
103+
super().__init__(pot)
104+
105+
if isinstance(method, Mapping):
106+
sampler = None
107+
rsampler = method["r"](pot, **kwargs)
108+
thetasampler = method["theta"](pot, **kwargs)
109+
phisampler = method["phi"](pot, **kwargs)
110+
else:
111+
sampler_cls: rv_continuous_modrvs
112+
if method == "interp":
113+
sampler_cls = SCFSamplerIntrp
114+
elif method == "exact":
115+
sampler_cls = SCFSamplerExact
116+
117+
sampler = sampler_cls(pot, **kwargs)
118+
rsampler = self._sampler._rsampler
119+
thetasampler = self._sampler._thetasampler
120+
phisampler = self._sampler._phisampler
121+
122+
self._sampler: T.Optional[SCFSamplerBase] = sampler
123+
self._rsampler = rsampler
124+
self._thetasampler = thetasampler
125+
self._phisampler = phisampler
97126

98127
# /def
99128

sample_scf/sample_exact.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import numpy.typing as npt
2121
from galpy.potential import SCFPotential
22-
from scipy.stats import rv_continuous
2322

2423
# LOCAL
2524
from ._typing import NDArray64, RandomLike
@@ -72,7 +71,7 @@ def __init__(self, pot: SCFPotential, **kw: T.Any) -> None:
7271
# radial sampler
7372

7473

75-
class SCFRSampler(rv_continuous):
74+
class SCFRSampler(rv_continuous_modrvs):
7675
"""Sample radial coordinate from an SCF potential.
7776
7877
Parameters

sample_scf/sample_intrp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
raise ValueError(f"Qls must be shape ({len(rgrid)}, {lmax})")
271271

272272
# l = 0 : spherical symmetry
273-
term0 = T.cast(npt.NDArray, 0.5 * (xs + 1)) # (T,)
273+
term0 = T.cast(NDArray64, 0.5 * (xs + 1)) # (T,)
274274
# l = 1+ : non-symmetry
275275
factor = 1.0 / (2.0 * Qls[:, 0]) # (R,)
276276
term1p = np.sum(

sample_scf/tests/test_base.py

+129-25
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,61 @@
77
# IMPORTS
88

99
# THIRD PARTY
10-
# import astropy.units as u
11-
# import numpy as np
10+
import astropy.coordinates as coord
11+
import astropy.units as u
12+
import numpy as np
1213
import pytest
14+
from numpy.testing import assert_allclose
1315

1416
# LOCAL
1517
from sample_scf import base
1618

17-
# from galpy.potential import SCFPotential
18-
19-
2019
##############################################################################
2120
# TESTS
2221
##############################################################################
2322

2423

25-
class Test_rv_continuous_modrvs:
24+
class testrvsampler(base.rv_continuous_modrvs):
25+
def _cdf(self, x, *args, **kwargs):
26+
return x
27+
28+
# /def
29+
30+
cdf = _cdf
31+
32+
def _rvs(self, *args, size=None, random_state=None):
33+
if random_state is None:
34+
random_state = np.random
35+
36+
return np.atleast_1d(random_state.uniform(size=size))
37+
38+
# /def
39+
40+
41+
# /class
42+
43+
44+
class Test_RVContinuousModRVS:
2645
"""Test `sample_scf.base.rv_continuous_modrvs`."""
2746

28-
@pytest.mark.skip("TODO!")
29-
def test_rvs(self):
47+
def setup_class(self):
48+
self.sampler = testrvsampler()
49+
50+
# /def
51+
52+
# ===============================================================
53+
54+
@pytest.mark.parametrize(
55+
"size, random, expected",
56+
[
57+
(None, 0, 0.5488135039273248),
58+
(1, 2, 0.43599490214200376),
59+
((3, 1), 4, (0.9670298390136767, 0.5472322491757223, 0.9726843599648843)),
60+
],
61+
)
62+
def test_rvs(self, size, random, expected):
3063
"""Test :meth:`sample_scf.base.rv_continuous_modrvs.rvs`."""
31-
assert False
64+
assert_allclose(self.sampler.rvs(size=size, random_state=random), expected, atol=1e-16)
3265

3366
# /def
3467

@@ -42,46 +75,117 @@ def test_rvs(self):
4275
class Test_SCFSamplerBase:
4376
"""Test :class:`sample_scf.base.SCFSamplerBase`."""
4477

45-
_cls = base.SCFSamplerBase
78+
def setup_class(self):
79+
self.cls = base.SCFSamplerBase
80+
self.cls_args = ()
4681

47-
@pytest.mark.skip("TODO!")
48-
def test_rsampler(self):
82+
self.expected_rvs = {
83+
0: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad),
84+
1: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad),
85+
2: dict(
86+
r=[0.9670298390136, 0.5472322491757, 0.9726843599648, 0.7148159936743],
87+
theta=[0.603766487781, 1.023564077619, 0.598111966830, 0.855980333120] * u.rad,
88+
phi=[0.9670298390136, 0.547232249175, 0.9726843599648, 0.7148159936743] * u.rad,
89+
),
90+
}
91+
92+
# /def
93+
94+
@pytest.fixture(autouse=True, scope="class")
95+
def sampler(self, potentials):
96+
"""Set up r, theta, phi sampler."""
97+
sampler = self.cls(potentials, *self.cls_args)
98+
sampler._rsampler = testrvsampler()
99+
sampler._thetasampler = testrvsampler()
100+
sampler._phisampler = testrvsampler()
101+
102+
return sampler
103+
104+
# /def
105+
106+
# ===============================================================
107+
108+
def test_rsampler(self, sampler):
49109
"""Test :meth:`sample_scf.base.SCFSamplerBase.rsampler`."""
50-
assert False
110+
assert isinstance(sampler.rsampler, base.rv_continuous_modrvs)
51111

52112
# /def
53113

54-
@pytest.mark.skip("TODO!")
55-
def test_thetasampler(self):
114+
def test_thetasampler(self, sampler):
56115
"""Test :meth:`sample_scf.base.SCFSamplerBase.thetasampler`."""
57-
assert False
116+
assert isinstance(sampler.thetasampler, base.rv_continuous_modrvs)
58117

59118
# /def
60119

61-
@pytest.mark.skip("TODO!")
62-
def test_phisampler(self):
120+
def test_phisampler(self, sampler):
63121
"""Test :meth:`sample_scf.base.SCFSamplerBase.phisampler`."""
64-
assert False
122+
assert isinstance(sampler.phisampler, base.rv_continuous_modrvs)
65123

66124
# /def
67125

68-
@pytest.mark.skip("TODO!")
69-
def test_cdf(self):
126+
@pytest.mark.parametrize(
127+
"r, theta, phi, expected",
128+
[
129+
(0, 0, 0, [0, 0, 0]),
130+
(1, 0, 0, [1, 0, 0]),
131+
([0, 1], [0, 0], [0, 0], [[0, 0, 0], [1, 0, 0]]),
132+
],
133+
)
134+
def test_cdf(self, sampler, r, theta, phi, expected):
70135
"""Test :meth:`sample_scf.base.SCFSamplerBase.cdf`."""
71-
assert False
136+
assert np.allclose(sampler.cdf(r, theta, phi), expected, atol=1e-16)
72137

73138
# /def
74139

75-
@pytest.mark.skip("TODO!")
76-
def test_rvs(self):
140+
@pytest.mark.parametrize(
141+
"id, size, random",
142+
[
143+
(0, None, 0),
144+
(1, 1, 0),
145+
(2, 4, 4),
146+
],
147+
)
148+
def test_rvs(self, sampler, id, size, random):
77149
"""Test :meth:`sample_scf.base.SCFSamplerBase.rvs`."""
78-
assert False
150+
samples = sampler.rvs(size=size, random_state=random)
151+
sce = coord.PhysicsSphericalRepresentation(**self.expected_rvs[id])
152+
153+
assert_allclose(samples.r, sce.r, atol=1e-16)
154+
assert_allclose(samples.theta.value, sce.theta.value, atol=1e-16)
155+
assert_allclose(samples.phi.value, sce.phi.value, atol=1e-16)
79156

80157
# /def
81158

82159

83160
# /class
84161

85162

163+
class SCFSamplerTestBase(Test_SCFSamplerBase):
164+
def setup_class(self):
165+
166+
self.expected_rvs = {
167+
0: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad),
168+
1: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad),
169+
2: dict(
170+
r=[0.9670298390136, 0.5472322491757, 0.9726843599648, 0.7148159936743],
171+
theta=[0.603766487781, 1.023564077619, 0.598111966830, 0.855980333120] * u.rad,
172+
phi=[0.9670298390136, 0.547232249175, 0.9726843599648, 0.7148159936743] * u.rad,
173+
),
174+
}
175+
176+
# /def
177+
178+
@pytest.fixture(autouse=True, scope="class")
179+
def sampler(self, potentials):
180+
"""Set up r, theta, phi sampler."""
181+
sampler = self.cls(potentials, *self.cls_args)
182+
183+
return sampler
184+
185+
# /def
186+
187+
188+
# /class
189+
86190
##############################################################################
87191
# END

0 commit comments

Comments
 (0)