Skip to content
This repository was archived by the owner on Aug 8, 2024. It is now read-only.

Commit 690a3f0

Browse files
committed
more image tests
Signed-off-by: Nathaniel Starkman (@nstarman) <[email protected]>
1 parent 1773aff commit 690a3f0

File tree

2 files changed

+79
-78
lines changed

2 files changed

+79
-78
lines changed

sample_scf/tests/test_sample_exact.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,63 +12,55 @@
1212
##############################################################################
1313
# IMPORTS
1414

15-
# BUILT-IN
16-
import pathlib
17-
1815
# THIRD PARTY
19-
# import matplotlib.pyplot as plt
2016
import numpy as np
2117

22-
# import pytest
23-
# from astropy.utils.misc import NumpyRNGContext
24-
# from galpy.df import isotropicHernquistdf
25-
from galpy.potential import HernquistPotential, SCFPotential
26-
27-
# LOCAL
28-
# from sample_scf.sample_exact import SCFPhiSampler, SCFRSampler, SCFSampler, SCFThetaSampler
29-
from sample_scf.utils import zeta_of_r # x_of_theta
30-
31-
# from .test_base import SCFSamplerTestBase
32-
# from sample_scf import sample_exact
33-
34-
3518
##############################################################################
3619
# PARAMETERS
3720

38-
# hernpot = TriaxialHernquistPotential(b=0.8, c=1.2)
39-
hernpot = HernquistPotential()
40-
coeffs = np.load(pathlib.Path(__file__).parent / "scf_coeffs.npz")
41-
Acos, Asin = coeffs["Acos"], coeffs["Asin"]
42-
43-
pot = SCFPotential(Acos=Acos, Asin=Asin)
44-
pot.turn_physical_off()
21+
rgrid = np.concatenate(([0], np.geomspace(1e-1, 1e3, 100)))
22+
tgrid = np.linspace(-np.pi / 2, np.pi / 2, 30)
23+
pgrid = np.linspace(0, 2 * np.pi, 30)
4524

46-
# r sampling
47-
r = np.unique(np.concatenate([[0], np.geomspace(1e-7, 1e3, 100), [np.inf]]))
48-
zeta = zeta_of_r(r)
49-
m = [pot._mass(x) for x in r]
50-
m[0] = 0
51-
m[-1] = 1
52-
53-
# theta sampling
54-
theta = np.linspace(-np.pi / 2, np.pi / 2, 30)
55-
56-
# phi sampling
57-
phi = np.linspace(0, 2 * np.pi, 30)
5825

5926
##############################################################################
6027
# CODE
6128
##############################################################################
6229

6330

6431
# class Test_SCFSampler(SCFSamplerTestBase):
65-
# """Test :class:`sample_scf.sample_intrp.SCFSampler`."""
32+
# """Test :class:`sample_scf.sample_exact.SCFSampler`."""
33+
#
34+
# self.cls = sample_intrp.SCFSampler
35+
# self.cls_args = (rgrid, tgrid, pgrid)
36+
# self.cls_kwargs = {}
37+
#
38+
# self.expected_rvs = {
39+
# 0: dict(r=2.8583146808697, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad),
40+
# 1: dict(r=2.8583146808697, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad),
41+
# 2: dict(
42+
# r=[59.15672032022, 2.842480998054, 71.71466505664, 5.471148006362],
43+
# theta=[0.36517953566424, 1.4761907683040, 0.33207251545636, 1.1267111320704]
44+
# * u.rad,
45+
# phi=[6.076027676095, 3.438361627636, 6.11155607905, 4.491321348792] * u.rad,
46+
# ),
47+
# }
6648
#
67-
# def setup_class(self):
68-
# super().setup_class()
49+
# # ===============================================================
50+
# # Method Tests
6951
#
70-
# self.cls = sample_exact.SCFSampler
71-
# self.cls_args = ()
52+
# # TODO! make sure these are correct
53+
# @pytest.mark.parametrize(
54+
# "r, theta, phi, expected",
55+
# [
56+
# (0, 0, 0, [0, 0.5, 0]),
57+
# (1, 0, 0, [0.25, 0.5, 0]),
58+
# ([0, 1], [0, 0], [0, 0], [[0, 0.5, 0], [0.25, 0.5, 0]]),
59+
# ],
60+
# )
61+
# def test_cdf(self, sampler, r, theta, phi, expected):
62+
# """Test :meth:`sample_scf.base.SCFSamplerBase.cdf`."""
63+
# assert np.allclose(sampler.cdf(r, theta, phi), expected, atol=1e-16)
7264
#
7365
# # /def
7466
#

