7
7
# IMPORTS
8
8
9
9
# THIRD PARTY
10
- # import astropy.units as u
11
- # import numpy as np
10
+ import astropy .coordinates as coord
11
+ import astropy .units as u
12
+ import numpy as np
12
13
import pytest
14
+ from numpy .testing import assert_allclose
13
15
14
16
# LOCAL
15
17
from sample_scf import base
16
18
17
- # from galpy.potential import SCFPotential
18
-
19
-
20
19
##############################################################################
21
20
# TESTS
22
21
##############################################################################
23
22
24
23
25
- class Test_rv_continuous_modrvs :
24
+ class testrvsampler (base .rv_continuous_modrvs ):
25
+ def _cdf (self , x , * args , ** kwargs ):
26
+ return x
27
+
28
+ # /def
29
+
30
+ cdf = _cdf
31
+
32
+ def _rvs (self , * args , size = None , random_state = None ):
33
+ if random_state is None :
34
+ random_state = np .random
35
+
36
+ return np .atleast_1d (random_state .uniform (size = size ))
37
+
38
+ # /def
39
+
40
+
41
+ # /class
42
+
43
+
44
+ class Test_RVContinuousModRVS :
26
45
"""Test `sample_scf.base.rv_continuous_modrvs`."""
27
46
28
- @pytest .mark .skip ("TODO!" )
29
- def test_rvs (self ):
47
+ def setup_class (self ):
48
+ self .sampler = testrvsampler ()
49
+
50
+ # /def
51
+
52
+ # ===============================================================
53
+
54
+ @pytest .mark .parametrize (
55
+ "size, random, expected" ,
56
+ [
57
+ (None , 0 , 0.5488135039273248 ),
58
+ (1 , 2 , 0.43599490214200376 ),
59
+ ((3 , 1 ), 4 , (0.9670298390136767 , 0.5472322491757223 , 0.9726843599648843 )),
60
+ ],
61
+ )
62
+ def test_rvs (self , size , random , expected ):
30
63
"""Test :meth:`sample_scf.base.rv_continuous_modrvs.rvs`."""
31
- assert False
64
+ assert_allclose ( self . sampler . rvs ( size = size , random_state = random ), expected , atol = 1e-16 )
32
65
33
66
# /def
34
67
@@ -42,46 +75,117 @@ def test_rvs(self):
42
75
class Test_SCFSamplerBase :
43
76
"""Test :class:`sample_scf.base.SCFSamplerBase`."""
44
77
45
- _cls = base .SCFSamplerBase
78
+ def setup_class (self ):
79
+ self .cls = base .SCFSamplerBase
80
+ self .cls_args = ()
46
81
47
- @pytest .mark .skip ("TODO!" )
48
- def test_rsampler (self ):
82
+ self .expected_rvs = {
83
+ 0 : dict (r = 0.548813503927 , theta = 1.021982822867 * u .rad , phi = 0.548813503927 * u .rad ),
84
+ 1 : dict (r = 0.548813503927 , theta = 1.021982822867 * u .rad , phi = 0.548813503927 * u .rad ),
85
+ 2 : dict (
86
+ r = [0.9670298390136 , 0.5472322491757 , 0.9726843599648 , 0.7148159936743 ],
87
+ theta = [0.603766487781 , 1.023564077619 , 0.598111966830 , 0.855980333120 ] * u .rad ,
88
+ phi = [0.9670298390136 , 0.547232249175 , 0.9726843599648 , 0.7148159936743 ] * u .rad ,
89
+ ),
90
+ }
91
+
92
+ # /def
93
+
94
+ @pytest .fixture (autouse = True , scope = "class" )
95
+ def sampler (self , potentials ):
96
+ """Set up r, theta, phi sampler."""
97
+ sampler = self .cls (potentials , * self .cls_args )
98
+ sampler ._rsampler = testrvsampler ()
99
+ sampler ._thetasampler = testrvsampler ()
100
+ sampler ._phisampler = testrvsampler ()
101
+
102
+ return sampler
103
+
104
+ # /def
105
+
106
+ # ===============================================================
107
+
108
+ def test_rsampler (self , sampler ):
49
109
"""Test :meth:`sample_scf.base.SCFSamplerBase.rsampler`."""
50
- assert False
110
+ assert isinstance ( sampler . rsampler , base . rv_continuous_modrvs )
51
111
52
112
# /def
53
113
54
- @pytest .mark .skip ("TODO!" )
55
- def test_thetasampler (self ):
114
+ def test_thetasampler (self , sampler ):
56
115
"""Test :meth:`sample_scf.base.SCFSamplerBase.thetasampler`."""
57
- assert False
116
+ assert isinstance ( sampler . thetasampler , base . rv_continuous_modrvs )
58
117
59
118
# /def
60
119
61
- @pytest .mark .skip ("TODO!" )
62
- def test_phisampler (self ):
120
+ def test_phisampler (self , sampler ):
63
121
"""Test :meth:`sample_scf.base.SCFSamplerBase.phisampler`."""
64
- assert False
122
+ assert isinstance ( sampler . phisampler , base . rv_continuous_modrvs )
65
123
66
124
# /def
67
125
68
- @pytest .mark .skip ("TODO!" )
69
- def test_cdf (self ):
126
+ @pytest .mark .parametrize (
127
+ "r, theta, phi, expected" ,
128
+ [
129
+ (0 , 0 , 0 , [0 , 0 , 0 ]),
130
+ (1 , 0 , 0 , [1 , 0 , 0 ]),
131
+ ([0 , 1 ], [0 , 0 ], [0 , 0 ], [[0 , 0 , 0 ], [1 , 0 , 0 ]]),
132
+ ],
133
+ )
134
+ def test_cdf (self , sampler , r , theta , phi , expected ):
70
135
"""Test :meth:`sample_scf.base.SCFSamplerBase.cdf`."""
71
- assert False
136
+ assert np . allclose ( sampler . cdf ( r , theta , phi ), expected , atol = 1e-16 )
72
137
73
138
# /def
74
139
75
- @pytest .mark .skip ("TODO!" )
76
- def test_rvs (self ):
140
+ @pytest .mark .parametrize (
141
+ "id, size, random" ,
142
+ [
143
+ (0 , None , 0 ),
144
+ (1 , 1 , 0 ),
145
+ (2 , 4 , 4 ),
146
+ ],
147
+ )
148
+ def test_rvs (self , sampler , id , size , random ):
77
149
"""Test :meth:`sample_scf.base.SCFSamplerBase.rvs`."""
78
- assert False
150
+ samples = sampler .rvs (size = size , random_state = random )
151
+ sce = coord .PhysicsSphericalRepresentation (** self .expected_rvs [id ])
152
+
153
+ assert_allclose (samples .r , sce .r , atol = 1e-16 )
154
+ assert_allclose (samples .theta .value , sce .theta .value , atol = 1e-16 )
155
+ assert_allclose (samples .phi .value , sce .phi .value , atol = 1e-16 )
79
156
80
157
# /def
81
158
82
159
83
160
# /class
84
161
85
162
163
+ class SCFSamplerTestBase (Test_SCFSamplerBase ):
164
+ def setup_class (self ):
165
+
166
+ self .expected_rvs = {
167
+ 0 : dict (r = 0.548813503927 , theta = 1.021982822867 * u .rad , phi = 0.548813503927 * u .rad ),
168
+ 1 : dict (r = 0.548813503927 , theta = 1.021982822867 * u .rad , phi = 0.548813503927 * u .rad ),
169
+ 2 : dict (
170
+ r = [0.9670298390136 , 0.5472322491757 , 0.9726843599648 , 0.7148159936743 ],
171
+ theta = [0.603766487781 , 1.023564077619 , 0.598111966830 , 0.855980333120 ] * u .rad ,
172
+ phi = [0.9670298390136 , 0.547232249175 , 0.9726843599648 , 0.7148159936743 ] * u .rad ,
173
+ ),
174
+ }
175
+
176
+ # /def
177
+
178
+ @pytest .fixture (autouse = True , scope = "class" )
179
+ def sampler (self , potentials ):
180
+ """Set up r, theta, phi sampler."""
181
+ sampler = self .cls (potentials , * self .cls_args )
182
+
183
+ return sampler
184
+
185
+ # /def
186
+
187
+
188
+ # /class
189
+
86
190
##############################################################################
87
191
# END
0 commit comments