sample_scf/tests/test_sample_intrp.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,12 @@ def test_interp_theta_sampling_plot(self, request, sampler):
517517
theory = self.theory[kind].sample(n=int(1e6)).theta() - np.pi / 2
518518

519519
fig = plt.figure(figsize=(10, 3))
520-
ax = fig.add_subplot(121, title="SCF vs theory sampling", xlabel="r", ylabel="frequency")
520+
ax = fig.add_subplot(
521+
121,
522+
title="SCF vs theory sampling",
523+
xlabel=r"$\theta$",
524+
ylabel="frequency",
525+
)
521526
_, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample")
522527
# Comparing to expected
523528
ax.hist(
@@ -662,43 +667,47 @@ def test_interp_phi_cdf_plot(self, sampler):
662667

663668
# /def
664669

670+
@pytest.mark.mpl_image_compare(
671+
baseline_dir="baseline_images",
672+
# hash_library="baseline_images/path_to_file.json",
673+
)
674+
def test_interp_phi_sampling_plot(self, request, sampler):
675+
"""Test sampling."""
676+
# fiqure out theory sampler
677+
options = request.fixturenames[0]
678+
if "hernquist" in options:
679+
kind = "hernquist"
680+
else:
681+
raise ValueError
682+
683+
with NumpyRNGContext(0): # control the random numbers
684+
sample = sampler.rvs(size=int(1e6), r=10, theta=np.pi / 6)
685+
sample = sample[sample < 1e4]
686+
687+
theory = self.theory[kind].sample(n=int(1e6)).phi()
665688

666-
# @pytest.mark.mpl_image_compare(
667-
# baseline_dir="baseline_images",
668-
# # hash_library="baseline_images/path_to_file.json",
669-
# )
670-
# def test_interp_phi_sampling_plot(self, request, sampler):
671-
# """Test sampling."""
672-
# # fiqure out theory sampler
673-
# options = request.fixturenames[0]
674-
# if "hernquist" in options:
675-
# kind = "hernquist"
676-
# else:
677-
# raise ValueError
678-
#
679-
# with NumpyRNGContext(0): # control the random numbers
680-
# sample = sampler.rvs(size=int(1e6), r=10)
681-
# sample = sample[sample < 1e4]
682-
#
683-
# theory = self.theory[kind].sample(n=int(1e6)).theta() - np.pi / 2
684-
#
685-
# fig = plt.figure(figsize=(10, 3))
686-
# ax = fig.add_subplot(121, title="SCF vs theory sampling", xlabel="r", ylabel="frequency")
687-
# _, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample")
688-
# # Comparing to expected
689-
# ax.hist(
690-
# theory,
691-
# bins=bins,
692-
# log=True,
693-
# alpha=0.5,
694-
# label="Hernquist theoretical",
695-
# )
696-
# ax.legend()
697-
# fig.tight_layout()
698-
#
699-
# return fig
700-
#
701-
# # /def
689+
fig = plt.figure(figsize=(10, 3))
690+
ax = fig.add_subplot(
691+
121,
692+
title="SCF vs theory sampling",
693+
xlabel=r"$\phi$",
694+
ylabel="frequency",
695+
)
696+
_, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample")
697+
# Comparing to expected
698+
ax.hist(
699+
theory,
700+
bins=bins,
701+
log=True,
702+
alpha=0.5,
703+
label="Hernquist theoretical",
704+
)
705+
ax.legend()
706+
fig.tight_layout()
707+
708+
return fig
709+
710+
# /def
702711

703712

704713
# /class

0 commit comments

Comments
 (0)