diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..86d2813 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,50 @@ +# Global options: + +[mypy] +python_version = 3.8 + +disallow_untyped_defs = True +no_implicit_reexport = True + +warn_unused_configs = True +warn_redundant_casts = True +warn_unused_ignores = True +no_warn_no_return = True +warn_return_any = True +warn_unreachable = True + +plugins = numpy.typing.mypy_plugin + + +####################################### +# Per-module options: + +[mypy-*/tests.*] +ignore_errors = True + +####################################### +# missing imports + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-galpy.*] +ignore_missing_imports = True + +[mypy-pytest.*] +ignore_missing_imports = True + +[mypy-pytest_astropy_header.display] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-setuptools_scm.*] +ignore_missing_imports = True diff --git a/README.rst b/README.rst index 87d3923..bb891e1 100644 --- a/README.rst +++ b/README.rst @@ -13,8 +13,8 @@ Self-Consistent Field (SCF). License ------- -This project is Copyright (c) nathaniel starkman and licensed under -the terms of the BSD 3-Clause license. This package is based upon +This project is Copyright (c) Nathaniel Starkman and Maintainers and licensed +under the terms of the BSD 3-Clause license. This package is based upon the `Astropy package template `_ which is licensed under the BSD 3-clause license. See the licenses folder for more information. @@ -25,29 +25,3 @@ Contributing We love contributions! sampleSCF is open source, built on open source, and we'd love to have you hang out in our community. - -**Imposter syndrome disclaimer**: We want your help. No, really. - -There may be a little voice inside your head that is telling you that you're not -ready to be an open source contributor; that your skills aren't nearly good -enough to contribute. What could you possibly offer a project like this one? - -We assure you - the little voice in your head is wrong. If you can write code at -all, you can contribute code to open source. Contributing to open source -projects is a fantastic way to advance one's coding skills. Writing perfect code -isn't the measure of a good developer (that would disqualify all of us!); it's -trying to create something, making mistakes, and learning from those -mistakes. That's how we all improve, and we are happy to help others learn. - -Being an open source contributor doesn't just mean writing code, either. You can -help out by writing documentation, tests, or even giving feedback about the -project (and yes - that includes giving feedback about the contribution -process). Some of these contributions may be the most valuable to the project as -a whole, because you're coming to the project with fresh eyes, so you can see -the errors and assumptions that seasoned contributors have glossed over. - -Note: This disclaimer was originally written by -`Adrienne Lowe `_ for a -`PyCon talk `_, and was adapted by -sampleSCF based on its use in the README file for the -`MetPy project `_. diff --git a/docs/conf.py b/docs/conf.py index 33ee295..3b3be34 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ # Thus, any C-extensions that are needed to build the documentation will *not* # be accessible, and the documentation will not build correctly. -# BUILT-IN +# STDLIB import datetime import os import sys @@ -146,7 +146,9 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [("index", project + ".tex", project + u" Documentation", author, "manual")] +latex_documents = [ + ("index", project + ".tex", project + u" Documentation", author, "manual"), +] # -- Options for manual page output ------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..03db57d --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,14 @@ +Documentation +============= + +This is the documentation for sampleSCF. + +.. toctree:: + :maxdepth: 2 + + sample_scf/index.rst + +.. note:: The layout of this directory is simply a suggestion. To follow + traditional practice, do *not* edit this page, but instead place + all documentation for the package inside ``sample_scf/``. + You can follow this practice or choose your own layout. diff --git a/docs/sample_scf/index.rst b/docs/sample_scf/index.rst new file mode 100644 index 0000000..53f7dae --- /dev/null +++ b/docs/sample_scf/index.rst @@ -0,0 +1,10 @@ +*********************** +sampleSCF Documentation +*********************** + +This is the documentation for sampleSCF. + +Reference/API +============= + +.. automodapi:: sample_scf diff --git a/pyproject.toml b/pyproject.toml index a917d4a..3ea12e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,4 @@ [build-system] - requires = ["extension-helpers", "setuptools", "setuptools_scm", @@ -8,21 +7,19 @@ requires = ["extension-helpers", build-backend = 'setuptools.build_meta' [tool.isort] -line_length = 100 -multi_line_output = 3 -include_trailing_comma = "True" -force_grid_wrap = 0 -use_parentheses = "True" -ensure_newline_before_comments = "True" -sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] - -known_third_party = ["astropy", "extension_helpers", "setuptools"] -known_localfolder = "sample_scf" + profile = "black" + include_trailing_comma = "True" + force_grid_wrap = 0 + use_parentheses = "True" + ensure_newline_before_comments = "True" + sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] + known_third_party = ["astropy", "extension_helpers", "galpy", "matplotlib", "numpy", "pytest", "scipy", "setuptools"] + known_localfolder = "sample_scf" -import_heading_stdlib = "BUILT-IN" -import_heading_thirdparty = "THIRD PARTY" -import_heading_firstparty = "FIRST PARTY" -import_heading_localfolder = "LOCAL" + import_heading_stdlib = "STDLIB" + import_heading_thirdparty = "THIRD PARTY" + import_heading_firstparty = "FIRST PARTY" + import_heading_localfolder = "LOCAL" [tool.black] line-length = 100 diff --git a/sample_scf/__init__.py b/sample_scf/__init__.py index 222aaf5..ea0d866 100644 --- a/sample_scf/__init__.py +++ b/sample_scf/__init__.py @@ -3,3 +3,14 @@ # LOCAL from sample_scf._astropy_init import * # isort: +split # noqa: F401, F403 +from sample_scf.core import SCFSampler +from sample_scf.exact import ExactSCFSampler +from sample_scf.interpolated import InterpolatedSCFSampler +from sample_scf.representation import FiniteSphericalRepresentation + +__all__ = [ + "SCFSampler", + "ExactSCFSampler", + "InterpolatedSCFSampler", + "FiniteSphericalRepresentation", +] diff --git a/sample_scf/_astropy_init.py b/sample_scf/_astropy_init.py index d47f554..65e6525 100644 --- a/sample_scf/_astropy_init.py +++ b/sample_scf/_astropy_init.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -# BUILT-IN +# STDLIB import os __all__ = ["__version__", "test"] diff --git a/sample_scf/_typing.py b/sample_scf/_typing.py new file mode 100644 index 0000000..e9829e2 --- /dev/null +++ b/sample_scf/_typing.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +"""Custom typing.""" + +# STDLIB +from typing import Union + +# THIRD PARTY +from numpy import floating +from numpy.random import Generator, RandomState +from numpy.typing import NDArray + +__all__ = ["RandomGenerator", "RandomLike", "NDArrayF", "FArrayLike"] + +RandomGenerator = Union[RandomState, Generator] +RandomLike = Union[None, int, RandomGenerator] +NDArrayF = NDArray[floating] + +# float array-like +FArrayLike = Union[float, NDArrayF] diff --git a/sample_scf/base_multivariate.py b/sample_scf/base_multivariate.py new file mode 100644 index 0000000..3c609d6 --- /dev/null +++ b/sample_scf/base_multivariate.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- + +"""Base class for sampling from an SCF Potential.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from abc import ABCMeta +from typing import Any, List, Optional, Tuple, Type, TypeVar + +# THIRD PARTY +from astropy.coordinates import BaseRepresentation, PhysicsSphericalRepresentation +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import column_stack + +# LOCAL +from .base_univariate import _calculate_Qls, _calculate_rhoTilde, _calculate_Scs +from .base_univariate import phi_distribution_base, r_distribution_base, theta_distribution_base +from sample_scf._typing import NDArrayF, RandomLike + +__all__: List[str] = ["SCFSamplerBase"] + +############################################################################## +# PARAMETERS + +RT = TypeVar("RT", bound=BaseRepresentation) + +############################################################################## +# CODE +############################################################################## + + +class SCFSamplerBase(metaclass=ABCMeta): + """Sample SCF in spherical coordinates. + + The coordinate system is: + - r : [0, infinity) + - theta : [-pi/2, pi/2] (positive at the North pole) + - phi : [0, 2pi) + + Parameters + ---------- + pot : `galpy.potential.SCFPotential` + """ + + _potential: SCFPotential + _r_distribution: r_distribution_base + _theta_distribution: theta_distribution_base + _phi_distribution: phi_distribution_base + + def __init__(self, potential: SCFPotential, **kwargs: Any) -> None: + if not isinstance(potential, SCFPotential): + msg = f"potential must be , not {type(potential)}" + raise TypeError(msg) + + potential.turn_physical_on() + self._potential = potential + + # child classes set up the samplers + # _r_distribution + # _theta_distribution + # _phi_distribution + + # ----------------------------------------------------- + + @property + def potential(self) -> SCFPotential: + """The SCF Potential instance.""" + return self._potential + + @property + def r_distribution(self) -> r_distribution_base: + """Radial coordinate sampler.""" + return self._r_distribution + + @property + def theta_distribution(self) -> theta_distribution_base: + """Inclination coordinate sampler.""" + return self._theta_distribution + + @property + def phi_distribution(self) -> phi_distribution_base: + """Azimuthal coordinate sampler.""" + return self._phi_distribution + + @property + def radial_scale_factor(self) -> Quantity: + """Scale factor to convert dimensionful radii to a dimensionless form.""" + return self._r_distribution._radial_scale_factor + + @property + def nmax(self) -> int: + return self._r_distribution._nmax + + @property + def lmax(self) -> int: + return self._r_distribution._lmax + + # ----------------------------------------------------- + + def calculate_rhoTilde(self, radii: Quantity) -> NDArrayF: + """ + + Parameters + ---------- + radii : (R,) Quantity['length', float] + + returns + ------- + (R, N, L) ndarray[float] + """ + return _calculate_rhoTilde(self, radii) + + def calculate_Qls(self, r: Quantity, rhoTilde=None) -> NDArrayF: + r""" + Radial sums for inclination weighting factors. + The weighting factors measure perturbations from spherical symmetry. + + :math:`Q_l(r) = \sum_{n=0}^{n_{\max}}A_{nl} \tilde{\rho}_{nl0}(r)` + + Parameters + ---------- + r : (R,) Quantity['kpc', float] + Radii. Scalar or 1D array. + + Returns + ------- + Ql : (R, L) array[float] + """ + return _calculate_Qls(self, r=r, rhoTilde=rhoTilde) + + def calculate_Scs( + self, + r: Quantity, + theta: Quantity, + *, + grid: bool = True, + warn: bool = True, + ) -> Tuple[NDArrayF, NDArrayF]: + r"""Radial and inclination sums for azimuthal weighting factors. + + Parameters + ---------- + pot : :class:`galpy.potential.SCFPotential` + Has coefficient matrices Acos and Asin with shape (N, L, L). + r : float or (R,) ndarray[float] + theta : float or (T,) ndarray[float] + grid : bool, optional keyword-only + warn : bool, optional keyword-only + + Returns + ------- + Rm, Sm : (R, T, L) ndarray[float] + Azimuthal weighting factors. + """ + return _calculate_Scs(self, r=r, theta=theta, grid=grid, warn=warn) + + # ----------------------------------------------------- + + def cdf(self, r: Quantity, theta: Quantity, phi: Quantity) -> NDArrayF: + """Cumulative distribution Functions in r, theta(r), phi(r, theta). + + Parameters + ---------- + r : (N,) Quantity ['length'] + theta : (N,) Quantity ['angle'] + phi : (N,) Quantity ['angle'] + + Returns + ------- + (N, 3) ndarray + """ + R: NDArrayF = self.r_distribution.cdf(r) + Theta: NDArrayF = self.theta_distribution.cdf(theta, r=r) + Phi: NDArrayF = self.phi_distribution.cdf(phi, r=r, theta=theta) + + c: NDArrayF = column_stack((R, Theta, Phi)).squeeze() + return c + + def rvs( + self, + *, + size: Optional[int] = None, + random_state: RandomLike = None, + # vectorized: bool = True, + representation_type: Type[RT] = PhysicsSphericalRepresentation, + ) -> RT: + """Sample random variates. + + Parameters + ---------- + size : int or None (optional, keyword-only) + Defining number of random variates. + random_state : int, `~numpy.random.RandomState`, or None (optional, keyword-only) + If seed is None (or numpy.random), the `numpy.random.RandomState` + singleton is used. If seed is an int, a new RandomState instance is + used, seeded with seed. If seed is already a Generator or + RandomState instance then that instance is used. + + Returns + ------- + `~astropy.coordinates.PhysicsSphericalRepresentation` + """ + rs: Quantity + thetas: Quantity + phis: Quantity + + rs = self.r_distribution.rvs(size=size, random_state=random_state) + thetas = self.theta_distribution.rvs(rs, size=size, random_state=random_state) + phis = self.phi_distribution.rvs(rs, thetas, size=size, random_state=random_state) + + crd: RT + crd = PhysicsSphericalRepresentation(r=rs, theta=thetas, phi=phis) + crd = crd.represent_as(representation_type) + + return crd + + def __repr__(self) -> str: + s: str = super().__repr__() + s += f"\n r_distribution: {self.r_distribution!r}" + s += f"\n theta_distribution: {self.theta_distribution!r}" + s += f"\n phi_distribution: {self.phi_distribution!r}" + + return s diff --git a/sample_scf/base_univariate.py b/sample_scf/base_univariate.py new file mode 100644 index 0000000..ed6d47a --- /dev/null +++ b/sample_scf/base_univariate.py @@ -0,0 +1,578 @@ +# -*- coding: utf-8 -*- + +"""Base class for sampling from an SCF Potential.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +import warnings +from abc import ABCMeta +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +# THIRD PARTY +import astropy.units as u +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import arange, array, atleast_1d, floating, inf, isinf, nan_to_num, pi, result_type, sum +from numpy import tril_indices, zeros +from numpy.typing import ArrayLike +from scipy._lib._util import check_random_state +from scipy.special import lpmv +from scipy.stats import rv_continuous + +# LOCAL +from sample_scf._typing import NDArrayF, RandomGenerator, RandomLike +from sample_scf.representation import x_of_theta + +if TYPE_CHECKING: + # LOCAL + from .base_multivariate import SCFSamplerBase + +__all__: List[str] = [] # nothing is publicly scoped + +############################################################################## +# CODE +############################################################################## + + +def _calculate_rhoTilde(distr: Union["rv_potential", SCFSamplerBase], /, r: Quantity) -> NDArrayF: + """Compute the r-dependent coefficient matrix. + + Parameters + ---------- + distr : `rv_potential` or `SCFSamplerBase` + r : (R,) Quantity['length', float] + + returns + ------- + (R, N, L) ndarray[float] + """ + # compute the r-dependent coefficient matrix $\tilde{\rho}$ + nmaxp1, lmaxp1 = distr._potential._Acos.shape[:2] + gprs = atleast_1d(r.to_value(u.kpc)) / distr._potential._ro + rhoT = array([distr._potential._rhoTilde(r, N=nmaxp1, L=lmaxp1) for r in gprs]) # (R, N, L) + # this matrix can have incorrect NaN values when r=0, inf + # and needs to be corrected + ind = (r == 0) | isinf(r) + rhoT[ind] = nan_to_num(rhoT[ind], copy=False, posinf=inf, neginf=-inf) + + return rhoT + + +def _calculate_Qls( + distr: Union["rv_potential", SCFSamplerBase], + /, + r: Quantity, + rhoTilde: Optional[NDArrayF] = None, +) -> NDArrayF: + r""" + Compute the radial sums for inclination weighting factors. + The weighting factors measure perturbations from spherical symmetry. + The sin component disappears in the integral. + + :math:`Q_l(r) = \sum_{n=0}^{n_{\max}}A_{nl} \tilde{\rho}_{nl0}(r)` + + Parameters + ---------- + r : (R,) Quantity['kpc', float] + Radii. Scalar or 1D array. + rhoTilde : (R, N, L) array[float] + + Returns + ------- + Ql : (R, L) array[float] + """ + Acos = distr.potential._Acos # (N, L, M) + rhoT = distr.calculate_rhoTilde(r) if rhoTilde is None else rhoTilde + + # inclination weighting factors + Qls: NDArrayF = sum(Acos[None, :, :, 0] * rhoT, axis=1) # (R, L) + # this matrix can have incorrect NaN values when radii=0 because + # rhoTilde will have +/- infs which when summed produce a NaN. + # at r=0 this can be changed to 0. # TODO! double confirm math + ind0 = r == 0 + if not sum(nan_to_num(rhoT[ind0, :, 0], posinf=1, neginf=-1)) == 0: + # note: this if statement works even if ind0 is all False + warnings.warn("Qls have non-cancelling infinities at r==0") + else: + Qls[ind0] = nan_to_num(Qls[ind0], copy=False) # TODO! Nan-> 0 or 1? + + return Qls + + +def _pnts_Scs( + radii: NDArrayF, + theta: NDArrayF, + rhoTilde: NDArrayF, + Acos: NDArrayF, + Asin: NDArrayF, +) -> Tuple[NDArrayF, NDArrayF]: + """Radial and inclination sums for azimuthal weighting factors. + + Parameters + ---------- + radii : (R/T,) ndarray[float] + rhoTilde: (R/T, N, L) ndarray[float] + Acos, Asin : (N, L, L) ndarray[float] + theta : (R/T,) ndarray[float] + + Returns + ------- + Scm, Ssm : (R, T, L) ndarray[float] + Azimuthal weighting factors. + Cosine and Sine, respectively. + + Warns + ----- + RuntimeWarning + For invalid values (inf addition -> Nan). + For overflow encountered related to inf and 0 division. + """ + T: int = len(theta) + L = M = Acos.shape[1] - 1 + # N = Acos.shape[0] - 1 + + # The r-dependent coefficient matrix $\tilde{\rho}$ + RhoT = rhoTilde[..., None] # (R/T, N, L, {M}) + + # need r and theta to be arrays. Maintains units. + x: NDArrayF = x_of_theta(theta) # (T,) + xs = x[:, None, None, None] # (R/T, {N}, {L}, {M}) + + # legendre polynomials + ls, ms = tril_indices(L + 1) # index set I_(L, M) + + lps = zeros((T, L + 1, M + 1)) # (R/T, L, M) + lps[:, ls, ms] = lpmv(ms[None, :], ls[None, :], xs[:, 0, 0, 0]) + Plm = lps[:, None, :, :] # (R/T, {N}, L, M) + + # full S matrices (R/T, N, L, M) # TODO! where's Nlm + # n-sum # (R/T, N, L, M) -> (R, T, L, M) + Sclm = sum(Acos[None, :, :, :] * RhoT * Plm, axis=-3) + Sslm = sum(Asin[None, :, :, :] * RhoT * Plm, axis=-3) + + # fix adding +/- inf -> NaN. happens when r=0. + idx = radii == 0 + Sclm[idx] = nan_to_num(Sclm[idx], posinf=inf, neginf=-inf) + Sslm[idx] = nan_to_num(Sslm[idx], posinf=inf, neginf=-inf) + + # l'-sum # FIXME! confirm correct som + Scm = sum(Sclm, axis=-2) + Ssm = sum(Sslm, axis=-2) + + return Scm, Ssm + + +# TODO! it's possible to make the r, theta grids, flatten, use _pnts_Scs, +# then reshape or asssign by index to the grid. Then the Sc calc is only +# in one place. +def _grid_Scs( + radii: NDArrayF, + thetas: NDArrayF, + rhoTilde: NDArrayF, + Acos: NDArrayF, + Asin: NDArrayF, +) -> Tuple[NDArrayF, NDArrayF]: + """Radial and inclination sums for azimuthal weighting factors. + + Parameters + ---------- + radii : (R,) ndarray[float] + rhoTilde: (R, N, L) ndarray[float] + Acos, Asin : (N, L, L) ndarray[float] + thetas : (T,) ndarray[float] + + Returns + ------- + Scm, Ssm : (R, T, L) ndarray[float] + Azimuthal weighting factors. + Cosine and Sine, respectively. + + Warns + ----- + RuntimeWarning + For invalid values (inf addition -> Nan). + For overflow encountered related to inf and 0 division. + """ + T: int = len(thetas) + L = M = Acos.shape[1] - 1 + + # The r-dependent coefficient matrix $\tilde{\rho}$ + RhoT = rhoTilde[:, None, :, :, None] # (R, {T}, N, L, {M}) + + # need r and theta to be arrays. Maintains units. + x: NDArrayF = x_of_theta(thetas << u.rad) # (T,) + xs = x[None, :, None, None, None] # ({R}, T, {N}, {L}, {M}) + + # legendre polynomials + ls, ms = tril_indices(L + 1) # index set I_(L, M) + + lps = zeros((T, L + 1, M + 1)) # (T, L, M) + lps[:, ls, ms] = lpmv(ms[None, ...], ls[None, ...], xs[0, :, 0, 0, 0, None]) + Plm = lps[None, :, None, :, :] # ({R}, T, {N}, L, M) + + # full S matrices (R, T, N, L, M) + # n-sum # (R, T, N, L, M) -> (R, T, L, M) + Sclm = sum(Acos[None, None, :, :, :] * RhoT * Plm, axis=-3) + Sslm = sum(Asin[None, None, :, :, :] * RhoT * Plm, axis=-3) + + # fix adding +/- inf -> NaN. happens when r=0. + idx = radii == 0 + Sclm[idx, ...] = nan_to_num(Sclm[idx, ...], posinf=inf, neginf=-inf) + Sslm[idx, ...] = nan_to_num(Sslm[idx, ...], posinf=inf, neginf=-inf) + + # l'-sum + Scm = sum(Sclm, axis=-2) + Ssm = sum(Sslm, axis=-2) + + return Scm, Ssm + + +def _calculate_Scs( + distr, + r: Quantity, + theta: Quantity, + *, + grid: bool = True, + warn: bool = True, +) -> Tuple[NDArrayF, NDArrayF]: + r"""Radial and inclination sums for azimuthal weighting factors. + + Parameters + ---------- + r : float or (R,) ndarray[float] + theta : float or (T,) ndarray[float] + + grid : bool, optional keyword-only + warn : bool, optional keyword-only + + Returns + ------- + Rm, Sm : (R, T, L) ndarray[float] + Azimuthal weighting factors. + """ + # need r and theta to be float arrays. + rdtype = result_type(float, result_type(r)) + radii: NDArrayF = atleast_1d(r).astype(rdtype) # (R,) + thetas: NDArrayF = atleast_1d(theta) << u.rad # (T,) + + if not grid and len(thetas) != len(radii): + raise ValueError + + # compute the r-dependent coefficient matrix $\tilde{\rho}$ # (R, N, L) + rhoTilde = _calculate_rhoTilde(distr, radii) + + # pass to actual calculator, which takes the matrices and r, theta grids. + with warnings.catch_warnings() if not warn else nullcontext(): + if not warn: + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message="(^invalid value)|(^overflow encountered)", + ) + func = _grid_Scs if grid else _pnts_Scs + Sc, Ss = func( + radii, + thetas, + rhoTilde=rhoTilde, + Acos=distr.potential._Acos, + Asin=distr.potential._Asin, + ) + + return Sc, Ss + + +############################################################################## + + +class rv_potential(rv_continuous, metaclass=ABCMeta): + """ + Modified :class:`scipy.stats.rv_continuous` to use custom ``rvs`` methods. + Made by stripping down the original scipy implementation. + See :class:`scipy.stats.rv_continuous` for details. + + Parameters + ---------- + `rv_continuous` is a base class to construct specific distribution classes + and instances for continuous random variables. It cannot be used + directly as a distribution. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + The potential from which to sample. + momtype : int, optional keyword-only + The type of generic moment calculation to use: 0 for pdf, 1 (default) + for ppf. + a : float, optional keyword-only + Lower bound of the support of the distribution, default is minus + infinity. + b : float, optional keyword-only + Upper bound of the support of the distribution, default is plus + infinity. + xtol : float, optional keyword-only + The tolerance for fixed point calculation for generic ppf. + badvalue : float, optional keyword-only + The value in a result arrays that indicates a value that for which + some argument restriction is violated, default is `~numpy.nan`. + name : str, optional keyword-only + The name of the instance. This string is used to construct the default + example for distributions. + longname : str, optional keyword-only + This string is used as part of the first line of the docstring returned + when a subclass has no docstring of its own. Note: `longname` exists + for backwards compatibility, do not use for new subclasses. + shapes : str, optional keyword-only + The shape of the distribution. For example ``"m, n"`` for a + distribution that takes two integers as the two shape arguments for all + its methods. If not provided, shape parameters will be inferred from + the signature of the private methods, ``_pdf`` and ``_cdf`` of the + instance. + extradoc : str, optional keyword-only, deprecated + This string is used as the last part of the docstring returned when a + subclass has no docstring of its own. Note: `extradoc` exists for + backwards compatibility, do not use for new subclasses. + seed : {None, int, `numpy.random.Generator`, + `numpy.random.RandomState`}, optional keyword-only + + If `seed` is None (or `numpy.random`), the `numpy.random.RandomState` + singleton is used. + If `seed` is an int, a new ``RandomState`` instance is used, + seeded with `seed`. + If `seed` is already a ``Generator`` or ``RandomState`` instance then + that instance is used. + """ + + _random_state: RandomGenerator + _potential: SCFPotential + _nmax: int + _lmax: int + _radial_scale_factor: Quantity + + def __init__( + self, + potential: SCFPotential, + *, + momtype: int = 1, + a: Optional[float] = None, + b: Optional[float] = None, + xtol: float = 1e-14, + badvalue: Optional[float] = None, + name: Optional[str] = None, + longname: Optional[str] = None, + shapes: Optional[Tuple[int, ...]] = None, + extradoc: Optional[str] = None, + seed: Optional[int] = None, + ) -> None: + super().__init__( + momtype=momtype, + a=a, + b=b, + xtol=xtol, + badvalue=badvalue, + name=name, + longname=longname, + shapes=shapes, + extradoc=extradoc, + seed=seed, + ) + + if not isinstance(potential, SCFPotential): + msg = f"potential must be , not {type(potential)}" + raise TypeError(msg) + + self._potential = potential + self._nmax = potential._Acos.shape[0] - 1 # 0 inclusive + self._lmax = potential._Acos.shape[1] - 1 # 0 inclusive + self._radial_scale_factor = (potential._a * potential._ro) << u.kpc + + @property + def potential(self) -> SCFPotential: + """The potential from which to sample""" + return self._potential + + @property + def radial_scale_factor(self) -> Quantity: + """Scale factor to convert dimensionful radii to a dimensionless form.""" + return self._radial_scale_factor + + @property + def nmax(self) -> int: + return self._nmax + + @property + def lmax(self) -> int: + return self._lmax + + def calculate_rhoTilde(self, radii: Quantity) -> NDArrayF: + """Compute the r-dependent coefficient matrix. + + Parameters + ---------- + radii : (R,) Quantity['length', float] + + returns + ------- + (R, N, L) ndarray[float] + """ + return _calculate_rhoTilde(self, r=radii) + + # --------------------------------------------------------------- + + def rvs( + self, + *args: Union[floating, ArrayLike], + size: Optional[int] = None, + random_state: RandomLike = None, + **kwargs, + ) -> NDArrayF: + """Random variate sampler. + + Parameters + ---------- + *args + size : int or None (optional, keyword-only) + Size of random variates to generate. + random_state : int, `~numpy.random.RandomState`, or None (optional, keyword-only) + If seed is None (or numpy.random), the `numpy.random.RandomState` + singleton is used. If seed is an int, a new RandomState instance is + used, seeded with seed. If seed is already a Generator or + RandomState instance then that instance is used. + **kwargs + + Returns + ------- + ndarray[float] + Shape 'size'. + """ + # copied from `scipy` + # extra gymnastics needed for a custom random_state + rndm: RandomGenerator + if random_state is not None: + random_state_saved = self._random_state + rndm = check_random_state(random_state) + else: + rndm = self._random_state + random_state_saved = None + + # go directly to `_rvs` + vals: NDArrayF = self._rvs(*args, size=size, random_state=rndm, **kwargs) + + # copied from `scipy` + # do not forget to restore the _random_state + if random_state is not None: + self._random_state = random_state_saved + + return vals.squeeze() # TODO? should it squeeze? + + +# ------------------------------------------------------------------- + + +class r_distribution_base(rv_potential): + """Sample radial coordinate from an SCF potential. + + The potential must have a convergent mass function. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + """ + + +class theta_distribution_base(rv_potential): + """Sample inclination coordinate from an SCF potential. + + The potential must have a convergent mass function. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + """ + + def __init__(self, potential: SCFPotential, **kwargs) -> None: + kwargs["a"], kwargs["b"] = 1, -1 # allowed range of x + super().__init__(potential, **kwargs) + + self._lrange = arange(0, self._lmax + 1) # lmax inclusive + + def rvs( + self, + *args: Union[floating, ArrayLike], + size: Optional[int] = None, + random_state: RandomLike = None, + ) -> NDArrayF: + return super().rvs( + *args, + size=size, + random_state=random_state, + # return_thetas=return_thetas + ) + + # --------------------------------------------------------------- + + def calculate_Qls(self, r: Quantity, rhoTilde: Optional[NDArrayF] = None) -> NDArrayF: + r""" + Compute the radial sums for inclination weighting factors. + The weighting factors measure perturbations from spherical symmetry. + The sin component disappears in the integral. + + :math:`Q_l(r) = \sum_{n=0}^{n_{\max}}A_{nl} \tilde{\rho}_{nl0}(r)` + + Parameters + ---------- + r : (R,) Quantity['kpc', float] + Radii. Scalar or 1D array. + rhoTilde : (R, N, L) array[float] + + Returns + ------- + Ql : (R, L) array[float] + """ + return _calculate_Qls(self, r=r, rhoTilde=rhoTilde) + + +class phi_distribution_base(rv_potential): + """Sample inclination coordinate from an SCF potential. + + The potential must have a convergent mass function. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + """ + + def __init__(self, potential: SCFPotential, **kwargs: Any) -> None: + kwargs["a"], kwargs["b"] = 0, 2 * pi + super().__init__(potential, **kwargs) + + self._lrange = arange(0, self._lmax + 1) + + def calculate_Scs( + self, + r: Quantity, + theta: Quantity, + *, + grid: bool = True, + warn: bool = True, + ) -> Tuple[NDArrayF, NDArrayF]: + r"""Radial and inclination sums for azimuthal weighting factors. + + Parameters + ---------- + r : float or (R,) ndarray[float] + theta : float or (T,) ndarray[float] + + grid : bool, optional keyword-only + warn : bool, optional keyword-only + + Returns + ------- + Rm, Sm : (R, T, L) ndarray[float] + Azimuthal weighting factors. + """ + return _calculate_Scs(self, r=r, theta=theta, grid=grid, warn=warn) diff --git a/sample_scf/conftest.py b/sample_scf/conftest.py index 6608724..3a175e1 100644 --- a/sample_scf/conftest.py +++ b/sample_scf/conftest.py @@ -10,11 +10,14 @@ """ -# BUILT-IN +# STDLIB import os # THIRD PARTY -from astropy.version import version as astropy_version # noqa: F401 +import pytest +from galpy.df import isotropicHernquistdf +from galpy.potential import HernquistPotential, SCFPotential +from numpy import zeros try: # THIRD PARTY @@ -24,6 +27,9 @@ except ImportError: ASTROPY_HEADER = False +# ============================================================================ +# Configuration + def pytest_configure(config): """Configure Pytest with Astropy. @@ -49,15 +55,86 @@ def pytest_configure(config): TESTED_VERSIONS[packagename] = __version__ -# Uncomment the last two lines in this block to treat all DeprecationWarnings as -# exceptions. For Astropy v2.0 or later, there are 2 additional keywords, -# as follow (although default should work for most cases). -# To ignore some packages that produce deprecation warnings on import -# (in addition to 'compiler', 'scipy', 'pygments', 'ipykernel', and -# 'setuptools'), add: -# modules_to_ignore_on_import=['module_1', 'module_2'] -# To ignore some specific deprecation warning messages for Python version -# MAJOR.MINOR or later, add: -# warnings_to_ignore_by_pyver={(MAJOR, MINOR): ['Message to ignore']} -# from astropy.tests.helper import enable_deprecations_as_exceptions # noqa: F401 -# enable_deprecations_as_exceptions() +# ============================================================================ +# Fixtures + +# Hernquist +hernquist_potential = HernquistPotential() +hernquist_potential.turn_physical_on() +hernquist_df = isotropicHernquistdf(hernquist_potential) + +Acos = zeros((5, 6, 6)) +Acos[0, 0, 0] = 1 +_hernquist_scf_potential = SCFPotential(Acos=Acos) +_hernquist_scf_potential.turn_physical_on() + + +# # NFW +# nfw_potential = NFWPotential(normalize=1) +# nfw_potential.turn_physical_on() +# nfw_df = isotropicNFWdf(nfw_potential, rmax=1e4) +# # FIXME! load this up as a test data file +# fpath = get_pkg_data_path("tests/data/nfw.npz", package="sample_scf") +# try: +# data = load(fpath) +# except FileNotFoundError: +# a_scf = 80 +# Acos, Asin = scf_compute_coeffs_axi(nfw_potential.dens, N=40, L=30, a=a_scf) +# savez(fpath, Acos=Acos, Asin=Asin, a_scf=a_scf) +# else: +# data = load(fpath, allow_pickle=True) +# Acos = copy.deepcopy(data["Acos"]) +# Asin = None +# a_scf = data["a_scf"] +# +# _nfw_scf_potential = SCFPotential(Acos=Acos, Asin=None, a=a_scf, normalize=1.0) +# _nfw_scf_potential.turn_physical_on() + + +# Triaxial NFW +# tnfw_potential = TriaxialNFWPotential(normalize=1.0, c=1.4, a=1.0) +# tnfw_potential.turn_physical_on() +# tnfw_df = osipkovmerrittNFWdf(tnfw_potential, rmax=1e4) + + +# ------------------------ +cls_pot_kw = { + _hernquist_scf_potential: {"total_mass": 1.0}, + # _nfw_scf_potential: {"total_mass": 1.0}, +} +theory = { + _hernquist_scf_potential: hernquist_df, + # _nfw_scf_potential: nfw_df, +} + + +@pytest.fixture(scope="session") +def hernquist_scf_potential(): + """Make a SCF of a Hernquist potential. + + This is tested for quality in ``test_conftest.py`` + """ + return _hernquist_scf_potential + + +# @pytest.fixture(scope="session") +# def nfw_scf_potential(): +# """Make a SCF of a triaxial NFW potential.""" +# return _nfw_scf_potential + + +@pytest.fixture( + params=[ + "hernquist_scf_potential", + # "nfw_scf_potential", # TODO! turn on + ], +) +def potentials(request): + if request.param in ("hernquist_scf_potential"): + potential = hernquist_scf_potential.__wrapped__() + # elif request.param == "nfw_scf_potential": + # potential = nfw_scf_potential.__wrapped__() + else: + raise ValueError + + yield potential diff --git a/sample_scf/core.py b/sample_scf/core.py new file mode 100644 index 0000000..e1c409c --- /dev/null +++ b/sample_scf/core.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +"""**DOCSTRING**. + +Description. + +""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from collections.abc import Mapping +from typing import Any, Literal, Optional, Type, TypedDict, Union + +# THIRD PARTY +from galpy.potential import SCFPotential + +# LOCAL +from .base_multivariate import SCFSamplerBase +from .base_univariate import rv_potential +from .exact import ExactSCFSampler +from .interpolated import InterpolatedSCFSampler + +__all__ = ["SCFSampler"] + + +############################################################################## +# Parameters + + +class MethodsMapping(TypedDict): + r: rv_potential + theta: rv_potential + phi: rv_potential + + +############################################################################## +# CODE +############################################################################## + + +class SCFSampler(SCFSamplerBase): + """Sample SCF in spherical coordinates. + + The coordinate system is: + - r : [0, infinity) + - theta : [0, pi] (0 at the North pole) + - phi : [0, 2pi) + + Parameters + ---------- + pot : `galpy.potential.SCFPotential` + method : {'interp', 'exact'} or mapping[str, type] + If mapping, must have keys (r, theta, phi) + **kwargs + Passed to to the individual component sampler constructors. + """ + + _sampler: Optional[SCFSamplerBase] + + def __init__( + self, + potential: SCFPotential, + method: Union[Literal["interp", "exact"], MethodsMapping], + **kwargs: Any, + ) -> None: + super().__init__(potential, **kwargs) + + if isinstance(method, Mapping): # mix and match exact and interpolated + sampler = None + r_distribution = method["r"](potential, **kwargs) + theta_distribution = method["theta"](potential, **kwargs) + phi_distribution = method["phi"](potential, **kwargs) + + else: # either exact or interpolated + sampler_cls: Type[SCFSamplerBase] + if method == "interp": + sampler_cls = InterpolatedSCFSampler + elif method == "exact": + sampler_cls = ExactSCFSampler + else: + raise ValueError(f"method = {method} not in " + "{'interp', 'exact'}") + + sampler = sampler_cls(potential, **kwargs) + r_distribution = sampler.r_distribution + theta_distribution = sampler.theta_distribution + phi_distribution = sampler.phi_distribution + + self._sampler: Optional[SCFSamplerBase] = sampler + self._r_distribution = r_distribution + self._theta_distribution = theta_distribution + self._phi_distribution = phi_distribution diff --git a/sample_scf/exact/__init__.py b/sample_scf/exact/__init__.py new file mode 100644 index 0000000..ce8ca03 --- /dev/null +++ b/sample_scf/exact/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +# LOCAL +from .azimuth import exact_phi_distribution, exact_phi_fixed_distribution +from .core import ExactSCFSampler +from .inclination import exact_theta_distribution, exact_theta_fixed_distribution +from .radial import exact_r_distribution + +__all__ = [ + # multivariate + "ExactSCFSampler", + # univariate + "exact_r_distribution", + "exact_theta_fixed_distribution", + "exact_theta_distribution", + "exact_phi_fixed_distribution", + "exact_phi_distribution", +] diff --git a/sample_scf/exact/azimuth.py b/sample_scf/exact/azimuth.py new file mode 100644 index 0000000..4266896 --- /dev/null +++ b/sample_scf/exact/azimuth.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +"""Exact sampling.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Optional, cast + +# THIRD PARTY +import astropy.units as u +from galpy.potential import SCFPotential +from numpy import arange, atleast_1d, cos, nan_to_num, pi, sin, sum +from numpy.typing import ArrayLike + +# LOCAL +from sample_scf._typing import NDArrayF, RandomLike +from sample_scf.base_univariate import phi_distribution_base + +__all__ = ["exact_phi_fixed_distribution", "exact_phi_distribution"] + + +############################################################################## +# CODE +############################################################################## + + +class exact_phi_distribution_base(phi_distribution_base): + """Sample Azimuthal Coordinate. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + **kw + Passed to `scipy.stats.rv_continuous` + "a", "b" are set to [0, 2 pi] + + """ + + def __init__(self, potential: SCFPotential, **kw: Any) -> None: + kw["a"], kw["b"] = 0, 2 * pi + super().__init__(potential, **kw) + self._lrange = arange(0, self._lmax + 1) + + # for compatibility + self._Sc: Optional[NDArrayF] = None + self._Ss: Optional[NDArrayF] = None + + def _cdf(self, phi: NDArrayF, *args: Any, **kw: Any) -> NDArrayF: + r"""Cumulative Distribution Function. + + Parameters + ---------- + phi : float or ndarray[float] ['radian'] + Azimuthal coordinate in radians, :math:`\in [0, 2\pi]`. + *args, **kw + Not used. + + Returns + ------- + cdf : float or ndarray[float] + Shape (len(r), len(theta), len(phi)). + :meth:`numpy.ndarray.squeeze` applied so scalar inputs has scalar + output. + + """ + Rm, Sm = kw.get("Scs", (self._Sc, self._Ss)) # (R/T, L) + + Phis: NDArrayF = atleast_1d(phi)[:, None] # (P, {L}) + + # l = 0 : spherical symmetry + term0: NDArrayF = Phis[..., 0] / (2 * pi) # (1, P) + + # l = 1+ : non-symmetry + factor = 1 / Rm[:, 0] # R0 (R/T,) # can be inf + ms = arange(1, self._lmax)[None, :] # ({R/T/P}, L) + term1p = sum( + (Rm[:, 1:] * sin(ms * Phis) + Sm[:, 1:] * (1 - cos(ms * Phis))) / (2 * pi * ms), + axis=-1, + ) + + cdf: NDArrayF = term0 + nan_to_num(factor * term1p) # (R/T/P,) + # 'factor' can be inf and term1p 0 => inf * 0 = nan -> 0 + + return cdf + + def _ppf_to_solve(self, phi: float, q: float, *args: Any) -> NDArrayF: + # changed from .cdf() to ._cdf() to use default 'r', 'theta' + return self._cdf(*(phi,) + args) - q + + +class exact_phi_fixed_distribution(exact_phi_distribution_base): + """Sample Azimuthal Coordinate at fixed r, theta. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + r, theta : float or ndarray[float] + + """ + + def __init__(self, potential: SCFPotential, r: NDArrayF, theta: NDArrayF, **kw: Any) -> None: + super().__init__(potential, **kw) + + # assign fixed r, theta + self._r, self._theta = r, theta + # and can compute the associated assymetry measures + self._Sc, self._Ss = self.calculate_Scs(r, theta, grid=False, warn=False) + + def cdf(self, phi: NDArrayF, *args: Any, **kw: Any) -> NDArrayF: + r"""Cumulative Distribution Function. + + Parameters + ---------- + phi : float or ndarray[float] ['radian'] + Azimuthal coordinate in radians, :math:`\in [0, 2\pi]`. + *args + **kw + + Returns + ------- + cdf : float or ndarray[float] + Shape (len(r), len(theta), len(phi)). + :meth:`numpy.ndarray.squeeze` applied so scalar inputs has scalar + output. + """ + return self._cdf(phi, *args, **kw) + + +class exact_phi_distribution(exact_phi_distribution_base): + def _cdf( + self, + phi: ArrayLike, + *args: Any, + r: Optional[float] = None, + theta: Optional[float] = None, + ) -> NDArrayF: + r"""Cumulative Distribution Function. + + Parameters + ---------- + phi : float or ndarray[float] ['radian'] + Azimuthal coordinate in radians, :math:`\in [0, 2\pi]`. + *args + r : float or ndarray[float], keyword-only + Radial coordinate at which to evaluate the CDF. Not optional. + theta : float or ndarray[float], keyword-only + Inclination coordinate at which to evaluate the CDF. Not optional. + In [-pi/2, pi/2]. + + Returns + ------- + cdf : float or ndarray[float] + Shape (len(r), len(theta), len(phi)). + :meth:`numpy.ndarray.squeeze` applied so scalar inputs has scalar + output. + + Raises + ------ + ValueError + If 'r' or 'theta' are None. + """ + Scs = self.calculate_Scs(cast(float, r), cast(float, theta), grid=False, warn=False) + cdf: NDArrayF = super()._cdf(phi, *args, Scs=Scs) + return cdf + + def cdf(self, phi: ArrayLike, *args: Any, r: float, theta: float) -> NDArrayF: + r"""Cumulative Distribution Function. + + Parameters + ---------- + phi : quantity-like or array-like ['radian'] + Azimuthal angular coordinate, :math:`\in [0, 2\pi]`. If doesn't + have units, must be in radians. + *args + r : float or ndarray[float], keyword-only + Radial coordinate at which to evaluate the CDF. Not optional. + theta : quantity-like or array-like ['radian'], keyword-only + Inclination coordinate at which to evaluate the CDF. Not optional. + In [-pi/2, pi/2]. If doesn't have units, must be in radians. + + Returns + ------- + cdf : float or ndarray[float] + Shape (len(r), len(theta), len(phi)). + :meth:`numpy.ndarray.squeeze` applied so scalar inputs has scalar + output. + """ + phi = u.Quantity(phi, u.rad).value + cdf: NDArrayF = self._cdf(phi, *args, r=r, theta=u.Quantity(theta, u.rad).value) + return cdf + + def rvs( # type: ignore + self, + r: float, + theta: float, + *, + size: Optional[int] = None, + random_state: RandomLike = None, + ) -> NDArrayF: + """Random Variate Sample. + + Parameters + ---------- + r : float + theta : float + size : int or None (optional, keyword-only) + random_state : int or `numpy.random.RandomState` or None (optional, keyword-only) + + Returns + ------- + vals : ndarray[float] + + """ + getattr(self._cdf, "__kwdefaults__", {})["r"] = r + getattr(self._cdf, "__kwdefaults__", {})["theta"] = theta + vals = super().rvs(size=size, random_state=random_state) + getattr(self._cdf, "__kwdefaults__", {})["r"] = None + getattr(self._cdf, "__kwdefaults__", {})["theta"] = None + return vals diff --git a/sample_scf/exact/core.py b/sample_scf/exact/core.py new file mode 100644 index 0000000..5801638 --- /dev/null +++ b/sample_scf/exact/core.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +"""Exact sampling.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Optional + +# THIRD PARTY +from astropy.coordinates import PhysicsSphericalRepresentation +from galpy.potential import SCFPotential + +# LOCAL +from .azimuth import exact_phi_distribution +from .inclination import exact_theta_distribution +from sample_scf._typing import RandomLike +from sample_scf.base_multivariate import SCFSamplerBase +from sample_scf.exact.radial import exact_r_distribution + +__all__ = ["ExactSCFSampler"] + + +############################################################################## +# CODE +############################################################################## + + +class ExactSCFSampler(SCFSamplerBase): + """SCF sampler in spherical coordinates. + + The coordinate system is: + - r : [0, infinity) + - theta : [-pi/2, pi/2] (positive at the North pole) + - phi : [0, 2pi) + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + **kw + Not used. + + """ + + def __init__(self, potential: SCFPotential, **kw: Any) -> None: + super().__init__(potential) + + # make samplers + total_mass = kw.pop("total_mass", None) + self._r_distribution = exact_r_distribution(potential, total_mass=total_mass, **kw) + self._theta_distribution = exact_theta_distribution(potential, **kw) # r=None + self._phi_distribution = exact_phi_distribution(potential, **kw) # r=None, theta=None + + def rvs( + self, *, size: Optional[int] = None, random_state: RandomLike = None + ) -> PhysicsSphericalRepresentation: + """Sample random variates. + + Parameters + ---------- + size : int or None (optional, keyword-only) + Defining number of random variates. + random_state : int, `~numpy.random.RandomState`, or None (optional, keyword-only) + If seed is None (or numpy.random), the `numpy.random.RandomState` + singleton is used. If seed is an int, a new RandomState instance is + used, seeded with seed. If seed is already a Generator or + RandomState instance then that instance is used. + + Returns + ------- + `~astropy.coordinates.PhysicsSphericalRepresentation` + """ + return super().rvs(size=size, random_state=random_state) diff --git a/sample_scf/exact/inclination.py b/sample_scf/exact/inclination.py new file mode 100644 index 0000000..ccfa19a --- /dev/null +++ b/sample_scf/exact/inclination.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + +"""Exact sampling of inclination coordinate.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Optional, Union + +# THIRD PARTY +import astropy.units as u +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import atleast_1d, atleast_2d, floating, nan_to_num, pad +from numpy.polynomial.legendre import legval +from numpy.typing import ArrayLike + +# LOCAL +from sample_scf._typing import NDArrayF, RandomLike +from sample_scf.base_univariate import theta_distribution_base +from sample_scf.representation import theta_of_x, x_of_theta + +__all__ = ["exact_theta_fixed_distribution", "exact_theta_distribution"] + + +############################################################################## +# CODE +############################################################################## + + +class exact_theta_distribution_base(theta_distribution_base): + """Base class for sampling the inclination coordinate.""" + + def _cdf(self, x: NDArrayF, Qls: NDArrayF) -> NDArrayF: + """Cumulative Distribution Function. + + .. math:: + + F_{\theta}(\theta; r) = \frac{1 + \\cos{\theta}}{2} + + \frac{1}{2 Q_0(r)}\\sum_{\\ell=1}^{L_{\\max}}Q_{\\ell}(r) + \frac{\\sin(\theta) P_{\\ell}^{1}(\\cos{\theta})}{\\ell(\\ell+1)} + + Where + + Q_{\\ell}(r) = \\sum_{n=0}^{N_{\\max}} N_{\\ell 0} A_{n\\ell 0}^{(\\cos)} + \tilde{\rho}_{n\\ell}(r) + + Parameters + ---------- + x : number or (T,) array[number] + :math:`x = \\cos\theta`. Must be in the range [-1, 1] + Qls : (R, L) array[float] + Radially-dependent coefficients parameterizing the deviations from + a uniform distribution on the inclination angle. + + Returns + ------- + (R, T) array + """ + xs = atleast_1d(x) # (T,) + Qls = atleast_2d(Qls) # (R, L) + + # l = 0 + term0 = 0.5 * (1.0 - xs) # (T,) + # l = 1+ : non-symmetry + factor = 1.0 / (2.0 * Qls[:, 0]) # (R,) + + wQls = Qls[:, 1:] / (2 * self._lrange[None, 1:] + 1) # apply over (L,) dimension + wQls_lp1 = pad(wQls, [[0, 0], [2, 0]]) # pad start of (L,) dimension + + sumPlp1 = legval(xs, wQls_lp1.T, tensor=True) # (R, T) + sumPlm1 = legval(xs, wQls.T, tensor=True) # (R, T) + + cdf = term0 + nan_to_num((factor * (sumPlm1 - sumPlp1).T).T) # (R, T) + return cdf # TODO! get rid of sf function + + # @abc.abstractmethod + # def _cdf(self, x: NDArrayF, Qls: NDArrayF) -> NDArrayF: + # """Cumulative Distribution Function. + # + # .. math:: + # + # F_{\theta}(\theta; r) = \frac{1 + \cos{\theta}}{2} + + # \frac{1}{2 Q_0(r)}\sum_{\ell=1}^{L_{\max}}Q_{\ell}(r) + # \frac{\sin(\theta) P_{\ell}^{1}(\cos{\theta})}{\ell(\ell+1)} + # + # Where + # + # Q_{\ell}(r) = \sum_{n=0}^{N_{\max}} N_{\ell 0} A_{n\ell 0}^{(\cos)} + # \tilde{\rho}_{n\ell}(r) + # + # Parameters + # ---------- + # x : number or (T,) array[number] + # :math:`x = \cos\theta`. Must be in the range [-1, 1] + # Qls : (R, L) array[float] + # Radially-dependent coefficients parameterizing the deviations from + # a uniform distribution on the inclination angle. + # + # Returns + # ------- + # (R, T) array + # """ + # sf = self._sf(x, Qls) + # return 1.0 - sf + + def _rvs( + self, + *args: Union[floating, ArrayLike], + size: Optional[int] = None, + random_state: RandomLike = None, + # return_thetas: bool = True + ) -> NDArrayF: + xs = super()._rvs(*args, size=size, random_state=random_state) + # ths = theta_of_x(xs) if return_thetas else xs + ths = theta_of_x(xs) + return ths + + def _ppf_to_solve(self, x: float, q: float, *args: Any) -> NDArrayF: + ppf: NDArrayF = self._cdf(*(x,) + args) - q + return ppf + + +class exact_theta_fixed_distribution(exact_theta_distribution_base): + """ + Sample inclination coordinate from an SCF potential. + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + r : Quantity or None, optional + If passed, these are the locations at which the theta CDF will be + evaluated. If None (default), then the r coordinate must be given + to the CDF and RVS functions. + **kw: + Not used. + """ + + def __init__(self, potential: SCFPotential, r: Quantity, **kw: Any) -> None: + super().__init__(potential) + + # points at which CDF is defined + self._r = r + self._Qlsatr = self.calculate_Qls(r) + + @property + def fixed_radius(self) -> Quantity: + return self._r + + def _cdf(self, x: ArrayLike, *args: Any) -> NDArrayF: + cdf: NDArrayF = super()._cdf(x, self._Qlsatr) + return cdf + + def cdf(self, theta: Quantity) -> NDArrayF: + """Cumulative distribution function of the given RV. + + Parameters + ---------- + theta : Quantity['angle'] + + Returns + ------- + cdf : ndarray + Cumulative distribution function evaluated at `theta` + """ + return self._cdf(x_of_theta(theta << u.rad)) + + def rvs(self, size: Optional[int] = None, random_state: RandomLike = None) -> NDArrayF: + pts = super().rvs(self._r, size=size, random_state=random_state) + return pts + + +class exact_theta_distribution(exact_theta_distribution_base): + """ + Sample inclination coordinate from an SCF potential. + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + + """ + + def _cdf(self, x: NDArrayF, r: float) -> NDArrayF: + Qls = self.calculate_Qls(r) + cdf = super()._cdf(x, Qls) + return cdf + + def cdf(self, theta: Quantity, *args: Any, r: Quantity) -> NDArrayF: + """Cumulative distribution function of the given RV. + + Parameters + ---------- + theta : Quantity['angle'] + *args + Not used. + r : Quantity['length', float] (optional, keyword-only) + + Returns + ------- + cdf : ndarray + Cumulative distribution function evaluated at `theta` + """ + return self._cdf(x_of_theta(theta), *args, r=r) + + def rvs( + self, r: Quantity, *, size: Optional[int] = None, random_state: RandomLike = None + ) -> NDArrayF: + pts = super().rvs(r, size=size, random_state=random_state) + return pts diff --git a/sample_scf/exact/radial.py b/sample_scf/exact/radial.py new file mode 100644 index 0000000..fc13a9d --- /dev/null +++ b/sample_scf/exact/radial.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +"""Exact sampling of radial coordinate.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Optional + +# THIRD PARTY +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import atleast_1d, inf, isnan, vectorize + +# LOCAL +from sample_scf._typing import NDArrayF +from sample_scf.base_univariate import r_distribution_base + +__all__ = ["exact_r_distribution"] + + +############################################################################## +# CODE +############################################################################## + + +class exact_r_distribution(r_distribution_base): + """Sample radial coordinate from an SCF potential. + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + A potential that can be used to calculate the enclosed mass. + total_mass : Optional + **kw + Not used. + """ + + def __init__( + self, potential: SCFPotential, total_mass: Optional[Quantity] = None, **kw: Any + ) -> None: + # make sampler + kw["a"], kw["b"] = 0, inf # allowed range of r + super().__init__(potential, **kw) + + # normalization for total mass + # TODO! if mass has units + if total_mass is None: + total_mass = potential._mass(inf) + if isnan(total_mass): + raise ValueError( + "total mass is NaN. Need to pass kwarg `total_mass` with a non-NaN value.", + ) + self._mtot = total_mass + # vectorize mass function, which is scalar + self._vec_cdf = vectorize(self._potential._mass) + + def _cdf(self, r: Quantity, *args: Any, **kw: Any) -> NDArrayF: + """Cumulative Distribution Function. + + Parameters + ---------- + r : Quantity ['length'] + *args + **kwargs + + Returns + ------- + mass : array-like + Shape matches 'r'. + """ + mass: NDArrayF = atleast_1d(self._vec_cdf(r)) / self._mtot + mass[r == 0] = 0 + mass[r == inf] = 1 + return mass.item() if mass.shape == (1,) else mass + + cdf = _cdf diff --git a/sample_scf/exact/tests/__init__.py b/sample_scf/exact/tests/__init__.py new file mode 100644 index 0000000..8807810 --- /dev/null +++ b/sample_scf/exact/tests/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +This module contains package tests. +""" diff --git a/sample_scf/exact/tests/test_core.py b/sample_scf/exact/tests/test_core.py new file mode 100644 index 0000000..1a82423 --- /dev/null +++ b/sample_scf/exact/tests/test_core.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +"""Tests for :mod:`sample_scf.exact.core`.""" + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import astropy.units as u +import matplotlib.pyplot as plt +import pytest +from sampler_scf.base_multivariate import SCFSamplerBase + +# LOCAL +from .test_base_multivariate import BaseTest_SCFSamplerBase +from sample_scf import ExactSCFSampler + +############################################################################## +# CODE +############################################################################## + + +class Test_ExactSCFSampler(BaseTest_SCFSamplerBase): + """Test :class:`sample_scf.exact.ExactSCFSampler`.""" + + @pytest.fixture(scope="class") + def rv_cls(self): + return ExactSCFSampler + + def setup_class(self): + # TODO! make sure these are right! + self.expected_rvs = { + 0: dict(r=2.85831468, theta=1.473013568997 * u.rad, phi=4.49366731864 * u.rad), + 1: dict(r=2.85831468, theta=1.473013568997 * u.rad, phi=4.49366731864 * u.rad), + 2: dict( + r=[59.156720319468995, 2.8424809956410684, 71.71466505619023, 5.471148006577435], + theta=[0.365179487932, 1.476190768288, 0.3320725403573, 1.126711132015] * u.rad, + phi=[4.383959499105, 1.3577303436664, 6.134113310024, 0.039145847961457] * u.rad, + ), + } + + # =============================================================== + # Method Tests + + def test_init_attrs(self, sampler): + super().test_init_attrs(sampler) + + hasattr(sampler, "_sampler") + assert sampler._sampler is None or isinstance(sampler._sampler, SCFSamplerBase) + + # TODO! make sure these are correct + @pytest.mark.parametrize( + "r, theta, phi, expected", + [ + (0.0, 0.0, 0.0, [0, 0.5, 0]), + (1.0, 0.0, 0.0, [0.25, 0.5, 0]), + # ([0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [[0, 0.5, 0], [0.25, 0.5, 0]]), + ], + ) + def test_cdf(self, sampler, r, theta, phi, expected): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.cdf`.""" + super().test_cdf(sampler, r, theta, phi, expected) + + @pytest.mark.skip("TODO!") + def test_rvs(self, sampler): + """Test Random Variates Sampler.""" + + # =============================================================== + # Plot Tests + + def test_exact_cdf_plot(self, sampler): + """Plot cdf.""" + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + cdf = sampler.cdf(rgrid, tgrid, pgrid) + + fig = plt.figure(figsize=(15, 3)) + + # r + ax = fig.add_subplot( + 131, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + ax.semilogx(rgrid, cdf[:, 0], **kw) + + # theta + ax = fig.add_subplot( + 132, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + ax.plot(tgrid, cdf[:, 1], **kw) + + # phi + ax = fig.add_subplot( + 133, + title=r"CDF($\phi$)", + xlabel=r"$\phi$", + ylabel=r"CDF($\phi$)", + ) + ax.plot(pgrid, cdf[:, 2], **kw) + + return fig + + def test_exact_sampling_plot(self, sampler): + """Plot sampling.""" + samples = sampler.rvs(size=int(1e3), random_state=3) + + fig = plt.figure(figsize=(15, 4)) + + ax = fig.add_subplot( + 131, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + ax.hist(samples.r.value[samples.r < 5e3], log=True, bins=50, density=True) + + ax = fig.add_subplot( + 132, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + ax.hist(samples.theta.value, bins=50, density=True) + + ax = fig.add_subplot(133, title=r"CDF($\phi$)", xlabel=r"$\phi$", ylabel=r"CDF($\phi$)") + ax.hist(samples.phi.value, bins=50) + + return fig diff --git a/sample_scf/exact/tests/test_exact.py b/sample_scf/exact/tests/test_exact.py new file mode 100644 index 0000000..95b9b10 --- /dev/null +++ b/sample_scf/exact/tests/test_exact.py @@ -0,0 +1,554 @@ +# -*- coding: utf-8 -*- + +"""Tests for :mod:`sample_scf.exact`.""" + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import astropy.units as u +import matplotlib.pyplot as plt +import pytest +from astropy.utils.misc import NumpyRNGContext +from numpy import allclose, atleast_1d, atleast_2d, concatenate, geomspace, isclose, linspace, pi +from numpy import random +from numpy.testing import assert_allclose + +# LOCAL +from .common import phi_distributionTestBase, r_distributionTestBase, theta_distributionTestBase +from .test_base import SCFSamplerTestBase +from sample_scf import conftest +from sample_scf.base_univariate import _calculate_Qls +from sample_scf.exact import ExactSCFSampler +from sample_scf.exact.azimuth import exact_phi_distribution +from sample_scf.exact.inclination import exact_theta_distribution +from sample_scf.exact.radial import exact_r_distribution +from sample_scf.representation import r_of_zeta, x_of_theta + +############################################################################## +# PARAMETERS + +rgrid = concatenate(([0], geomspace(1e-1, 1e3, 29))) # same shape as ↓ +tgrid = linspace(-pi / 2, pi / 2, 30) +pgrid = linspace(0, 2 * pi, 30) + + +############################################################################## +# CODE +############################################################################## + + +class Test_SCFSampler(SCFSamplerTestBase): + """Test :class:`sample_scf.exact.SCFSampler`.""" + + def setup_class(self): + super().setup_class(self) + + # sampler initialization + self.cls = ExactSCFSampler + self.cls_args = () + self.cls_kwargs = {} + self.cls_pot_kw = conftest.cls_pot_kw + + # TODO! make sure these are right! + self.expected_rvs = { + 0: dict(r=2.85831468, theta=1.473013568997 * u.rad, phi=4.49366731864 * u.rad), + 1: dict(r=2.85831468, theta=1.473013568997 * u.rad, phi=4.49366731864 * u.rad), + 2: dict( + r=[59.156720319468995, 2.8424809956410684, 71.71466505619023, 5.471148006577435], + theta=[0.365179487932, 1.476190768288, 0.3320725403573, 1.126711132015] * u.rad, + phi=[4.383959499105, 1.3577303436664, 6.134113310024, 0.039145847961457] * u.rad, + ), + } + + # =============================================================== + # Method Tests + + def test_init(self, potentials): + kw = {**self.cls_kwargs, **self.cls_pot_kw.get(potentials, {})} + instance = self.cls(potentials, *self.cls_args, **kw) + + assert isinstance(instance.r_distribution, exact_r_distribution) + assert isinstance(instance.theta_distribution, exact_theta_distribution) + assert isinstance(instance.phi_distribution, exact_phi_distribution) + + def test_rvs(self, sampler): + """Test Random Variates Sampler.""" + + # TODO! make sure these are correct + @pytest.mark.parametrize( + "r, theta, phi, expected", + [ + (0.0, 0.0, 0.0, [0, 0.5, 0]), + (1.0, 0.0, 0.0, [0.25, 0.5, 0]), + # ([0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [[0, 0.5, 0], [0.25, 0.5, 0]]), + ], + ) + def test_cdf(self, sampler, r, theta, phi, expected): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.cdf`.""" + assert allclose(sampler.cdf(r, theta, phi), expected, atol=1e-16) + + # =============================================================== + # Plot Tests + + def test_exact_cdf_plot(self, sampler): + """Plot cdf.""" + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + cdf = sampler.cdf(rgrid, tgrid, pgrid) + + fig = plt.figure(figsize=(15, 3)) + + # r + ax = fig.add_subplot( + 131, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + ax.semilogx(rgrid, cdf[:, 0], **kw) + + # theta + ax = fig.add_subplot( + 132, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + ax.plot(tgrid, cdf[:, 1], **kw) + + # phi + ax = fig.add_subplot( + 133, + title=r"CDF($\phi$)", + xlabel=r"$\phi$", + ylabel=r"CDF($\phi$)", + ) + ax.plot(pgrid, cdf[:, 2], **kw) + + return fig + + def test_exact_sampling_plot(self, sampler): + """Plot sampling.""" + samples = sampler.rvs(size=int(1e3), random_state=3) + + fig = plt.figure(figsize=(15, 4)) + + ax = fig.add_subplot( + 131, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + ax.hist(samples.r.value[samples.r < 5e3], log=True, bins=50, density=True) + + ax = fig.add_subplot( + 132, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + ax.hist(samples.theta.value, bins=50, density=True) + + ax = fig.add_subplot(133, title=r"CDF($\phi$)", xlabel=r"$\phi$", ylabel=r"CDF($\phi$)") + ax.hist(samples.phi.value, bins=50) + + return fig + + +# ============================================================================ + + +class Test_r_distribution(r_distributionTestBase): + """Test :class:`sample_scf.exact.r_distribution`""" + + def setup_class(self): + super().setup_class(self) + + # sampler initialization + self.cls = exact_r_distribution + self.cls_args = () + self.cls_kwargs = {} + self.cls_pot_kw = conftest.cls_pot_kw + + # time-scale tests + self.cdf_time_scale = 1e-2 # milliseconds + self.rvs_time_scale = 1e-2 # milliseconds + + @pytest.fixture() + def sampler(self, potentials): + """Set up r, theta, or phi sampler.""" + kw = {**self.cls_kwargs, **self.cls_pot_kw.get(potentials, {})} + sampler = self.cls(potentials, *self.cls_args, **kw) + + return sampler + + # =============================================================== + # Method Tests + + @pytest.mark.skip("TODO!") + def test_init(self): + assert False + # test if mgrid is SCFPotential + + # TODO! use hypothesis + @pytest.mark.parametrize("r", random.default_rng(0).uniform(0, 1e4, 10)) + def test__cdf(self, sampler, r): + """Test :meth:`sample_scf.exact.r_distribution._cdf`.""" + super().test__cdf(sampler, r) + + # expected + mass = atleast_1d(sampler._potential._mass(r)) / sampler._mtot + assert_allclose(sampler._cdf(r), mass) + + @pytest.mark.parametrize( + "size, random, expected", + [ + (None, 0, 2.85831468026), + (1, 2, 1.9437661234293), + ((3, 1), 4, [59.156720319468, 2.8424809956410, 71.71466505619]), + ((3, 1), None, [59.156720319468, 2.8424809956410, 71.71466505619]), + ], + ) + def test_rvs(self, sampler, size, random, expected): + """Test :meth:`sample_scf.exact.r_distribution.rvs`.""" + super().test_rvs(sampler, size, random, expected) + + # =============================================================== + # Time Scaling Tests + + # TODO! generalize for subclasses + @pytest.mark.parametrize("size", [1, 10, 100, 1000]) # rm 1e4 + def test_rvs_time_scaling(self, sampler, size): + """Test that the time scales as X * size""" + super().test_rvs_time_scaling(sampler, size) + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_exact_r_cdf_plot(self, sampler): + fig = plt.figure(figsize=(10, 3)) + + ax = fig.add_subplot( + 111, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.semilogx(rgrid, sampler.cdf(rgrid), **kw) + ax.axvline(0.0, c="tab:blue") + ax.axhline(sampler.cdf(0.0), c="tab:blue", label="r=0") + ax.axvline(1.0, c="tab:green") + ax.axhline(sampler.cdf(1.0), c="tab:green", label="r=1") + ax.axvline(1e2, c="tab:red") + ax.axhline(sampler.cdf(1e2), c="tab:red", label="r=100") + + ax.set_xlim((1e-1, None)) + ax.legend(loc="lower right") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_exact_r_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e3)) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e6)).r() + theory = theory[theory < 1e4 * u.kpc] + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot(121, title="SCF vs theory sampling", xlabel="r", ylabel="frequency") + _, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample") + # Comparing to expected + ax.hist( + theory.to_value(u.kpc), + bins=bins, + log=True, + alpha=0.5, + label="Hernquist theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig + + +# ---------------------------------------------------------------------------- + + +class Test_theta_distribution(theta_distributionTestBase): + """Test :class:`sample_scf.exact.theta_distribution`.""" + + def setup_class(self): + super().setup_class(self) + + self.cls = exact_theta_distribution + + self.cdf_time_scale = 1e-3 + self.rvs_time_scale = 7e-2 + + # =============================================================== + # Method Tests + + # TODO! use hypothesis + + @pytest.mark.parametrize( + "x, r", + [ + *zip( + random.default_rng(1).uniform(-1, 1, 10), + r_of_zeta(random.default_rng(1).uniform(-1, 1, 10)), + ), + ], + ) + def test__cdf(self, sampler, x, r): + """Test :meth:`sample_scf.exact.theta_distribution._cdf`.""" + Qls = atleast_2d(_calculate_Qls(sampler._potential, r)) + + # basically a test it's Hernquist, only the first term matters + if allclose(Qls[:, 1:], 0.0): + assert_allclose(sampler._cdf(x, r=r), 0.5 * (x + 1.0)) + + else: + assert False + + # l = 0 + # term0 = 0.5 * (x + 1.0) # (T,) + # # l = 1+ : non-symmetry + # factor = 1.0 / (2.0 * Qls[:, 0]) # (R,) + # term1p = sum( + # (Qls[None, :, 1:] * difPls(x, self._lmax - 1).T[:, None, :]).T, + # axis=0, + # ) + # cdf = term0[None, :] + nan_to_num(factor[:, None] * term1p) # (R, T) + + # assert_allclose(sampler._cdf(x, r=r), cdf) + + @pytest.mark.parametrize("r", r_of_zeta(random.default_rng(0).uniform(-1, 1, 10))) + def test__cdf_edge(self, sampler, r): + """Test :meth:`sample_scf.exact.r_distribution._cdf`.""" + assert isclose(sampler._cdf(-1, r=r), 0.0, atol=1e-16) + assert isclose(sampler._cdf(1, r=r), 1.0, atol=1e-16) + + @pytest.mark.parametrize( + "theta, r", + [ + *zip( + random.default_rng(0).uniform(-pi / 2, pi / 2, 10), + random.default_rng(1).uniform(0, 1e4, 10), + ), + ], + ) + def test_cdf(self, sampler, theta, r): + """Test :meth:`sample_scf.exact.theta_distribution.cdf`.""" + self.test__cdf(sampler, x_of_theta(theta), r) + + @pytest.mark.skip("TODO!") + def test__rvs(self): + """Test :meth:`sample_scf.exact.theta_distribution._rvs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_rvs(self): + """Test :meth:`sample_scf.exact.theta_distribution.rvs`.""" + assert False + + # =============================================================== + # Time Scaling Tests + + # TODO! generalize for subclasses + @pytest.mark.parametrize("size", [1, 10, 100, 1000]) # rm 1e4 + def test_rvs_time_scaling(self, sampler, size): + """Test that the time scales as X * size""" + super().test_rvs_time_scaling(sampler, size) + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_exact_theta_cdf_plot(self, sampler): + fig = plt.figure(figsize=(10, 3)) + + # plot 1 + ax = fig.add_subplot( + 121, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.plot(tgrid, sampler.cdf(tgrid, r=10), **kw) + ax.axvline(-pi / 2, c="tab:blue") + ax.axhline(sampler.cdf(-pi / 2, r=10), c="tab:blue", label=r"$\theta=-\frac{\pi}{2}$") + ax.axvline(0, c="tab:green") + ax.axhline(sampler.cdf(0, r=10), c="tab:green", label=r"$\theta=0$") + ax.axvline(pi / 2, c="tab:red") + ax.axhline(sampler.cdf(pi / 2, r=10), c="tab:red", label=r"$\theta=\frac{\pi}{2}$") + ax.legend(loc="lower right") + + # plot 2 + ax = fig.add_subplot( + 122, + title=r"CDF($x$)", + xlabel=r"x$", + ylabel=r"CDF($x$)", + ) + ax.plot(x_of_theta(tgrid), sampler.cdf(tgrid, r=10), **kw) + ax.axvline(x_of_theta(-1), c="tab:blue") + ax.axhline(sampler.cdf(-1, r=10), c="tab:blue", label=r"$\theta=-\frac{\pi}{2}$") + ax.axvline(x_of_theta(0), c="tab:green") + ax.axhline(sampler.cdf(0, r=10), c="tab:green", label=r"$\theta=0$") + ax.axvline(x_of_theta(1), c="tab:red") + ax.axhline(sampler.cdf(1, r=10), c="tab:red", label=r"$\theta=\frac{\pi}{2}$") + ax.legend(loc="upper left") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_exact_theta_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e3), r=10) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e6)).theta() + theory -= pi / 2 * u.rad + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot( + 121, + title="SCF vs theory sampling", + xlabel=r"$\theta$", + ylabel="frequency", + ) + _, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample") + # Comparing to expected + ax.hist( + theory.to_value(u.rad), + bins=bins, + log=True, + alpha=0.5, + label="Hernquist theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig + + +############################################################################### + + +class Test_phi_distribution(phi_distributionTestBase): + """Test :class:`sample_scf.exact.phi_distribution`.""" + + def setup_class(self): + super().setup_class(self) + + self.cls = exact_phi_distribution + + self.cdf_time_scale = 3e-3 + self.rvs_time_scale = 3e-3 + + # =============================================================== + # Method Tests + + @pytest.mark.skip("TODO!") + def test__cdf(self): + """Test :meth:`sample_scf.exactolated.phi_distribution._cdf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_cdf(self): + """Test :meth:`sample_scf.exactolated.phi_distribution.cdf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test__rvs(self): + """Test :meth:`sample_scf.exactolated.phi_distribution._rvs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_rvs(self): + """Test :meth:`sample_scf.exactolated.phi_distribution.rvs`.""" + assert False + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_exact_phi_cdf_plot(self, sampler): + fig = plt.figure(figsize=(5, 3)) + + ax = fig.add_subplot( + 111, + title=r"CDF($\phi$)", + xlabel=r"$\phi$", + ylabel=r"CDF($\phi$)", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.plot(pgrid, sampler.cdf(pgrid, r=10, theta=pi / 6), **kw) + ax.axvline(0, c="tab:blue") + ax.axhline(sampler.cdf(0, r=10, theta=pi / 6), c="tab:blue", label=r"$\phi=0$") + ax.axvline(pi, c="tab:green") + ax.axhline(sampler.cdf(pi, r=10, theta=pi / 6), c="tab:green", label=r"$\phi=\pi$") + ax.axvline(2 * pi, c="tab:red") + ax.axhline(sampler.cdf(2 * pi, r=10, theta=pi / 6), c="tab:red", label=r"$\phi=2\pi$") + ax.legend(loc="lower right") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_exact_phi_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e3), r=10, theta=pi / 6) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e3)).phi() + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot( + 121, + title="SCF vs theory sampling", + xlabel=r"$\phi$", + ylabel="frequency", + ) + _, bins, *_ = ax.hist(sample, bins=30, log=True, alpha=0.5, label="SCF sample") + # Comparing to expected + ax.hist( + theory.to_value(u.rad), + bins=bins, + log=True, + alpha=0.5, + label="Hernquist theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig diff --git a/sample_scf/exact/tests/test_utils.py b/sample_scf/exact/tests/test_utils.py new file mode 100644 index 0000000..2a5ac79 --- /dev/null +++ b/sample_scf/exact/tests/test_utils.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.interpolated.utils`.""" + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import pytest +from numpy import inf, isclose, pi, zeros +from numpy.testing import assert_allclose + +# LOCAL +from sample_scf.base_univariate import _calculate_Qls, _calculate_Scs + +############################################################################## +# TESTS +############################################################################## + + +class Test_Qls: + """Test `sample_scf.base_univariate.Qls`.""" + + # =============================================================== + # Usage Tests + + @pytest.mark.parametrize("r, expected", [(0, 1), (1, 0.01989437), (inf, 0)]) + def test_hernquist(self, hernquist_scf_potential, r, expected): + Qls = _calculate_Qls(hernquist_scf_potential, r=r) + # shape should be L (see setup_class) + assert len(Qls) == 6 + # only 1st index is non-zero + assert isclose(Qls[0], expected) + assert_allclose(Qls[1:], 0) + + @pytest.mark.skip("TODO!") + def test_nfw(self, nfw_scf_potential): + assert False + + +# ------------------------------------------------------------------- + + +class Test_phiScs: + + # =============================================================== + # Tests + + # @pytest.mark.skip("TODO!") + @pytest.mark.parametrize( + "r, theta, expected", + [ + # show it doesn't depend on theta + (0, -pi / 2, (zeros(5), zeros(5))), + (0, 0, (zeros(5), zeros(5))), # special case when x=0 is 0 + (0, pi / 6, (zeros(5), zeros(5))), + (0, pi / 2, (zeros(5), zeros(5))), + # nor on r + (1, -pi / 2, (zeros(5), zeros(5))), + (10, -pi / 4, (zeros(5), zeros(5))), + (100, pi / 6, (zeros(5), zeros(5))), + (1000, pi / 2, (zeros(5), zeros(5))), + # Legendre[n=0, l=0, z=z] = 1 is a special case + (1, 0, (zeros(5), zeros(5))), + (10, 0, (zeros(5), zeros(5))), + (100, 0, (zeros(5), zeros(5))), + (1000, 0, (zeros(5), zeros(5))), + ], + ) + def test_phiScs_hernquist(self, hernquist_scf_potential, r, theta, expected): + Rm, Sm = _calculate_Scs(hernquist_scf_potential, r, theta, warn=False) + assert Rm.shape == Sm.shape + assert Rm.shape == (1, 1, 6) + assert_allclose(Rm[0, 0, 1:], expected[0], atol=1e-16) + assert_allclose(Sm[0, 0, 1:], expected[1], atol=1e-16) + + if theta == 0 and r != 0: + assert Rm[0, 0, 0] != 0 + assert Sm[0, 0, 0] == 0 diff --git a/sample_scf/interpolated/__init__.py b/sample_scf/interpolated/__init__.py new file mode 100644 index 0000000..b294d23 --- /dev/null +++ b/sample_scf/interpolated/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +# LOCAL +from .azimuth import interpolated_phi_distribution +from .core import InterpolatedSCFSampler +from .inclination import interpolated_theta_distribution +from .radial import interpolated_r_distribution + +__all__ = [ + "InterpolatedSCFSampler", + "interpolated_r_distribution", + "interpolated_theta_distribution", + "interpolated_phi_distribution", +] diff --git a/sample_scf/interpolated/azimuth.py b/sample_scf/interpolated/azimuth.py new file mode 100644 index 0000000..9a9b75d --- /dev/null +++ b/sample_scf/interpolated/azimuth.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +"""**DOCSTRING**. + +Description. + +""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +import itertools +import warnings +from typing import Any, Optional, Tuple + +# THIRD PARTY +import astropy.units as u +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import arange, argsort, column_stack, cos, empty, float64, inf, linspace, meshgrid +from numpy import nan_to_num, pi, random, sin, sum +from numpy.typing import ArrayLike +from scipy.interpolate import RegularGridInterpolator, splev, splrep + +# LOCAL +from .inclination import interpolated_theta_distribution +from .radial import interpolated_r_distribution +from sample_scf._typing import NDArrayF, RandomLike +from sample_scf.base_univariate import _grid_Scs, phi_distribution_base +from sample_scf.representation import x_of_theta, zeta_of_r + +__all__ = ["interpolated_phi_distribution"] + + +############################################################################## +# PARAMETERS + +_phi_filter = dict(category=RuntimeWarning, message="(^invalid value)|(^overflow encountered)") + +############################################################################## +# CODE +############################################################################## + + +class interpolated_phi_distribution(phi_distribution_base): + """SCF phi sampler. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + radii : ndarray[float] + thetas : ndarray[float] + phis : ndarray[float] + intrp_step : float, optional + **kw + Passed to `scipy.stats.rv_continuous` + "a", "b" are set to [0, 2 pi] + """ + + def __init__( + self, + potential: SCFPotential, + radii: Quantity, + thetas: Quantity, + phis: Quantity, + nintrp: float = 1e3, + **kw: Any, + ) -> None: + rhoTilde = kw.pop("rhoTilde", None) # must be same sort order as + super().__init__(potential, **kw) # allowed range of r + + self._phi_interpolant = linspace(0, 2 * pi, int(nintrp)) << u.rad + self._ninterpolant = len(self._phi_interpolant) + self._q_interpolant = qarr = linspace(0, 1, self._ninterpolant) + + # ------- + # build CDF + + radii, zetas = interpolated_r_distribution.order_radii(self, radii) # (R,) + thetas, xs = interpolated_theta_distribution.order_thetas(thetas) # (T,) + phis = interpolated_phi_distribution.order_phis(phis) # (P,) + self._phis = phis + + lR, lT, _ = len(radii), len(thetas), len(phis) + Phis = phis.to_value(u.rad)[None, None, :, None] # ({R}, {T}, P, {L}) + + # get Sc, Ss. We have defaults from above. + if rhoTilde is None: + rhoTilde = self.calculate_rhoTilde(radii) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", **_phi_filter) + Sc, Ss = _grid_Scs( + radii, thetas, rhoTilde=rhoTilde, Acos=potential._Acos, Asin=potential._Asin + ) # (R, T, L) + self._Scms, self._Ssms = Sc, Ss + + # l = 0 : spherical symmetry + term0 = Phis[..., 0] / (2 * pi) # (1, 1, P) + # l = 1+ : non-symmetry + with warnings.catch_warnings(): # ignore true_divide RuntimeWarnings + warnings.simplefilter("ignore") + factor = 1.0 / Sc[:, :, :1] # R0 (R, T, 1) + + ms = arange(1, self._lmax + 1)[None, None, None, :] # ({R}, {T}, {P}, L) + term1p = sum( + ((Sc[:, :, None, 1:] * sin(ms * Phis)) + (Ss[:, :, None, 1:] * (1 - cos(ms * Phis)))) + / (2 * pi * ms), + axis=-1, + ) + + # cdfs = term0 + nan_to_num(factor * term1p) # (R, T, P) + cdfs = term0 + nan_to_num(factor * term1p, posinf=inf, neginf=-inf) # (R, T, P) + # 'factor' can be inf and term1p 0 => inf * 0 = nan -> 0 + + # interpolate + # currently assumes a regular grid + self._spl_cdf = RegularGridInterpolator((zetas, xs, phis.to_value(u.rad)), cdfs) + + # ------- + # ppf + # might need cdf strategy to enforce "reality" + # cdfstrategy = get_strategy(cdf_strategy) + + # start by supersampling + Zetas, Xs, Phis = meshgrid(zetas, xs, self._phi_interpolant.value, indexing="ij") + _cdfs = self._spl_cdf((Zetas.ravel(), Xs.ravel(), Phis.ravel())) + _cdfs.shape = (lR, lT, len(self._phi_interpolant)) + + self._cdfs = _cdfs + # return + + # build reverse spline + # TODO! vectorize + ppfs = empty((lR, lT, self._ninterpolant), dtype=float64) + for (i, j) in itertools.product(*map(range, ppfs.shape[:2])): + try: + spl = splrep(_cdfs[i, j, :], self._phi_interpolant.value, s=0) + except ValueError: # CDF is non-real + # STDLIB + import pdb + + pdb.set_trace() + raise + + ppfs[i, j, :] = splev(qarr, spl, ext=0) + # interpolate + self._spl_ppf = RegularGridInterpolator((zetas, xs, qarr), ppfs, bounds_error=False) + + @staticmethod + def order_phis(phis: Quantity) -> Tuple[Quantity]: + """Return ordered phis.""" + psort = argsort(phis) + phis = phis[psort] + return phis + + # --------------------------------------------------------------- + + def _cdf(self, phi: ArrayLike, *args: Any, zeta: ArrayLike, x: ArrayLike) -> NDArrayF: + cdf: NDArrayF = self._spl_cdf((zeta, x, phi)) + return cdf + + def cdf(self, phi: Quantity, *, r: Quantity, theta: Quantity) -> NDArrayF: + # TODO! make sure r, theta in right domain + cdf = self._cdf( + phi, + zeta=zeta_of_r(r, self._radial_scale_factor), + x=x_of_theta(theta << u.rad), + ) + return cdf + + def _ppf(self, q: ArrayLike, *args: Any, r: ArrayLike, theta: NDArrayF, **kw: Any) -> NDArrayF: + zeta = zeta_of_r(r, self._radial_scale_factor) + x = x_of_theta(theta << u.rad) + ppf: NDArrayF = self._spl_ppf(column_stack((zeta, x, q))) + return ppf + + def _rvs( + self, + r: NDArrayF, + theta: NDArrayF, + *args: Any, + random_state: random.RandomState, + size: Optional[int] = None, + ) -> NDArrayF: + # Use inverse cdf algorithm for RV generation. + U = random_state.uniform(size=size) + Y = self._ppf(U, *args, r=r, theta=theta) + return Y + + def rvs( # type: ignore + self, + r: Quantity, + theta: Quantity, + *, + size: Optional[int] = None, + random_state: RandomLike = None, + ) -> NDArrayF: + """Random variate sampler. + + Parameters + ---------- + r : Quantity['length', float] + theta : Quantity['angle', float] + size : int or None (optional, keyword-only) + Size of random variates to generate. + random_state : int, `~numpy.random.RandomState`, or None (optional, keyword-only) + If seed is None (or numpy.random), the `numpy.random.RandomState` + singleton is used. If seed is an int, a new RandomState instance is + used, seeded with seed. If seed is already a Generator or + RandomState instance then that instance is used. + + Returns + ------- + ndarray[float] + Shape 'size'. + """ + return super().rvs(r, theta, size=size, random_state=random_state) << u.rad diff --git a/sample_scf/interpolated/core.py b/sample_scf/interpolated/core.py new file mode 100644 index 0000000..9fb415b --- /dev/null +++ b/sample_scf/interpolated/core.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +"""**DOCSTRING**. + +Description. + +""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any + +# THIRD PARTY +from astropy.units import Quantity +from galpy.potential import SCFPotential + +# LOCAL +from .azimuth import interpolated_phi_distribution +from .inclination import interpolated_theta_distribution +from .radial import interpolated_r_distribution +from sample_scf._typing import NDArrayF +from sample_scf.base_multivariate import SCFSamplerBase + +__all__ = ["InterpolatedSCFSampler"] + +############################################################################## +# CODE +############################################################################## + + +class InterpolatedSCFSampler(SCFSamplerBase): + r"""Interpolated SCF Sampler. + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + radii : array-like[float] + The radial component of the interpolation grid. + thetas : array-like[float] + The inclination component of the interpolation grid. + :math:`\theta \in [-\pi/2, \pi/2]`, from the South to North pole, so + :math:`\theta = 0` is the equator. + phis : array-like[float] + The azimuthal component of the interpolation grid. + :math:`phi \in [0, 2\pi)`. + + **kw: + passed to :class:`~sample_scf.interpolated.interpolated_r_distribution`, + :class:`~sample_scf.interpolated.interpolated_theta_distribution`, + :class:`~sample_scf.interpolated.interpolated_phi_distribution` + + Examples + -------- + For all examples we assume the following imports + + >>> import numpy as np + >>> from galpy import potential + + For the SCF Potential we will use the simple example of a Hernquist sphere. + + >>> Acos = np.zeros((20, 24, 24)) + >>> Acos[0, 0, 0] = 1 # Hernquist potential + >>> pot = potential.SCFPotential(Acos=Acos) + + Now we make the sampler, specifying the grid from which the interpolation + will be built. + + >>> radii = np.geomspace(1e-1, 1e3, 100) + >>> thetas = np.linspace(-np.pi / 2, np.pi / 2, 30) + >>> phis = np.linspace(0, 2 * np.pi, 30) + + >>> sampler = SCFSampler(pot, radii=radii, thetas=thetas, phis=phis) + + Now we can evaluate the CDF + + >>> sampler.cdf(10.0, np.pi/3, np.pi) + array([0.82666461, 0.9330127 , 0.5 ]) + + And draw samples + + >>> sampler.rvs(size=5, random_state=3) + + """ + + def __init__( + self, potential: SCFPotential, radii: Quantity, thetas: Quantity, phis: Quantity, **kw: Any + ) -> None: + super().__init__(potential, **kw) + + # ------------------- + # Radial + + # sampler + self._r_distribution = interpolated_r_distribution(potential, radii, **kw) + + # compute the r-dependent coefficient matrix. + rhoT = self.calculate_rhoTilde(self._radii) + + # ------------------- + # Thetas + + # sampler + self._theta_distribution = interpolated_theta_distribution( + potential, self._radii, thetas, rhoTilde=rhoT, **kw + ) + + # ------------------- + # Phis + + self._phi_distribution = interpolated_phi_distribution( + potential, self._radii, self._thetas, phis, rhoTilde=rhoT, **kw + ) + + @property + def _radii(self) -> Quantity: + return self._r_distribution._radii + + @property + def _zetas(self) -> Quantity: + return self._r_distribution._zetas + + @property + def _thetas(self) -> Quantity: + return self._theta_distribution._thetas + + @property + def _xs(self) -> Quantity: + return self._theta_distribution._xs + + @property + def _Qls(self) -> NDArrayF: + return self._theta_distribution._Qls + + @property + def _phis(self) -> Quantity: + self._phi_distribution._phis + + @property + def _Scms(self) -> NDArrayF: + return self._phi_distribution._Scms + + @property + def _Ssms(self) -> NDArrayF: + return self._phi_distribution._Ssms diff --git a/sample_scf/interpolated/inclination.py b/sample_scf/interpolated/inclination.py new file mode 100644 index 0000000..9076d40 --- /dev/null +++ b/sample_scf/interpolated/inclination.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- + +"""**DOCSTRING**. + +Description. + +""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Optional, Tuple, Union + +# THIRD PARTY +import astropy.units as u +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import argsort, array, linspace, pi +from numpy.random import Generator, RandomState +from numpy.typing import ArrayLike +from scipy.interpolate import RectBivariateSpline, splev, splrep + +# LOCAL +from .radial import interpolated_r_distribution +from sample_scf._typing import NDArrayF, RandomLike +from sample_scf.base_univariate import theta_distribution_base +from sample_scf.exact.inclination import exact_theta_distribution_base +from sample_scf.representation import x_of_theta, zeta_of_r + +__all__ = ["interpolated_theta_distribution"] + + +############################################################################## +# CODE +############################################################################## + + +class interpolated_theta_distribution(theta_distribution_base): + """ + Sample inclination coordinate from an SCF potential. + + Parameters + ---------- + pot : `~galpy.potential.SCFPotential` + radii : (R,) Quantity['angle', float] + thetas : (T, ) Quantity ['angle', float] + intrp_step : float, optional + Interpolation step. + **kw + Passed to `scipy.stats.rv_continuous` + "a", "b" are set to [0, pi] + """ + + def __init__( + self, + potential: SCFPotential, + radii: Quantity, + thetas: Quantity, + nintrp: float = 1e3, + **kw: Any, + ) -> None: + rhoTilde: NDArrayF = kw.pop("rhoTilde", None) + super().__init__(potential, **kw) # allowed range of theta + + self._theta_interpolant = linspace(0, pi, num=int(nintrp)) << u.rad + self._x_interpolant = x_of_theta(self._theta_interpolant) + self._q_interpolant = linspace(0, 1, len(self._theta_interpolant)) + + # Sorting + radii, zetas = interpolated_r_distribution.order_radii(self, radii) + thetas, xs = interpolated_theta_distribution.order_thetas(thetas) + self._thetas, self._xs = thetas, xs + + # ------- + # build CDF in shells + + Qls = self.calculate_Qls(radii, rhoTilde=rhoTilde) + # check it's the right shape (R, L) + if Qls.shape != (len(radii), self._lmax + 1): + raise ValueError(f"Qls must be shape ({len(radii)}, {self._lmax + 1})") + self._Qls: NDArrayF = Qls + + # calculate the CDFs exactly # TODO! cleanup + cdfs = exact_theta_distribution_base._cdf(self, xs, Qls) # (R, T) + + # ------- + # interpolate + # assumes a regular grid + + self._spl_cdf = RectBivariateSpline( # (R, T) + zetas, + xs, + cdfs, # (R, T) is anti-theta ordered + bbox=[-1, 1, -1, 1], # [min(zeta), max(zeta), min(x), max(x)] + kx=kw.get("kx", 2), + ky=kw.get("ky", 2), + s=kw.get("s", 0), + ) + + # ppf, one per r, supersampled + # TODO! see if can use this to avoid resplining + _cdfs = self._spl_cdf(zetas, self._x_interpolant[::-1], grid=True) + spls = ( # work through the (R, T) is anti-theta ordered + splrep(_cdfs[i, ::-1], self._theta_interpolant.value, s=0) + for i in range(_cdfs.shape[0]) + ) + ppfs = array([splev(self._q_interpolant, spl, ext=0) for spl in spls]) + + self._spl_ppf = RectBivariateSpline( + zetas, + self._q_interpolant, + ppfs, + bbox=[-1, 1, 0, 1], # [zetamin, zetamax, xmin, xmax] + kx=kw.get("kx", 3), + ky=kw.get("ky", 3), + s=kw.get("s", 0), + ) + + @staticmethod + def order_thetas(thetas: Quantity) -> Tuple[Quantity, NDArrayF]: + """Return ordered thetas and xs. + + Parameters + ---------- + thetas : (T,) Quantity['angle', float] + + Returns + ------- + thetas : (T,) Quantity['angle', float] + xs : (T,) ndarray[float] + """ + xs_unsorted = x_of_theta(thetas << u.rad) # (T,) + xsort = argsort(xs_unsorted) # opposite as theta sort + xs = xs_unsorted[xsort] + thetas = thetas[xsort] + return thetas, xs + + # --------------------------------------------------------------- + + def _cdf(self, x: ArrayLike, *args: Any, zeta: ArrayLike, **kw: Any) -> NDArrayF: + cdf: NDArrayF = self._spl_cdf(zeta, x, grid=False) + return cdf + + def cdf(self, theta: Quantity, r: ArrayLike) -> NDArrayF: + """Cumulative Distribution Function. + + Parameters + ---------- + theta : (T,) Quantity['angle'] + r : (R,) Quantity['length'] + + Returns + ------- + cdf : ndarray[float] + """ + x = x_of_theta(theta << u.rad) + zeta = zeta_of_r(r, scale_radius=self.radial_scale_factor) + cdf = self._cdf(x, zeta=zeta) + return cdf + + def _ppf(self, q: ArrayLike, *, r: ArrayLike, **kw: Any) -> NDArrayF: + """Percent-point function. + + Parameters + ---------- + q : float or (N,) array-like[float] + r : float or (N,) array-like[float] + + Returns + ------- + float or (N,) array-like[float] + Same shape as 'r', 'q'. + """ + zeta = zeta_of_r(r, scale_radius=self.radial_scale_factor) + ppf: NDArrayF = self._spl_ppf(zeta, q, grid=False) + return ppf + + def _rvs( + self, + r: Quantity, + *, + size: Optional[int] = None, + random_state: Union[RandomState, Generator], + # return_thetas: bool = True, # TODO! + ) -> NDArrayF: + """Random variate sampling. + + Parameters + ---------- + r : (R,) Quantity['length', float] + size : int or None (optional, keyword-only) + random_state : int or None (optional, keyword-only) + + Returns + ------- + (size,) array-like[float] + """ + # Use inverse cdf algorithm for RV generation. + U = random_state.uniform(size=size) + Y = self._ppf(U, r=r, grid=False) + return Y + + def rvs( # type: ignore + self, + r: Quantity, + *, + size: Optional[int] = None, + random_state: RandomLike = None, + ) -> Quantity: + """Random variate sampling. + + Parameters + ---------- + r : (R,) Quantity['length', float] + size : int or None (optional, keyword-only) + random_state : int or None (optional, keyword-only) + + Returns + ------- + (R, size) Quantity[float] + Shape 'size'. + """ + return super().rvs(r, size=size, random_state=random_state) << u.rad diff --git a/sample_scf/interpolated/radial.py b/sample_scf/interpolated/radial.py new file mode 100644 index 0000000..b5b42f3 --- /dev/null +++ b/sample_scf/interpolated/radial.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +"""Radial sampling.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from typing import Any, Tuple + +# THIRD PARTY +import astropy.units as u +from astropy.units import Quantity +from galpy.potential import SCFPotential +from numpy import argsort, array, diff, inf, isnan, nanmax, nanmin, where +from numpy.typing import ArrayLike +from scipy.interpolate import InterpolatedUnivariateSpline as IUS + +# LOCAL +from sample_scf._typing import NDArrayF +from sample_scf.base_univariate import r_distribution_base +from sample_scf.representation import r_of_zeta, zeta_of_r + +__all__ = ["interpolated_r_distribution"] + + +############################################################################## +# CODE +############################################################################## + + +class interpolated_r_distribution(r_distribution_base): + """Sample radial coordinate from an SCF potential. + + The potential must have a convergent mass function. + + Parameters + ---------- + potential : `galpy.potential.SCFPotential` + radii : Quantity + Radii at which to interpolate. + **kw + Passed to `scipy.stats.rv_continuous` + "a", "b" are set to [0, inf] + """ + + _interp_in_zeta: bool + + def __init__(self, potential: SCFPotential, radii: Quantity, **kw: Any) -> None: + kw["a"], kw["b"] = 0, nanmax(radii) # allowed range of r + super().__init__(potential, **kw) + + # fraction of total mass grid + # work in zeta, not r, since it is more numerically stable + self._radii, self._zetas = self.order_radii(radii) + self._mgrid = self.calculate_cumulative_mass(self._radii) + + # make splines for fast calculation + self._spl_cdf = IUS(self._zetas, self._mgrid, ext="raise", bbox=[-1, 1], k=1) + self._spl_ppf = IUS(self._mgrid, self._zetas, ext="raise", bbox=[0, 1], k=1) + + def order_radii(self, radii: Quantity) -> Tuple[Quantity, NDArrayF]: + """Return ordered radii and zetas.""" + rsort = argsort(radii) # same as zeta ordering + radii = radii[rsort] + zeta = zeta_of_r(radii, scale_radius=self.radial_scale_factor) + return radii, zeta + + def calculate_cumulative_mass(self, radii: Quantity) -> NDArrayF: + """Calculate cumulative mass function (ie the cdf). + + Parameters + ---------- + radii : (R,) Quantity['length', float] + + Returns + ------- + (R,) ndarray[float] + """ + rgalpy = radii.to_value(u.kpc) / self.potential._ro + mgrid = array([self.potential._mass(x) for x in rgalpy]) # :( + # manual fixes for endpoints and normalization + ind = where(isnan(mgrid))[0] + mgrid[ind[radii[ind] == 0]] = 0 + mgrid = (mgrid - nanmin(mgrid)) / (nanmax(mgrid) - nanmin(mgrid)) # rescale + infind = ind[radii[ind] == inf].squeeze() + mgrid[infind] = 1 + if mgrid[infind - 1] == 1: # munge the rescaling TODO! do better + mgrid[infind - 1] -= min(1e-8, diff(mgrid[slice(infind - 2, infind)]) / 2) + + return mgrid + + # --------------------------------------------------------------- + + def cdf(self, radii: Quantity): # TODO! + return self._cdf(zeta_of_r(radii, self.radial_scale_factor)) + + def _cdf(self, zeta: NDArrayF, *args: Any, **kw: Any) -> NDArrayF: + cdf: NDArrayF = self._spl_cdf(zeta) + # (self._scfmass(zeta) - self._mi) / (self._mf - self._mi) + # TODO! is this normalization even necessary? + return cdf + + def _ppf(self, q: ArrayLike, *args: Any, **kw: Any) -> NDArrayF: + zeta = self._spl_ppf(q) + return r_of_zeta(zeta, self.radial_scale_factor) # TODO! not convert in private function diff --git a/sample_scf/interpolated/tests/__init__.py b/sample_scf/interpolated/tests/__init__.py new file mode 100644 index 0000000..8807810 --- /dev/null +++ b/sample_scf/interpolated/tests/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +This module contains package tests. +""" diff --git a/sample_scf/interpolated/tests/test_interpolated.py b/sample_scf/interpolated/tests/test_interpolated.py new file mode 100644 index 0000000..1539a44 --- /dev/null +++ b/sample_scf/interpolated/tests/test_interpolated.py @@ -0,0 +1,550 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.interpolated`.""" + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import astropy.units as u +import matplotlib.pyplot as plt +import pytest +from astropy.utils.misc import NumpyRNGContext +from numpy import allclose, concatenate, geomspace, inf, isclose, linspace, ndarray, pi, random +from numpy.testing import assert_allclose + +# LOCAL +from .common import phi_distributionTestBase, r_distributionTestBase, theta_distributionTestBase +from .test_base import BaseTest_rv_potential, SCFSamplerTestBase +from sample_scf.base_univariate import _calculate_Qls, _calculate_Scs +from sample_scf.interpolated import InterpolatedSCFSampler +from sample_scf.interpolated.azimuth import interpolated_phi_distribution +from sample_scf.interpolated.inclination import interpolated_theta_distribution +from sample_scf.interpolated.radial import interpolated_r_distribution +from sample_scf.representation import r_of_zeta, x_of_theta, zeta_of_r + +############################################################################## +# PARAMETERS + +rgrid = concatenate(([0], geomspace(1e-1, 1e3, 100), [inf])) +tgrid = linspace(-pi / 2, pi / 2, 30) +pgrid = linspace(0, 2 * pi, 30) + + +############################################################################## +# TESTS +############################################################################## + + +class Test_SCFSampler(SCFSamplerTestBase): + """Test :class:`sample_scf.interpolated.SCFSampler`.""" + + def setup_class(self): + super().setup_class(self) + + self.cls = InterpolatedSCFSampler + self.cls_args = (rgrid, tgrid, pgrid) + self.cls_kwargs = {} + self.cls_pot_kw = {} + + # TODO! make sure these are right! + self.expected_rvs = { + 0: dict(r=2.8473287899985, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad), + 1: dict(r=2.8473287899985, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad), + 2: dict( + r=[55.79997672576021, 2.831600636133138, 66.85343958872159, 5.435971037191061], + theta=[0.3651795356642, 1.476190768304, 0.3320725154563, 1.126711132070] * u.rad, + phi=[6.076027676095, 3.438361627636, 6.11155607905, 4.491321348792] * u.rad, + ), + } + + # =============================================================== + # Method Tests + + # TODO! make sure these are correct + @pytest.mark.parametrize( + "r, theta, phi, expected", + [ + (0, 0, 0, [0, 0.5, 0]), + (1, 0, 0, [0.2505, 0.5, 0]), + ([0, 1], [0, 0], [0, 0], [[0, 0.5, 0], [0.2505, 0.5, 0]]), + ], + ) + def test_cdf(self, sampler, r, theta, phi, expected): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.cdf`.""" + assert allclose(sampler.cdf(r, theta, phi), expected, atol=1e-16) + + # =============================================================== + # Plot Tests + + @pytest.mark.skip("TODO!") + def test_interp_cdf_plot(self): + assert False + + @pytest.mark.skip("TODO!") + def test_interp_sampling_plot(self): + assert False + + +############################################################################## + + +class InterpBaseTest_rv_potential(BaseTest_rv_potential): + def test_init(self, sampler): + """Test initialization.""" + potential = sampler._potential + + assert hasattr(sampler, "_spl_cdf") + assert hasattr(sampler, "_spl_ppf") + + # good + newsampler = self.cls(potential, *self.cls_args) + + # compare that the knots are the same when initializing a second time + # ie that the splines are stable + cdfk = sampler._spl_cdf.get_knots() + ncdfk = newsampler._spl_cdf.get_knots() + if isinstance(cdfk, ndarray): # 1D splines + assert_allclose(ncdfk, cdfk, atol=1e-16) + else: # 2D and 3D splines + for k, nk in zip(cdfk, ncdfk): + assert_allclose(k, nk, atol=1e-16) + + ppfk = sampler._spl_ppf.get_knots() + nppfk = newsampler._spl_ppf.get_knots() + if isinstance(ppfk, ndarray): # 1D splines + assert_allclose(nppfk, ppfk, atol=1e-16) + else: # 2D and 3D splines + for k, nk in zip(ppfk, nppfk): + assert_allclose(k, nk, atol=1e-16) + + # bad + with pytest.raises(TypeError, match="SCFPotential"): + self.cls(None, *self.cls_args) + + +# ---------------------------------------------------------------------------- + + +class Test_r_distribution(r_distributionTestBase, InterpBaseTest_rv_potential): + """Test :class:`sample_scf.sample_interp.interpolated_r_distribution`""" + + def setup_class(self): + super().setup_class(self) + + self.cls = interpolated_r_distribution + self.cls_args = (rgrid,) + self.cls_kwargs = {} + self.cls_pot_kw = {} + + self.cdf_time_scale = 6e-4 # milliseconds + self.rvs_time_scale = 2e-4 # milliseconds + + # =============================================================== + # Method Tests + + def test_init(self, sampler): + """Test initialization.""" + super().test_init(sampler) + + # TODO! test mgrid endpoints, cdf, and ppf + + # TODO! use hypothesis + @pytest.mark.parametrize("r", random.default_rng(0).uniform(0, 1e4, 10)) + def test__cdf(self, sampler, r): + """Test :meth:`sample_scf.interpolated.interpolated_r_distribution._cdf`.""" + super().test__cdf(sampler, r) + + # expected + assert_allclose(sampler._cdf(r), sampler._spl_cdf(zeta_of_r(r))) + + # TODO! use hypothesis + @pytest.mark.parametrize("q", random.default_rng(0).uniform(0, 1, 10)) + def test__ppf(self, sampler, q): + """Test :meth:`sample_scf.interpolated.interpolated_r_distribution._ppf`.""" + # expected + assert_allclose(sampler._ppf(q), r_of_zeta(sampler._spl_ppf(q))) + + # args and kwargs don't matter + assert_allclose(sampler._ppf(q), sampler._ppf(q, 10, test=14)) + + @pytest.mark.parametrize( + "size, random, expected", + [ + (None, 0, 2.84732879), + (1, 2, 1.938060987), + ((3, 1), 4, (55.79997672, 2.831600636, 66.85343958)), + ((3, 1), None, (55.79997672, 2.831600636, 66.85343958)), + ], + ) + def test_rvs(self, sampler, size, random, expected): + """Test :meth:`sample_scf.interpolated.interpolated_r_distribution.rvs`.""" + super().test_rvs(sampler, size, random, expected) + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_interp_r_cdf_plot(self, sampler): + fig = plt.figure(figsize=(10, 3)) + + ax = fig.add_subplot( + 121, + title=r"$m(\leq r) / m_{tot}$", + xlabel="r", + ylabel=r"$m(\leq r) / m_{tot}$", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.semilogx(rgrid, sampler.cdf(rgrid), **kw) + ax.axvline(0, c="tab:blue") + ax.axhline(sampler.cdf(0), c="tab:blue", label="r=0") + ax.axvline(1, c="tab:green") + ax.axhline(sampler.cdf(1), c="tab:green", label="r=1") + ax.axvline(1e2, c="tab:red") + ax.axhline(sampler.cdf(1e2), c="tab:red", label="r=100") + + ax.set_xlim((1e-1, None)) + ax.legend(loc="lower right") + + ax = fig.add_subplot( + 122, + title=r"$m(\leq \zeta) / m_{tot}$", + xlabel=r"$\zeta$", + ylabel=r"$m(\leq \zeta) / m_{tot}$", + ) + ax.plot(zeta_of_r(rgrid), sampler.cdf(rgrid), **kw) + ax.axvline(zeta_of_r(0), c="tab:blue") + ax.axhline(sampler.cdf(0), c="tab:blue", label="r=0") + ax.axvline(zeta_of_r(1), c="tab:green") + ax.axhline(sampler.cdf(1), c="tab:green", label="r=1") + ax.axvline(zeta_of_r(1e2), c="tab:red") + ax.axhline(sampler.cdf(1e2), c="tab:red", label="r=100") + ax.legend(loc="upper left") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_interp_r_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e6)) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e6)).r() + theory = theory[theory < 1e4 * u.kpc] + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot(121, title="SCF vs theory sampling", xlabel="r", ylabel="frequency") + _, bins, *_ = ax.hist( + sample, bins=30, log=True, alpha=0.5, label="SCF sample", c="tab:blue" + ) + # Comparing to expected + ax.hist( + theory.to_value(u.kpc), + bins=bins, + log=True, + edgecolor="black", + linewidth=1.2, + fc=(1, 0, 0, 0.0), + label="Theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig + + +# ---------------------------------------------------------------------------- + + +class Test_theta_distribution(theta_distributionTestBase, InterpBaseTest_rv_potential): + """Test :class:`sample_scf.interpolated.interpolated_theta_distribution`.""" + + def setup_class(self): + super().setup_class(self) + + self.cls = interpolated_theta_distribution + self.cls_args = (rgrid, tgrid) + + self.cdf_time_scale = 3e-4 + self.rvs_time_scale = 6e-4 + + # =============================================================== + # Method Tests + + def test_init(self, sampler): + """Test initialization.""" + super().test_init(sampler) + + # a shape mismatch + Qls = _calculate_Qls(sampler._potential, rgrid[1:-1]) + with pytest.raises(ValueError, match="Qls must be shape"): + sampler.__class__(sampler._potential, rgrid, tgrid, Qls=Qls) + + # TODO! use hypothesis + @pytest.mark.parametrize( + "x, zeta", + [ + *zip( + random.default_rng(0).uniform(-1, 1, 10), + random.default_rng(1).uniform(-1, 1, 10), + ), + ], + ) + def test__cdf(self, sampler, x, zeta): + """Test :meth:`sample_scf.interpolated.interpolated_theta_distribution._cdf`.""" + # expected + assert_allclose(sampler._cdf(x, zeta=zeta), sampler._spl_cdf(zeta, x, grid=False)) + + # args and kwargs don't matter + assert_allclose(sampler._cdf(x, zeta=zeta), sampler._cdf(x, 10, zeta=zeta, test=14)) + + @pytest.mark.parametrize("zeta", random.default_rng(0).uniform(-1, 1, 10)) + def test__cdf_edge(self, sampler, zeta): + """Test :meth:`sample_scf.interpolated.interpolated_theta_distribution._cdf`.""" + assert isclose(sampler._cdf(-1, zeta=zeta), 0.0, atol=1e-16) + assert isclose(sampler._cdf(1, zeta=zeta), 1.0, atol=1e-16) + + @pytest.mark.parametrize( + "theta, r", + [ + *zip( + random.default_rng(0).uniform(-pi / 2, pi / 2, 10), + random.default_rng(1).uniform(0, 1e4, 10), + ), + ], + ) + def test_cdf(self, sampler, theta, r): + """Test :meth:`sample_scf.interpolated.theta_distribution.cdf`.""" + assert_allclose( + sampler.cdf(theta, r), + sampler._spl_cdf( + zeta_of_r(r), + x_of_theta(u.Quantity(theta, u.rad)), + grid=False, + ), + ) + + @pytest.mark.skip("TODO!") + def test__ppf(self): + """Test :meth:`sample_scf.interpolated.theta_distribution._ppf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test__rvs(self): + """Test :meth:`sample_scf.interpolated.theta_distribution._rvs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_rvs(self): + """Test :meth:`sample_scf.interpolated.theta_distribution.rvs`.""" + assert False + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_interp_theta_cdf_plot(self, sampler): + fig = plt.figure(figsize=(10, 3)) + + ax = fig.add_subplot( + 121, + title=r"CDF($\theta$)", + xlabel=r"$\theta$", + ylabel=r"CDF($\theta$)", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.plot(tgrid, sampler.cdf(tgrid, r=10), **kw) + ax.axvline(-pi / 2, c="tab:blue") + ax.axhline(sampler.cdf(-pi / 2, r=10), c="tab:blue", label=r"$\theta=-\frac{\pi}{2}$") + ax.axvline(0, c="tab:green") + ax.axhline(sampler.cdf(0, r=10), c="tab:green", label=r"$\theta=0$") + ax.axvline(pi / 2, c="tab:red") + ax.axhline(sampler.cdf(pi / 2, r=10), c="tab:red", label=r"$\theta=\frac{\pi}{2}$") + ax.legend(loc="lower right") + + ax = fig.add_subplot( + 122, + title=r"CDF($x$)", + xlabel=r"x$", + ylabel=r"CDF($x$)", + ) + ax.plot(x_of_theta(tgrid), sampler.cdf(tgrid, r=10), **kw) + ax.axvline(x_of_theta(-1), c="tab:blue") + ax.axhline(sampler.cdf(-1, r=10), c="tab:blue", label=r"$\theta=-\frac{\pi}{2}$") + ax.axvline(x_of_theta(0), c="tab:green") + ax.axhline(sampler.cdf(0, r=10), c="tab:green", label=r"$\theta=0$") + ax.axvline(x_of_theta(1), c="tab:red") + ax.axhline(sampler.cdf(1, r=10), c="tab:red", label=r"$\theta=\frac{\pi}{2}$") + ax.legend(loc="upper left") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_interp_theta_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e6), r=10) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e6)).theta() + theory -= pi / 2 * u.rad # adjust range back + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot( + 121, + title="SCF vs theory sampling", + xlabel=r"$\theta$", + ylabel="frequency", + ) + _, bins, *_ = ax.hist( + sample, bins=30, log=True, label="SCF sample", color="tab:blue", alpha=0.5 + ) + # Comparing to expected + ax.hist( + theory.to_value(u.rad), + bins=bins, + log=True, + edgecolor="black", + linewidth=1.2, + fc=(1, 0, 0, 0.0), + label="Theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig + + +# ---------------------------------------------------------------------------- + + +class Test_phi_distribution(phi_distributionTestBase, InterpBaseTest_rv_potential): + """Test :class:`sample_scf.interpolated.interpolated_phi_distribution`.""" + + def setup_class(self): + super().setup_class(self) + + self.cls = interpolated_phi_distribution + self.cls_args = (rgrid, tgrid, pgrid) + + self.cdf_time_scale = 12e-4 + self.rvs_time_scale = 12e-4 + + # =============================================================== + # Method Tests + + def test_init(self, sampler): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution._cdf`.""" + # super().test_init(sampler) # doesn't work TODO! + + # a shape mismatch + Scs = _calculate_Scs(sampler._potential, rgrid[1:-1], tgrid[1:-1], warn=False) + with pytest.raises(ValueError, match="Rm, Sm must be shape"): + sampler.__class__(sampler._potential, rgrid, tgrid, pgrid, Scs=Scs) + + @pytest.mark.skip("TODO!") + def test__cdf(self): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution._cdf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_cdf(self): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution.cdf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test__ppf(self): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution._ppf`.""" + assert False + + @pytest.mark.skip("TODO!") + def test__rvs(self): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution._rvs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_rvs(self): + """Test :meth:`sample_scf.interpolated.interpolated_phi_distribution.rvs`.""" + assert False + + # =============================================================== + # Image Tests + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", # TODO! + ) + def test_interp_phi_cdf_plot(self, sampler): + fig = plt.figure(figsize=(5, 3)) + + ax = fig.add_subplot( + 111, + title=r"CDF($\phi$)", + xlabel=r"$\phi$", + ylabel=r"CDF($\phi$)", + ) + kw = dict(marker="o", ms=5, c="k", zorder=5, label="CDF") + ax.plot(pgrid, sampler.cdf(pgrid, r=10, theta=pi / 6), **kw) + ax.axvline(0, c="tab:blue") + ax.axhline(sampler.cdf(0, r=10, theta=pi / 6), c="tab:blue", label=r"$\phi=0$") + ax.axvline(pi, c="tab:green") + ax.axhline(sampler.cdf(pi, r=10, theta=pi / 6), c="tab:green", label=r"$\phi=\pi$") + ax.axvline(2 * pi, c="tab:red") + ax.axhline(sampler.cdf(2 * pi, r=10, theta=pi / 6), c="tab:red", label=r"$\phi=2\pi$") + ax.legend(loc="lower right") + + fig.tight_layout() + return fig + + @pytest.mark.mpl_image_compare( + baseline_dir="baseline_images", + # hash_library="baseline_images/path_to_file.json", + ) + def test_interp_phi_sampling_plot(self, sampler): + """Test sampling.""" + with NumpyRNGContext(0): # control the random numbers + sample = sampler.rvs(size=int(1e6), r=10, theta=pi / 6) + sample = sample[sample < 1e4] + + theory = self.theory[sampler._potential].sample(n=int(1e6)).phi() + + fig = plt.figure(figsize=(10, 3)) + ax = fig.add_subplot( + 121, + title="SCF vs theory sampling", + xlabel=r"$\phi$", + ylabel="frequency", + ) + _, bins, *_ = ax.hist( + sample, bins=30, log=True, alpha=0.5, c="tab:blue", label="SCF sample" + ) + # Comparing to expected + ax.hist( + theory.to_value(u.rad), + bins=bins, + log=True, + edgecolor="black", + linewidth=1.2, + fc=(1, 0, 0, 0.0), + label="Theoretical", + ) + ax.legend() + fig.tight_layout() + + return fig diff --git a/sample_scf/representation.py b/sample_scf/representation.py new file mode 100644 index 0000000..d5dfe9d --- /dev/null +++ b/sample_scf/representation.py @@ -0,0 +1,488 @@ +# -*- coding: utf-8 -*- + +"""Utility functions.""" + +############################################################################## +# IMPORTS + +from __future__ import annotations + +# STDLIB +from functools import singledispatch +from inspect import isclass +from typing import Dict, Optional, Type, Union, overload + +# THIRD PARTY +import astropy.units as u +from astropy.coordinates import Angle, BaseDifferential, BaseRepresentation +from astropy.coordinates import CartesianRepresentation, Distance, PhysicsSphericalRepresentation +from astropy.coordinates import SphericalRepresentation, UnitSphericalRepresentation +from astropy.units import Quantity, UnitConversionError +from erfa import ufunc as erfa_ufunc +from numpy import abs, all, any, arccos, arctan2, atleast_1d, cos, divide, floating, hypot +from numpy import isfinite, less, nan_to_num, ndarray, sin, isnan +from numpy.typing import ArrayLike + +# LOCAL +from ._typing import NDArrayF + +__all__ = ["FiniteSphericalRepresentation"] + +############################################################################## +# CODE +############################################################################## + + +@singledispatch +def _zeta_of_r( + r: Union[ArrayLike, Quantity], /, scale_radius: Union[NDArrayF, Quantity, None] = None +) -> NDArrayF: + # Default implementation, unless there's a registered specific method. + # -------------- + # Checks: r must be non-negative, and the scale radius must be None or positive + if any(less(r, 0)): + raise ValueError("r must be >= 0") + elif scale_radius is None: + scale_radius = 1 + elif not all(isfinite(scale_radius)) or scale_radius <= 0: + raise ValueError("scale_radius must be a finite number > 0") + + # Calculation + r_a: Quantity = divide(r, scale_radius) # can be inf + zeta: NDArrayF = nan_to_num(divide(r_a - 1, r_a + 1), nan=1.0) + return zeta + + +@overload +@_zeta_of_r.register +def zeta_of_r(r: Quantity, /, scale_radius=None) -> NDArrayF: # type: ignore + # Checks: r must be a non-negative length-type quantity, and the scale + # radius must be None or a positive length-type quantity. + if not isinstance(r, Quantity) or r.unit.physical_type != "length": + raise UnitConversionError("r must be a Quantity with units of 'length'") + elif any(isnan(r)) or any(r < 0): + raise ValueError("r must be >= 0") + elif scale_radius is not None: + if not isinstance(scale_radius, Quantity) or scale_radius.unit.physical_type != "length": + raise TypeError("scale_radius must be a Quantity with units of 'length'") + elif not isfinite(scale_radius) or scale_radius <= 0: + raise ValueError("scale_radius must be a finite number > 0") + else: + scale_radius = 1 * r.unit + + r_a: Quantity = r / scale_radius # can be inf + zeta: NDArrayF = nan_to_num(divide(r_a - 1, r_a + 1), nan=1.0) + return zeta.value + + +def zeta_of_r( + r: Union[NDArrayF, Quantity], /, scale_radius: Union[NDArrayF, Quantity, None] = None +) -> NDArrayF: + r""":math:`\zeta(r) = \frac{r/a - 1}{r/a + 1}`. + + Map the half-infinite domain [0, infinity) -> [-1, 1]. + + Parameters + ---------- + r : (R,) Quantity['length'], position-only + scale_radius : Quantity['length'] or None, optional + If None (default), taken to be 1 in the units of `r`. + + Returns + ------- + (R,) array[floating] + + Raises + ------ + TypeError + If `r` is a Quantity and scale radius is not a Quantity. + If `r` is not a Quantity and scale radius is a Quantity. + UnitConversionError + If `r` is a Quantity but does not have units of length. + If `r` is a Quantity and `scale_radius` is not None and does not have + units of length. + ValueError + If `r` is less than 0. + If `scale_radius` is not None and is less than or equal to 0. + """ + return _zeta_of_r(r, scale_radius=scale_radius) + + +zeta_of_r.__wrapped__ = _zeta_of_r # For easier access. + + +# ------------------------------------------------------------------- + + +def r_of_zeta( + zeta: ndarray, /, scale_radius: Union[float, floating, Quantity, None] = None +) -> Union[NDArrayF, Quantity]: + r""":math:`r = \frac{1 + \zeta}{1 - \zeta}`. + + Map back to the half-infinite domain [0, infinity) <- [-1, 1]. + + Parameters + ---------- + zeta : (R,) array[floating] or (R,) Quantity['dimensionless'], position-only + scale_radius : Quantity['length'] or None, optional + + Returns + ------- + (R,) ndarray[float] or (R,) Quantity['length'] + A |Quantity| if scale_radius is not None, else a `numpy.ndarray`. + + Raises + ------ + UnitConversionError + If `scale_radius` is a |Quantity|, but does not have units of length. + ValueError + If `zeta` is not in [-1, 1]. + If `scale_radius` not in (0, `numpy.inf`). + + Warnings + -------- + RuntimeWarning + If zeta is 1 (r is `numpy.inf`). Don't worry, it's not a problem. + """ + if any(zeta < -1) or any(zeta > 1): + raise ValueError("zeta must be in [-1, 1].") + elif scale_radius is None: + scale_radius = 1 + elif scale_radius <= 0 or not isfinite(scale_radius): + raise ValueError("scale_radius must be in (0, inf).") + elif ( + isinstance(scale_radius, Quantity) + and scale_radius.unit.physical_type != "length" # type: ignore + ): + raise UnitConversionError("scale_radius must have units of length") + + r: NDArrayF = atleast_1d(divide(1 + zeta, 1 - zeta)) + r[r < 0] = 0 # correct small errors + rq: Union[NDArrayF, Quantity] + rq = scale_radius * r + return rq + + +# ------------------------------------------------------------------- + + +def x_of_theta(theta: Union[ndarray, Quantity["angle"]]) -> NDArrayF: # type: ignore + r""":math:`x = \cos{\theta}`. + + Parameters + ---------- + theta : (T,) Quantity['angle'] or array['radian'] + + Returns + ------- + float or (T,) ndarray[floating] + """ + x: NDArrayF = cos(theta) + xval = x if not isinstance(x, Quantity) else x.value + return xval + + +# ------------------------------------------------------------------- + + +def theta_of_x(x: ArrayLike, unit=u.rad) -> Quantity: + r""":math:`\theta = \cos^{-1}{x}`. + + Parameters + ---------- + x : array-like + unit : unit-like['angular'], optional + Output units. + + Returns + ------- + theta : float or ndarray + """ + th: NDArrayF = arccos(x) << u.rad + theta = th << unit + return theta + + +########################################################################### + + +class FiniteSphericalRepresentation(BaseRepresentation): + r""" + Representation of points in 3D spherical coordinates (using the physics + convention for azimuth and inclination from the pole) where the radius and + inclination are rescaled to be on [-1, 1]. + + .. math:: + + \zeta = \frac{1 - r / a}{1 + r/a} x = \cos(\theta) + """ + + _phi: Quantity + _x: NDArrayF + _zeta: NDArrayF + _scale_radius: Union[NDArrayF, Quantity] + + attr_classes: Dict[str, Type[Quantity]] = {"phi": Angle, "x": Quantity, "zeta": Quantity} + + def __init__( + self, + phi: Quantity, + x: Union[NDArrayF, Quantity, None] = None, + zeta: Union[NDArrayF, Quantity, None] = None, + scale_radius: Optional[Quantity] = None, + differentials: Union[BaseDifferential, Dict[str, BaseDifferential]] = None, + copy: bool = True, + ): + # Adjustments if passing unitful quantities + if isinstance(x, Quantity) and x.unit.physical_type == "angle": # type: ignore + x = x_of_theta(x) + if isinstance(zeta, Quantity) and zeta.unit.physical_type == "length": # type: ignore + if scale_radius is None: + scale_radius = 1 * zeta.unit # type: ignore + zeta = zeta_of_r(zeta, scale_radius=scale_radius) + elif scale_radius is None: + raise ValueError("if zeta is not a length, a scale_radius must given") + + super().__init__(phi, x, zeta, copy=copy, differentials=differentials) + self._scale_radius = scale_radius + + # Wrap/validate phi/theta + # Note that _phi already holds our own copy if copy=True. + self._phi.wrap_at(360 * u.deg, inplace=True) + + if any(self._x < -1) or any(self._x > 1): + raise ValueError(f"inclination angle(s) must be within -1 <= angle <= 1, got {x}") + + if any(self._zeta < -1) or any(self._zeta > 1): + raise ValueError(f"distances must be within -1 <= zeta <= 1, got {zeta}") + + @property + def phi(self) -> Quantity: + """The azimuth of the point(s).""" + return self._phi + + @property + def x(self) -> Quantity: + """The elevation of the point(s).""" + return self._x + + @property + def zeta(self) -> Quantity: + """The distance from the origin to the point(s).""" + return self._zeta + + @property + def scale_radius(self) -> Union[NDArrayF, Quantity]: + return self._scale_radius + + # ----------------------------------------------------- + # corresponding PhysicsSpherical coordinates + + @property + def theta(self) -> Quantity: + """The elevation of the point(s).""" + return self.calculate_theta_of_x(self._x) + + @property + def r(self) -> Union[NDArrayF, Quantity]: + """The distance from the origin to the point(s).""" + return Distance(self.calculate_r_of_zeta(self._zeta), copy=False) + + # ----------------------------------------------------- + # conversion functions + + def calculate_zeta_of_r(self, r: Union[NDArrayF, Quantity], /) -> NDArrayF: + r""":math:`\zeta(r) = \frac{r/a - 1}{r/a + 1}`. + + Map the half-infinite domain [0, infinity) -> [-1, 1]. + + Parameters + ---------- + r : (R,) Quantity['length'], position-only + + Returns + ------- + (R,) array[floating] + + See Also + -------- + sample_scf.representation.zeta_of_r + """ + return zeta_of_r(r, scale_radius=self.scale_radius) + + def calculate_r_of_zeta(self, zeta: ndarray, /) -> Union[NDArrayF, Quantity]: + r""":math:`r = \frac{1 + \zeta}{1 - \zeta}`. + + Map back to the half-infinite domain [0, infinity) <- [-1, 1]. + + Parameters + ---------- + zeta : (R,) array[floating] or (R,) Quantity['dimensionless'], position-only + + Returns + ------- + (R,) ndarray[float] or (R,) Quantity['length'] + A |Quantity| if scale_radius is not None, else a `numpy.ndarray`. + + See Also + -------- + sample_scf.representation.r_of_zeta + """ + return r_of_zeta(zeta, scale_radius=self.scale_radius) + + def calculate_x_of_theta(self, theta: Quantity) -> NDArrayF: + r""":math:`x = \cos{\theta}`. + + Parameters + ---------- + theta : (T,) Quantity['angle'] or array['radian'] + + Returns + ------- + float or (T,) ndarray[floating] + """ + return x_of_theta(theta) + + def calculate_theta_of_x(self, x: ArrayLike) -> Quantity: + r""":math:`\theta = \cos^{-1}{x}`. + + Parameters + ---------- + x : array-like + unit : unit-like['angular'] or None, optional + Output units. + + Returns + ------- + theta : float or ndarray + """ + return theta_of_x(x) + + # ----------------------------------------------------- + + # def unit_vectors(self): + # sinphi, cosphi = sin(self.phi), cos(self.phi) + # sintheta, x = sin(self.theta), self.x + # return { + # "phi": CartesianRepresentation(-sinphi, cosphi, 0.0, copy=False), + # "theta": CartesianRepresentation(x * cosphi, x * sinphi, -sintheta, copy=False), + # "r": CartesianRepresentation(sintheta * cosphi, sintheta * sinphi, x, copy=False), + # } + + # TODO! + # def scale_factors(self): + # r = self.r / u.radian + # sintheta = sin(self.theta) + # l = broadcast_to(1.*u.one, self.shape, subok=True) + # return {'phi': r * sintheta, + # 'theta': r, + # 'r': l} + + def represent_as(self, other_class, differential_class=None): + # Take a short cut if the other class is a spherical representation + + if isclass(other_class): + if issubclass(other_class, PhysicsSphericalRepresentation): + diffs = self._re_represent_differentials(other_class, differential_class) + return other_class( + phi=self.phi, theta=self.theta, r=self.r, differentials=diffs, copy=False + ) + elif issubclass(other_class, SphericalRepresentation): + diffs = self._re_represent_differentials(other_class, differential_class) + return other_class( + lon=self.phi, + lat=90 * u.deg - self.theta, + distance=self.r, + differentials=diffs, + copy=False, + ) + elif issubclass(other_class, UnitSphericalRepresentation): + diffs = self._re_represent_differentials(other_class, differential_class) + return other_class( + lon=self.phi, lat=90 * u.deg - self.theta, differentials=diffs, copy=False + ) + + return super().represent_as(other_class, differential_class) + + def to_cartesian(self): + """ + Converts spherical polar coordinates to 3D rectangular cartesian + coordinates. + """ + # We need to convert Distance to Quantity to allow negative values. + d = self.r.view(Quantity) + + x = d * sin(self.theta) * cos(self.phi) + y = d * sin(self.theta) * sin(self.phi) + z = d * cos(self.theta) + + return CartesianRepresentation(x=x, y=y, z=z, copy=False) + + @classmethod + def from_cartesian(cls, cart, scale_radius: Optional[Quantity] = None): + """ + Converts 3D rectangular cartesian coordinates to spherical polar + coordinates. + """ + s = hypot(cart.x, cart.y) + r = hypot(s, cart.z) + + phi = arctan2(cart.y, cart.x) << u.rad + theta = arctan2(s, cart.z) << u.rad + + return cls(phi=phi, x=theta, zeta=r, scale_radius=scale_radius, copy=False) + + @classmethod + def from_physicsspherical( + cls, psphere: PhysicsSphericalRepresentation, scale_radius: Optional[Quantity] = None + ): + """ + Converts spherical polar coordinates. + """ + return cls( + phi=psphere.phi, x=psphere.theta, zeta=psphere.r, scale_radius=scale_radius, copy=False + ) + + def transform(self, matrix, scale_radius: Optional[Quantity] = None): + """Transform the spherical coordinates using a 3x3 matrix. + + This returns a new representation and does not modify the original one. + Any differentials attached to this representation will also be + transformed. + + Parameters + ---------- + matrix : (3,3) array-like + A 3x3 matrix, such as a rotation matrix (or a stack of matrices). + """ + if self.differentials: + # TODO! shortcut if there are differentials. + # Currently just super, which uses Cartesian backend. + rep = super().transform(matrix) + + else: + # apply transformation in unit-spherical coordinates + xyz = erfa_ufunc.s2c(self.phi, 90 * u.deg - self.theta) + p = erfa_ufunc.rxp(matrix, xyz) + lon, lat, ur = erfa_ufunc.p2s(p) # `ur` is transformed unit-`r` + theta = 90 * u.deg - lat + + # create transformed physics-spherical representation, + # reapplying the distance scaling + rep = self.__class__(phi=lon, x=theta, zeta=self.r * ur, scale_radius=scale_radius) + + return rep + + def norm(self): + """Vector norm. + + The norm is the standard Frobenius norm, i.e., the square root of the + sum of the squares of all components with non-angular units. For + spherical coordinates, this is just the absolute value of the radius. + + Returns + ------- + norm : `astropy.units.Quantity` + Vector norm, with the same shape as the representation. + """ + return abs(self.zeta) diff --git a/sample_scf/tests/base.py b/sample_scf/tests/base.py new file mode 100644 index 0000000..c190c57 --- /dev/null +++ b/sample_scf/tests/base.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- + + +############################################################################## +# IMPORTS + +# STDLIB +import time +from abc import ABCMeta, abstractmethod + +# THIRD PARTY +import pytest +from galpy.potential import KeplerPotential +from numpy import linspace + +# LOCAL +from sample_scf.conftest import _hernquist_scf_potential + +############################################################################## +# TESTS +############################################################################## + + +class BaseTest_Sampler(metaclass=ABCMeta): + @pytest.fixture( + scope="class", + params=[ + "hernquist_scf_potential", + # "nfw_scf_potential", # TODO! turn on + ], + ) + def potential(self, request): + if request.param in ("hernquist_scf_potential"): + potential = _hernquist_scf_potential + elif request.param == "nfw_scf_potential": + # potential = nfw_scf_potential.__wrapped__() + raise NotImplementedError + else: + raise NotImplementedError + + yield potential + + @pytest.fixture(scope="class") + @abstractmethod + def rv_cls(self): + """Sample class.""" + raise NotImplementedError + + @pytest.fixture(scope="class") + def rv_cls_args(self): + return () + + @pytest.fixture(scope="class") + def rv_cls_kw(self): + return {} + + @pytest.fixture(scope="class") + def cls_pot_kw(self): + return {} + + @pytest.fixture(scope="class") + def full_rv_cls_kw(self, rv_cls_kw, cls_pot_kw, potential): + return {**rv_cls_kw, **cls_pot_kw.get(potential, {})} + + @pytest.fixture(scope="class") + def sampler(self, rv_cls, potential, rv_cls_args, full_rv_cls_kw): + """Set up r, theta, or phi sampler.""" + sampler = rv_cls(potential, *rv_cls_args, **full_rv_cls_kw) + return sampler + + # cdf tests + + @pytest.fixture(scope="class") + def cdf_args(self): + return () + + @pytest.fixture(scope="class") + def cdf_kw(self): + return {} + + # rvs tests + + @pytest.fixture(scope="class") + def rvs_args(self): + return () + + @pytest.fixture(scope="class") + def rvs_kw(self): + return {} + + # time-scale tests + + def cdf_time_arr(self, size): + return linspace(0, 1e4, size) + + @pytest.fixture(scope="class") + def cdf_time_scale(self): + return 0 + + @pytest.fixture(scope="class") + def rvs_time_scale(self): + return 0 + + # =============================================================== + # Method Tests + + def test_init_wrong_potential(self, rv_cls, rv_cls_args, rv_cls_kw): + """Test initialization when the potential is wrong.""" + # bad value + with pytest.raises(TypeError, match=""): + rv_cls(KeplerPotential(), *rv_cls_args, **rv_cls_kw) + + # --------------------------------------------------------------- + + def test_potential_property(self, sampler): + # Identity + assert sampler.potential is sampler._potential + + # =============================================================== + # Time Scaling Tests + + @pytest.mark.parametrize("size", [1, 10, 100, 1000, 10000]) + def test_cdf_time_scaling(self, sampler, size, cdf_args, cdf_kw, cdf_time_scale): + """Test that the time scales as X * size""" + x = self.cdf_time_arr(size) + tic = time.perf_counter() + sampler.cdf(x, *cdf_args, **cdf_kw) + toc = time.perf_counter() + + assert (toc - tic) < cdf_time_scale * size # linear scaling + + @pytest.mark.parametrize("size", [1, 10, 100, 1000, 10000]) + def test_rvs_time_scaling(self, sampler, size, rvs_args, rvs_kw, rvs_time_scale): + """Test that the time scales as X * size""" + tic = time.perf_counter() + sampler.rvs(size=size, *rvs_args, **rvs_kw) + toc = time.perf_counter() + + assert (toc - tic) < rvs_time_scale * size # linear scaling diff --git a/sample_scf/tests/baseline_images/test_phi_cdf_plot.png b/sample_scf/tests/baseline_images/test_phi_cdf_plot.png new file mode 100644 index 0000000..f63c35e Binary files /dev/null and b/sample_scf/tests/baseline_images/test_phi_cdf_plot.png differ diff --git a/sample_scf/tests/baseline_images/test_phi_sampling_plot.png b/sample_scf/tests/baseline_images/test_phi_sampling_plot.png new file mode 100644 index 0000000..462c115 Binary files /dev/null and b/sample_scf/tests/baseline_images/test_phi_sampling_plot.png differ diff --git a/sample_scf/tests/baseline_images/test_r_cdf_plot.png b/sample_scf/tests/baseline_images/test_r_cdf_plot.png new file mode 100644 index 0000000..b31364e Binary files /dev/null and b/sample_scf/tests/baseline_images/test_r_cdf_plot.png differ diff --git a/sample_scf/tests/baseline_images/test_r_sampling_plot.png b/sample_scf/tests/baseline_images/test_r_sampling_plot.png new file mode 100644 index 0000000..421a270 Binary files /dev/null and b/sample_scf/tests/baseline_images/test_r_sampling_plot.png differ diff --git a/sample_scf/tests/baseline_images/test_theta_cdf_plot.png b/sample_scf/tests/baseline_images/test_theta_cdf_plot.png new file mode 100644 index 0000000..48c7e25 Binary files /dev/null and b/sample_scf/tests/baseline_images/test_theta_cdf_plot.png differ diff --git a/sample_scf/tests/baseline_images/test_theta_sampling_plot.png b/sample_scf/tests/baseline_images/test_theta_sampling_plot.png new file mode 100644 index 0000000..05e9f5d Binary files /dev/null and b/sample_scf/tests/baseline_images/test_theta_sampling_plot.png differ diff --git a/sample_scf/tests/data/__init__.py b/sample_scf/tests/data/__init__.py new file mode 100644 index 0000000..d207d56 --- /dev/null +++ b/sample_scf/tests/data/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +This module contains package tests. +""" + +# LOCAL +from . import data_Test_rv_potential + +results = {"Test_rv_potential": data_Test_rv_potential.results} diff --git a/sample_scf/tests/data/data_Test_rv_potential.py b/sample_scf/tests/data/data_Test_rv_potential.py new file mode 100644 index 0000000..1eaa1b9 --- /dev/null +++ b/sample_scf/tests/data/data_Test_rv_potential.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +rvs = { + "size": (None, 1, (3, 1), (3, 1)), + "random": (0, 2, 4, None), + "expected": ( + 0.5488135039273248, + 0.43599490214200376, + (0.9670298390136767, 0.5472322491757223, 0.9726843599648843), + (0.9670298390136767, 0.5472322491757223, 0.9726843599648843), + ), +} + + +results = {"rvs": rvs} diff --git a/sample_scf/tests/data/nfw.npz b/sample_scf/tests/data/nfw.npz new file mode 100644 index 0000000..e8b573d Binary files /dev/null and b/sample_scf/tests/data/nfw.npz differ diff --git a/sample_scf/tests/data/scf_coeffs.npz b/sample_scf/tests/data/scf_coeffs.npz new file mode 100644 index 0000000..caf3765 Binary files /dev/null and b/sample_scf/tests/data/scf_coeffs.npz differ diff --git a/sample_scf/tests/data/scf_nfw_coeffs.npz b/sample_scf/tests/data/scf_nfw_coeffs.npz new file mode 100644 index 0000000..66afaba Binary files /dev/null and b/sample_scf/tests/data/scf_nfw_coeffs.npz differ diff --git a/sample_scf/tests/data/scf_tnfw_coeffs.npz b/sample_scf/tests/data/scf_tnfw_coeffs.npz new file mode 100644 index 0000000..88d2b58 Binary files /dev/null and b/sample_scf/tests/data/scf_tnfw_coeffs.npz differ diff --git a/sample_scf/tests/test_base_multivariate.py b/sample_scf/tests/test_base_multivariate.py new file mode 100644 index 0000000..dd043dc --- /dev/null +++ b/sample_scf/tests/test_base_multivariate.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.base_multivariate`.""" + + +############################################################################## +# IMPORTS + +# STDLIB +from abc import abstractmethod + +# THIRD PARTY +import astropy.units as u +import pytest +from astropy.coordinates import BaseRepresentation, PhysicsSphericalRepresentation +from astropy.utils.misc import NumpyRNGContext +from galpy.potential import SCFPotential +from numpy import ndarray, shape, squeeze, atleast_1d, allclose +from numpy.testing import assert_allclose + +# LOCAL +from .base import BaseTest_Sampler +from .test_base_univariate import rvtestsampler +from sample_scf.base_univariate import phi_distribution_base, r_distribution_base +from sample_scf.base_univariate import theta_distribution_base + +############################################################################## +# TESTS +############################################################################## + + +class BaseTest_SCFSamplerBase(BaseTest_Sampler): + """Test :class:`sample_scf.base_multivariate.SCFSamplerBase`.""" + + @pytest.fixture(scope="class") + @abstractmethod + def rv_cls(self): + raise NotImplementedError + + @pytest.fixture(scope="class") + def r_distribution_cls(self): + return r_distribution_base + + @pytest.fixture(scope="class") + def theta_distribution_cls(self): + return theta_distribution_base + + @pytest.fixture(scope="class") + def phi_distribution_cls(self): + return phi_distribution_base + + def setup_class(self): + self.expected_rvs = { + 0: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad), + 1: dict(r=0.548813503927, theta=1.021982822867 * u.rad, phi=0.548813503927 * u.rad), + 2: dict( + r=[0.9670298390136, 0.5472322491757, 0.9726843599648, 0.7148159936743], + theta=[0.603766487781, 1.023564077619, 0.598111966830, 0.855980333120] * u.rad, + phi=[0.9670298390136, 0.547232249175, 0.9726843599648, 0.7148159936743] * u.rad, + ), + } + + # =============================================================== + # Method Tests + + def test_init_attrs(self, sampler): + assert hasattr(sampler, "_potential") + assert hasattr(sampler, "_r_distribution") + assert hasattr(sampler, "_theta_distribution") + assert hasattr(sampler, "_phi_distribution") + + # --------------------------------------------------------------- + + def test_potential_property(self, sampler, potential): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.potential`.""" + # Identity + assert sampler.potential is sampler._potential + # Properties + assert isinstance(sampler.potential, SCFPotential) + + def test_r_distribution_property(self, sampler, r_distribution_cls): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.r_distribution`.""" + # Identity + assert sampler.r_distribution is sampler._r_distribution + # Properties + assert isinstance(sampler.r_distribution, r_distribution_cls) + + def test_theta_distribution_property(self, sampler, theta_distribution_cls): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.theta_distribution`.""" + # Identity + assert sampler.theta_distribution is sampler._theta_distribution + # Properties + assert isinstance(sampler.theta_distribution, theta_distribution_cls) + + def test_phi_distribution_property(self, sampler, phi_distribution_cls): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.phi_distribution`.""" + # Identity + assert sampler.phi_distribution is sampler._phi_distribution + # Properties + assert isinstance(sampler.phi_distribution, phi_distribution_cls) + + def test_radial_scale_factor_property(self, sampler): + # Identity + assert sampler.radial_scale_factor is sampler.r_distribution.radial_scale_factor + + def test_nmax_property(self, sampler): + # Identity + assert sampler.nmax is sampler.r_distribution.nmax + + def test_lmax_property(self, sampler): + # Identity + assert sampler.lmax is sampler.r_distribution.lmax + + # --------------------------------------------------------------- + + @abstractmethod + def test_cdf(self, sampler, position, expected): + """Test cdf method.""" + cdf = sampler.cdf(size=size, *position) + + assert isinstance(cdf, ndarray) + assert False + + assert_allclose(cdf, expected, atol=1e-16) + + @abstractmethod + def test_rvs(self, sampler, size, random, expected): + """Test rvs method. + + The ``NumpyRNGContext`` is to control the random generator used to make + the RandomState. For ``random != None``, this doesn't matter. + + Each child class will need to define the set of expected results. + """ + with NumpyRNGContext(4): + rvs = sampler.rvs(size=size, random_state=random) + + assert isinstance(rvs, BaseRepresentation) + + r = rvs.represent_as(PhysicsSphericalRepresentation) + assert_allclose(r.r, expected.r, atol=1e-16) + assert_allclose(r.theta, expected.theta, atol=1e-16) + assert_allclose(r.phi, expected.phi, atol=1e-16) + + # --------------------------------------------------------------- + + def test_repr(self): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.__repr__`.""" + assert False + + +############################################################################## + + +class Test_SCFSamplerBase(BaseTest_SCFSamplerBase): + @pytest.fixture(scope="class") + def rv_cls(self): + return SCFSamplerBase + + @pytest.fixture() + def sampler(self, potential): + """Set up r, theta, phi sampler.""" + super().sampler(potential) + + sampler._r_distribution = rvtestsampler(potential) + sampler._theta_distribution = rvtestsampler(potential) + sampler._phi_distribution = rvtestsampler(potential) + + return sampler + + # =============================================================== + # Method Tests + + @pytest.mark.parametrize( + "r, theta, phi, expected", + [ + (0, 0, 0, [0, 0, 0]), + (1, 0, 0, [1, 0, 0]), + ([0, 1], [0, 0], [0, 0], [[0, 0, 0], [1, 0, 0]]), + ], + ) + def test_cdf(self, sampler, r, theta, phi, expected): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.cdf`.""" + cdf = sampler.cdf(r, theta, phi) + assert allclose(cdf, expected, atol=1e-16) + + # also test shape + assert tuple(atleast_1d(squeeze((*shape(r), 3)))) == cdf.shape + + @pytest.mark.parametrize( + "id, size, random, vectorized", + [ + (0, None, 0, True), + (0, None, 0, False), + (1, 1, 0, True), + (1, 1, 0, False), + (2, 4, 4, True), + (2, 4, 4, False), + ], + ) + def test_rvs(self, sampler, id, size, random, vectorized): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.rvs`.""" + samples = sampler.rvs(size=size, random_state=random, vectorized=vectorized) + sce = PhysicsSphericalRepresentation(**self.expected_rvs[id]) + + assert_allclose(samples.r, sce.r, atol=1e-16) + assert_allclose(samples.theta.value, sce.theta.value, atol=1e-16) + assert_allclose(samples.phi.value, sce.phi.value, atol=1e-16) diff --git a/sample_scf/tests/test_base_univariate.py b/sample_scf/tests/test_base_univariate.py new file mode 100644 index 0000000..d7e3728 --- /dev/null +++ b/sample_scf/tests/test_base_univariate.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.base_univariate`.""" + + +############################################################################## +# IMPORTS + +# STDLIB +import inspect +from abc import abstractmethod + +# THIRD PARTY +import pytest +from astropy.utils.misc import NumpyRNGContext +from numpy.testing import assert_allclose +from scipy.stats import rv_continuous + +# LOCAL +from .base import BaseTest_Sampler +from .data import results +from sample_scf.base_univariate import r_distribution_base, rv_potential +from numpy import concatenate, geomspace, inf, pi, linspace, random, atleast_1d + +# import time +# from abc import ABCMeta +# from galpy.potential import KeplerPotential +# from sample_scf.conftest import _hernquist_scf_potential +# import astropy.coordinates as coord +# import astropy.units as u + + +############################################################################## +# PARAMETERS + +radii = concatenate(([0], geomspace(1e-1, 1e3, 28), [inf])) # same shape as ↓ +thetas = linspace(0, pi, 30) +phis = linspace(0, 2 * pi, 30) + + +############################################################################## +# TESTS +############################################################################## + + +class BaseTest_rv_potential(BaseTest_Sampler): + """Test subclasses of `sample_scf.base_univariate.rv_potential`.""" + + # =============================================================== + # Method Tests + + def test_init_signature(self, sampler): + """Test signature is compatible with `scipy.stats.rv_continuous`. + + The subclasses pass to parent by kwargs, so can't check the full + suite of parameters. + """ + sig = inspect.signature(sampler.__init__) + params = sig.parameters + + assert "potential" in params + + def test_init_attrs(self, sampler, rv_cls_args, rv_cls_kw): + """Test attributes set at initialization.""" + # check it has the expected attributes + assert hasattr(sampler, "_potential") + assert hasattr(sampler, "_nmax") + assert hasattr(sampler, "_lmax") + assert hasattr(sampler, "_radial_scale_factor") + + # TODO! expected parameters from scipy rv_continuous + + # --------------------------------------------------------------- + + def test_radial_scale_factor_property(self, sampler): + # Identity + assert sampler.radial_scale_factor is sampler._radial_scale_factor + + def test_nmax_property(self, sampler): + # Identity + assert sampler.nmax is sampler._nmax + + def test_lmax_property(self, sampler): + # Identity + assert sampler.lmax is sampler._lmax + + # --------------------------------------------------------------- + + @abstractmethod + def test_cdf(self, sampler, position, expected): + """Test cdf method.""" + assert_allclose(sampler.cdf(size=len(expected), *position), expected, atol=1e-16) + + @abstractmethod + def test_rvs(self, sampler, size, random, expected): + """Test rvs method. + + The ``NumpyRNGContext`` is to control the random generator used to make + the RandomState. For ``random != None``, this doesn't matter. + + Each child class will need to define the set of expected results. + """ + with NumpyRNGContext(4): + assert_allclose(sampler.rvs(size=size, random_state=random), expected, atol=1e-16) + + +############################################################################## + + +class rvtestsampler(rv_potential): + """A sampler for testing the modified ``rv_continuous`` base class.""" + + def _cdf(self, x, *args, **kwargs): + return x + + cdf = _cdf + + def _rvs(self, *args, size=None, random_state=None): + if random_state is None: + random_state = random + + return atleast_1d(random_state.uniform(size=size)) + + +class Test_rv_potential(BaseTest_rv_potential): + """Test :class:`sample_scf.base_univariate.rv_potential`.""" + + @pytest.fixture(scope="class") + def rv_cls(self): + return rvtestsampler + + @pytest.fixture(scope="class") + def cdf_time_scale(self): + return 4e-6 + + @pytest.fixture(scope="class") + def rvs_time_scale(self): + return 1e-4 + + # =============================================================== + # Method Tests + + def test_init_signature(self, sampler): + """Test signature is compatible with `scipy.stats.rv_continuous`.""" + sig = inspect.signature(sampler.__init__) + params = sig.parameters + + scipyps = inspect.signature(rv_continuous.__init__).parameters + + assert params["momtype"].default == scipyps["momtype"].default + assert params["a"].default == scipyps["a"].default + assert params["b"].default == scipyps["b"].default + assert params["xtol"].default == scipyps["xtol"].default + assert params["badvalue"].default == scipyps["badvalue"].default + assert params["name"].default == scipyps["name"].default + assert params["longname"].default == scipyps["longname"].default + assert params["shapes"].default == scipyps["shapes"].default + assert params["extradoc"].default == scipyps["extradoc"].default + assert params["seed"].default == scipyps["seed"].default + + @pytest.mark.parametrize( # TODO! instead by request.param lookup and index + list(results["Test_rv_potential"]["rvs"].keys()), + zip(*results["Test_rv_potential"]["rvs"].values()), + ) + def test_rvs(self, sampler, size, random, expected): + super().test_rvs(sampler, size, random, expected) + + +############################################################################## + + +class BaseTest_r_distribution_base(BaseTest_rv_potential): + """Test :class:`sample_scf.base_multivariate.r_distribution_base`.""" + + @pytest.fixture(scope="class") + @abstractmethod + def rv_cls(self): + """Sample class.""" + return r_distribution_base + + # =============================================================== + # Method Tests + + +############################################################################## + + +class BaseTest_theta_distribution_base(BaseTest_rv_potential): + """Test :class:`sample_scf.base_multivariate.theta_distribution_base`.""" + + @pytest.fixture(scope="class") + @abstractmethod + def rv_cls(self): + # return theta_distribution_base + raise NotImplementedError + + def cdf_time_arr(self, size: int): + return linspace(0, pi, size) + + # =============================================================== + # Method Tests + + def test_init_attrs(self, sampler): + """Test attributes set at initialization.""" + # super().test_init_attrs(sampler) + + assert hasattr(sampler, "_lrange") + assert min(sampler._lrange) == 0 + assert max(sampler._lrange) == sampler.lmax + 1 + + @pytest.mark.skip("TODO!") + def test_calculate_Qls(self, potential, sampler): + """Test :meth:`sample_scf.base_univariate.theta_distribution_base.calculate_Qls`.""" + assert False + + +############################################################################## + + +class BaseTest_phi_distribution_base(BaseTest_rv_potential): + """Test :class:`sample_scf.base_multivariate.phi_distribution_base`.""" + + @pytest.fixture(scope="class") + @abstractmethod + def rv_cls(self): + # return theta_distribution_base + raise NotImplementedError + + def cdf_time_arr(self, size: int): + return linspace(0, 2 * pi, size) + + # =============================================================== + # Method Tests + + def test_init_attrs(self, sampler): + """Test attributes set at initialization.""" + # super().test_init_attrs(sampler) + raise NotImplementedError + + # l-range + assert hasattr(sampler, "_lrange") + assert min(sampler._lrange) == 0 + assert max(sampler._lrange) == sampler.lmax + 1 + + @pytest.mark.skip("TODO!") + def test_pnts_Scs(self, sampler): + """Test :class:`sample_scf.base_multivariate.phi_distribution_base._pnts_Scs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_grid_Scs(self, sampler): + """Test :class:`sample_scf.base_multivariate.phi_distribution_base._grid_Scs`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_calculate_Scs(self, sampler): + """Test :class:`sample_scf.base_multivariate.phi_distribution_base.calculate_Scs`.""" + assert False diff --git a/sample_scf/tests/test_conftest.py b/sample_scf/tests/test_conftest.py new file mode 100644 index 0000000..b02ff2c --- /dev/null +++ b/sample_scf/tests/test_conftest.py @@ -0,0 +1,110 @@ +# # -*- coding: utf-8 -*- +# +# """Testing :mod:`~sample_scf.conftest`. +# +# Even the test source should be tested. +# In particular, the potential fixtures need confirmation that the SCF form +# matches the theoretical, within tolerances. +# """ +# +# ############################################################################## +# # IMPORTS +# +# # STDLIB +# import abc +# +# # THIRD PARTY +# import numpy as np +# import pytest +# +# # LOCAL +# from sample_scf import conftest +# +# ############################################################################## +# # TESTS +# ############################################################################## +# +# +# class PytestPotential(metaclass=abc.ABCMeta): +# """Test a Pytest Potential.""" +# +# @classmethod +# @abc.abstractmethod +# def setup_class(self): +# """Setup fixtures for testing.""" +# self.R = linspace(0.0, 3.0, num=1001) +# self.atol = 1e-6 +# self.restrict_ind = ones(1001, dtype=bool) +# +# @pytest.fixture(scope="class") +# @abc.abstractmethod +# def scf_potential(self): +# """The `galpy.potential.SCFPotential` from a `pytest.fixture`.""" +# return +# +# def compare_to_theory(self, theory, scf, atol=1e-6): +# # test that where theory is finite they match and where it's infinite, +# # the scf is NaN +# fnt = ~isinf(theory) +# ind = self.restrict_ind & fnt +# +# assert allclose(theory[ind], scf[ind], atol=atol) +# assert all(isnan(scf[~fnt])) +# +# # =============================================================== +# # sanity checks +# +# def test_df(self): +# assert self.df._pot is self.theory +# +# # =============================================================== +# +# def test_density_along_Rz_equality(self, scf_potential): +# theory = self.theory.dens(self.R, self.R) +# scf = scf_potential.dens(self.R, self.R) +# self.compare_to_theory(theory, scf, atol=self.atol) +# +# @pytest.mark.parametrize("z", [0, 10, 15]) +# def test_density_at_z(self, scf_potential, z): +# theory = self.theory.dens(self.R, z) +# scf = scf_potential.dens(self.R, z) +# self.compare_to_theory(theory, scf, atol=self.atol) +# +# +# # ------------------------------------------------------------------- +# +# +# class Test_hernquist_scf_potential(PytestPotential): +# @classmethod +# def setup_class(self): +# """Setup fixtures for testing.""" +# super().setup_class() +# +# self.theory = conftest.hernquist_potential +# self.df = conftest.hernquist_df +# +# @pytest.fixture(scope="class") +# def scf_potential(self, hernquist_scf_potential): +# """The `galpy.potential.SCFPotential` from a `pytest.fixture`.""" +# return hernquist_scf_potential +# +# +# # ------------------------------------------------------------------- +# +# +# # class Test_nfw_scf_potential(PytestPotential): +# # @classmethod +# # def setup_class(self): +# # """Setup fixtures for testing.""" +# # super().setup_class() +# # +# # self.theory = conftest.nfw_potential +# # self.df = conftest.nfw_df +# # +# # self.atol = 1e-2 +# # self.restrict_ind[:18] = False # skip some of the inner ones +# # +# # @pytest.fixture(scope="class") +# # def scf_potential(self, nfw_scf_potential): +# # """The `galpy.potential.SCFPotential` from a `pytest.fixture`.""" +# # return nfw_scf_potential diff --git a/sample_scf/tests/test_core.py b/sample_scf/tests/test_core.py new file mode 100644 index 0000000..9f02858 --- /dev/null +++ b/sample_scf/tests/test_core.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.core`.""" + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import astropy.units as u +import pytest + +# LOCAL +from .test_base_multivariate import BaseTest_SCFSamplerBase +from sample_scf import SCFSampler +from sample_scf.exact import exact_phi_distribution, exact_r_distribution, exact_theta_distribution +from sample_scf.core import MethodsMapping +from numpy import allclose, atleast_1d, squeeze, shape + +############################################################################## +# TESTS +############################################################################## + + +def test_MethodsMapping(potentials): + """Test `sample_scf.core.MethodsMapping`.""" + # Good + mm = MethodsMapping( + r=exact_r_distribution(potentials, total_mass=1), + theta=exact_theta_distribution(potentials), + phi=exact_phi_distribution(potentials), + ) + + assert False + + +############################################################################## + + +class Test_SCFSampler(BaseTest_SCFSamplerBase): + """Test :class:`sample_scf.core.SCFSample`.""" + + @pytest.fixture(scope="class") + def rv_cls(self): + return SCFSampler + + @pytest.fixture(scope="class") + def rv_cls_args(self): + return ("interp",) # TODO! iterate over this + + @pytest.fixture(scope="class") + def rv_cls_kw(self): + # return dict(rgrid=rgrid, thetagrid=tgrid, phigrid=pgrid) + return {} + + def setup_class(self): + # TODO! make sure these are right! + self.expected_rvs = { + 0: dict(r=2.8473287899985, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad), + 1: dict(r=2.8473287899985, theta=1.473013568997 * u.rad, phi=3.4482969442579 * u.rad), + 2: dict( + r=[55.79997672576021, 2.831600636133138, 66.85343958872159, 5.435971037191061], + theta=[0.3651795356642, 1.476190768304, 0.3320725154563, 1.126711132070] * u.rad, + phi=[6.076027676095, 3.438361627636, 6.11155607905, 4.491321348792] * u.rad, + ), + } + + # =============================================================== + # Method Tests + + @pytest.mark.parametrize( + "r, theta, phi, expected", + [ + (0, 0, 0, [0, 0.5, 0]), + (1, 0, 0, [0.2505, 0.5, 0]), + ([0, 1], [0, 0], [0, 0], [[0, 0.5, 0], [0.2505, 0.5, 0]]), + ], + ) + def test_cdf(self, sampler, r, theta, phi, expected): + """Test :meth:`sample_scf.base_multivariate.SCFSamplerBase.cdf`.""" + cdf = sampler.cdf(r, theta, phi) + assert allclose(cdf, expected, atol=1e-16) + + # also test shape + assert tuple(atleast_1d(squeeze((*shape(r), 3)))) == cdf.shape diff --git a/sample_scf/tests/test_init.py b/sample_scf/tests/test_init.py index 5e10999..f5a81d0 100644 --- a/sample_scf/tests/test_init.py +++ b/sample_scf/tests/test_init.py @@ -2,17 +2,18 @@ """Some basic tests.""" -__all__ = [ - "test_expected_imports", -] - - ############################################################################## # IMPORTS -# BUILT-IN +# STDLIB import inspect +# LOCAL +import sample_scf +from sample_scf.core import SCFSampler +from sample_scf.exact import ExactSCFSampler +from sample_scf.interpolated import InterpolatedSCFSampler + ############################################################################## # TESTS ############################################################################## @@ -20,17 +21,8 @@ def test_expected_imports(): """Test can import expected modules and objects.""" - # LOCAL - import sample_scf - assert inspect.ismodule(sample_scf) - -# /def - - -# ------------------------------------------------------------------- - - -############################################################################## -# END + assert sample_scf.SCFSampler is SCFSampler + assert sample_scf.ExactSCFSampler is ExactSCFSampler + assert sample_scf.InterpolatedSCFSampler is InterpolatedSCFSampler diff --git a/sample_scf/tests/test_representation.py b/sample_scf/tests/test_representation.py new file mode 100644 index 0000000..4fa7165 --- /dev/null +++ b/sample_scf/tests/test_representation.py @@ -0,0 +1,516 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`scample_scf.representation`.""" + + +############################################################################## +# IMPORTS + +# STDLIB +import contextlib +import re + +# THIRD PARTY +import astropy.units as u +import pytest +from astropy.coordinates import CartesianRepresentation, Distance, PhysicsSphericalRepresentation +from astropy.coordinates import SphericalRepresentation, UnitSphericalRepresentation +from astropy.units import Quantity, UnitConversionError, allclose +from numpy import eye, pi, sin, cos, ndarray, inf, array, atleast_1d + +# LOCAL +from sample_scf.representation import FiniteSphericalRepresentation, r_of_zeta, theta_of_x +from sample_scf.representation import x_of_theta, zeta_of_r + +############################################################################## +# TESTS +############################################################################## + + +class Test_FiniteSphericalRepresentation: + """Test :class:`sample_scf.FiniteSphericalRepresentation`.""" + + def setup_class(self): + """Setup class for testing.""" + + self.phi = 0 * u.rad + self.x = 0 + self.zeta = 0 + + @pytest.fixture + def rep_cls(self): + return FiniteSphericalRepresentation + + @pytest.fixture + def differentials(self): + return None # TODO! maybe as a class-level parametrize + + @pytest.fixture + def scale_radius(self): + return 8 * u.kpc + + @pytest.fixture + def rep(self, rep_cls, scale_radius, differentials): + return rep_cls( + self.phi, + x=self.x, + zeta=self.zeta, + scale_radius=scale_radius, + copy=False, + differentials=differentials, + ) + + # =============================================================== + # Method Tests + + def test_init_simple(self, rep_cls): + """ + Test initializing an FiniteSphericalRepresentation. + This is actually mostly tested by the pytest fixtures, which will fail + if bad input is given. + """ + rep = rep_cls(phi=1 * u.deg, x=-1, zeta=0, scale_radius=8 * u.kpc) + + assert isinstance(rep, rep_cls) + assert (rep.phi, rep.x, rep.zeta) == (1 * u.deg, -1, 0) + assert rep.scale_radius == 8 * u.kpc + + def test_init_dimensionless_radius(self, rep_cls): + """Test initialization when scale radius is unit-less.""" + rep = rep_cls(phi=1 * u.deg, x=-1, zeta=0, scale_radius=8) + + assert isinstance(rep, rep_cls) + assert (rep.phi, rep.x, rep.zeta) == (1 * u.deg, -1, 0) + assert rep.scale_radius == 8 + + def test_init_x_is_theta(self, rep_cls): + """Test initialization when x has angular units.""" + rep = rep_cls(phi=1 * u.deg, x=90 * u.deg, zeta=0, scale_radius=8 * u.kpc) + + assert isinstance(rep, rep_cls) + assert rep.phi == 1 * u.deg + assert allclose(rep.x, 0, atol=1e-16) + assert rep.zeta == 0 + assert rep.scale_radius == 8 * u.kpc + + def test_init_zeta_is_r(self, rep_cls): + """Test initialization when zeta has units of length.""" + # When scale_radius is None + rep = rep_cls(phi=1 * u.deg, x=-1, zeta=8 * u.kpc) + assert isinstance(rep, rep_cls) + assert (rep.phi, rep.x, rep.zeta) == (1 * u.deg, -1, 7 / 9) + assert rep.scale_radius == 1 * u.kpc + + # When scale_radius is not None + rep = rep_cls(phi=1 * u.deg, x=-1, zeta=8 * u.kpc, scale_radius=8 * u.kpc) + assert isinstance(rep, rep_cls) + assert (rep.phi, rep.x, rep.zeta) == (1 * u.deg, -1, 0) + assert rep.scale_radius == 8 * u.kpc + + # Scale radius must match the units of zeta + with pytest.raises(TypeError, match="scale_radius must be a Quantity"): + rep_cls(phi=1 * u.deg, x=-1, zeta=8 * u.kpc, scale_radius=8) + + def test_init_needs_scale_radius(self, rep_cls): + """ + Test initialization when zeta is correctly unit-less, but no scale + radius was given. + """ + with pytest.raises(ValueError, match="if zeta is not a length"): + rep_cls(phi=1 * u.deg, x=-1, zeta=0) + + def test_init_x_out_of_bounds(self, rep_cls): + """ + Test initialization when transformed inclination angle is out of bounds. + """ + with pytest.raises(ValueError, match=re.escape("inclination angle(s) must be within")): + rep_cls(phi=1 * u.deg, x=-2, zeta=1 * u.kpc) + + with pytest.raises(ValueError, match=re.escape("inclination angle(s) must be within")): + rep_cls(phi=1 * u.deg, x=2, zeta=1 * u.kpc) + + def test_init_zeta_out_of_bounds(self, rep_cls): + """Test initialization when transformed distance is out of bounds.""" + with pytest.raises(ValueError, match="distances must be within"): + rep_cls(phi=1 * u.deg, x=0, zeta=-2, scale_radius=1) + + with pytest.raises(ValueError, match="distances must be within"): + rep_cls(phi=1 * u.deg, x=0, zeta=2, scale_radius=1) + + # ------------------------------------------- + + def test_phi(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.phi`.""" + # class + assert isinstance(rep_cls.phi, property) + + # instance + assert rep.phi is rep._phi + assert isinstance(rep.phi, Quantity) + assert rep.phi.unit.physical_type == "angle" + + def test_x(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.x`.""" + # class + assert isinstance(rep_cls.x, property) + + # instance + assert rep.x is rep._x + assert isinstance(rep.x, Quantity) + assert rep.x.unit.physical_type == "dimensionless" + + def test_zeta(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.zeta`.""" + # class + assert isinstance(rep_cls.zeta, property) + + # instance + assert rep.zeta is rep._zeta + assert isinstance(rep.zeta, Quantity) + assert rep.zeta.unit.physical_type == "dimensionless" + + def test_scale_radius(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.scale_radius`.""" + # class + assert isinstance(rep_cls.scale_radius, property) + + # instance + assert rep.scale_radius is rep._scale_radius + assert isinstance(rep.scale_radius, Quantity) + assert rep.scale_radius.unit.physical_type == "length" + + # ----------------------------------------------------- + # corresponding PhysicsSpherical coordinates + + def test_theta(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.theta`.""" + # class + assert isinstance(rep_cls.theta, property) + + # instance + assert rep.theta == rep.calculate_theta_of_x(rep.x) + assert isinstance(rep.theta, Quantity) + assert rep.theta.unit.physical_type == "angle" + + def test_r(self, rep_cls, rep): + """Test :attr:`sample_scf.FiniteSphericalRepresentation.r`.""" + # class + assert isinstance(rep_cls.r, property) + + # instance + assert rep.r == rep.calculate_r_of_zeta(rep.zeta) + assert isinstance(rep.r, Distance) + assert rep.r.unit == rep.scale_radius.unit + assert rep.r.unit.physical_type == "length" + + # ----------------------------------------------------- + # conversion functions + # TODO! from below tests + + @pytest.mark.skip("TODO!") + def test_calculate_zeta_of_r(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.calculate_zeta_of_r`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_calculate_r_of_zeta(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.calculate_r_of_zeta`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_calculate_x_of_theta(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.calculate_x_of_theta`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_calculate_theta_of_x(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.calculate_theta_of_x`.""" + assert False + + # ----------------------------------------------------- + + @pytest.mark.skip("TODO!") + def test_unit_vectors(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.unit_vectors`.""" + assert False + + @pytest.mark.skip("TODO!") + def test_scale_factors(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.scale_factors`.""" + assert False + + # -------------------------------------------- + + def test_represent_as_PhysicsSphericalRepresentation(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.represent_as`.""" + r = rep.represent_as(PhysicsSphericalRepresentation) + assert allclose(r.phi, rep.phi) + assert allclose(r.theta, rep.theta) + assert allclose(r.r, rep.r) + + def test_represent_as_SphericalRepresentation(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.represent_as`.""" + r = rep.represent_as(SphericalRepresentation) + assert allclose(r.lon, rep.phi) + assert allclose(r.lat, 90 * u.deg - rep.theta) + assert allclose(r.distance, rep.r) + + def test_represent_as_UnitSphericalRepresentation(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.represent_as`.""" + r = rep.represent_as(UnitSphericalRepresentation) + assert allclose(r.lon, rep.phi) + assert allclose(r.lat, 90 * u.deg - rep.theta) + + def test_represent_as_CartesianRepresentation(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.represent_as`.""" + assert rep.represent_as(CartesianRepresentation) == rep.to_cartesian() + + # -------------------------------------------- + + def test_to_cartesian(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.to_cartesian`.""" + r = rep.to_cartesian() + + x = rep.r * sin(rep.theta) * cos(rep.phi) + y = rep.r * sin(rep.theta) * sin(rep.phi) + z = rep.r * cos(rep.theta) + + assert allclose(r.x, x) + assert allclose(r.y, y) + assert allclose(r.z, z) + + def test_from_cartesian(self, rep_cls, rep, scale_radius): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.from_cartesian`.""" + cart = rep.to_cartesian() + + # Not passing a scale radius + r = rep_cls.from_cartesian(cart) + assert rep != r + + r = rep_cls.from_cartesian(cart, scale_radius=scale_radius) + assert allclose(rep.phi, r.phi) + assert allclose(rep.theta, r.theta) + assert allclose(rep.zeta, r.zeta) + + def test_from_physicsspherical(self, rep_cls, rep, scale_radius): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.from_physicsspherical`.""" + psphere = rep.represent_as(PhysicsSphericalRepresentation) + + # Not passing a scale radius + r = rep_cls.from_physicsspherical(psphere) + assert rep != r + + r = rep_cls.from_physicsspherical(psphere, scale_radius=scale_radius) + assert allclose(rep.phi, r.phi) + assert allclose(rep.theta, r.theta) + assert allclose(rep.zeta, r.zeta) + + def test_transform(self, rep, scale_radius): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.transform`.""" + # Identity + matrix = eye(3) + r = rep.transform(matrix, scale_radius) + assert allclose(rep.phi, r.phi) + assert allclose(rep.theta, r.theta) + assert allclose(rep.zeta, r.zeta) + + # alternating coordinates + matrix = array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + r = rep.transform(matrix, scale_radius) + assert allclose(rep.phi, r.phi - pi / 2 * u.rad) + assert allclose(rep.theta, r.theta) + assert allclose(rep.zeta, r.zeta) + + def test_norm(self, rep): + """Test :meth:`sample_scf.FiniteSphericalRepresentation.norm`.""" + assert rep.norm() == abs(rep.zeta) + + +############################################################################## + + +def test_zeta_of_r_fail(): + """Test :func:`sample_scf.representation.r_of_zeta` with wrong r type.""" + # Negative value + with pytest.raises(ValueError, match="r must be >= 0"): + zeta_of_r(-1) + + # Type mismatch + with pytest.raises(TypeError, match="scale radius cannot be a Quantity"): + zeta_of_r(1, scale_radius=8 * u.kpc) + + # Negative value + with pytest.raises(ValueError, match="scale_radius must be > 0"): + zeta_of_r(1, scale_radius=-1) + + +@pytest.mark.parametrize( + "r, scale_radius, expected, warns", + [ + (0, None, -1.0, False), + (1, None, 0.0, False), + (inf, None, 1.0, RuntimeWarning), # edge case + (10, None, 9 / 11, False), + ([0, 1, inf], None, [-1.0, 0.0, 1.0], False), + ([0, 1, inf], None, [-1.0, 0.0, 1.0], False), + ], +) +def test_zeta_of_r_ArrayLike(r, scale_radius, expected, warns): + """Test :func:`sample_scf.representation.r_of_zeta` with wrong r type.""" + with pytest.warns(warns) if warns is not False else contextlib.nullcontext(): + zeta = zeta_of_r(r, scale_radius=scale_radius) # TODO! scale radius + + assert allclose(zeta, expected) + assert not isinstance(zeta, Quantity) + + +def test_zeta_of_r_Quantity_fail(): + """Test :func:`sample_scf.representation.r_of_zeta`: r=Quantity, with errors.""" + # Wrong units + with pytest.raises(UnitConversionError, match="r must have units of length"): + zeta_of_r(8 * u.s) + + # Negative value + with pytest.raises(ValueError, match="r must be >= 0"): + zeta_of_r(-1 * u.kpc) + + # Type mismatch + with pytest.raises(TypeError, match="scale_radius must be a Quantity"): + zeta_of_r(8 * u.kpc, scale_radius=1) + + # Wrong units + with pytest.raises(UnitConversionError, match="scale_radius must have units of length"): + zeta_of_r(8 * u.kpc, scale_radius=1 * u.s) + + # Non-positive value + with pytest.raises(ValueError, match="scale_radius must be > 0"): + zeta_of_r(1 * u.kpc, scale_radius=-1 * u.kpc) + + +@pytest.mark.parametrize( + "r, scale_radius, expected, warns", + [ + (0 * u.kpc, None, -1.0, False), + (1 * u.kpc, None, 0.0, False), + (inf * u.kpc, None, 1.0, RuntimeWarning), # edge case + (10 * u.km, None, 9 / 11, False), + ([0, 1, inf] * u.kpc, None, [-1.0, 0.0, 1.0], False), + ([0, 1, inf] * u.km, None, [-1.0, 0.0, 1.0], False), + ], +) +def test_zeta_of_r_Quantity(r, scale_radius, expected, warns): + """Test :func:`sample_scf.representation.r_of_zeta` with wrong r type.""" + with pytest.warns(warns) if warns is not False else contextlib.nullcontext(): + zeta = zeta_of_r(r, scale_radius=scale_radius) # TODO! scale radius + + assert allclose(zeta, expected) + assert isinstance(zeta, Quantity) + assert zeta.unit.physical_type == "dimensionless" + + +@pytest.mark.parametrize("r", [0 * u.kpc, 1 * u.kpc, inf * u.kpc, [0, 1, inf] * u.kpc]) +def test_zeta_of_r_roundtrip(r): + """Test zeta and r round trip. Note that Quantities don't round trip.""" + assert allclose(r_of_zeta(zeta_of_r(r, None), 1), r.value) + # TODO! scale radius + + +# ----------------------------------------------------- + + +@pytest.mark.parametrize( + "zeta, expected, warns", + [ + (-1.0, 0, False), + (0.0, 1, False), + (1.0, inf, RuntimeWarning), # edge case + (array([-1.0, 0.0, 1.0]), [0, 1, inf], False), + ], +) +def test_r_of_zeta(zeta, expected, warns): + """Test :func:`sample_scf.representation.r_of_zeta`.""" + with pytest.warns(warns) if warns is not False else contextlib.nullcontext(): + r = r_of_zeta(zeta, 1) + + assert allclose(r, expected) # TODO! scale_radius + assert isinstance(r, ndarray) + + +def test_r_of_zeta_fail(): + """Test when the input is bad.""" + # Under lower bound + with pytest.raises(ValueError, match="zeta must be in"): + r_of_zeta(-2) + + # Above upper bound + with pytest.raises(ValueError, match="zeta must be in"): + r_of_zeta(2) + + +@pytest.mark.parametrize( + "zeta, scale_radius, expected", + [ + (0, 1 * u.pc, 1 * u.pc), + ], +) +def test_r_of_zeta_unit_input(zeta, expected, scale_radius): + """Test when input units.""" + assert allclose(r_of_zeta(zeta, scale_radius), expected) + + +@pytest.mark.skip("TODO!") +@pytest.mark.parametrize("zeta", [-1, 0, 1, [-1, 0, 1]]) +def test_r_of_zeta_roundtrip(zeta): + """Test zeta and r round trip. Note that Quantities don't round trip.""" + assert allclose(zeta_of_r(r_of_zeta(zeta, None), None), zeta) + + +# ----------------------------------------------------- + + +@pytest.mark.parametrize( + "theta, expected", + [ + (0, 1), + (pi / 2, 0), + (pi, -1), + ([0, pi / 2, pi], [1, 0, -1]), # array + # with units + (0 << u.rad, 1), + (pi / 2 << u.rad, 0), + (pi << u.rad, -1), + ([pi, pi / 2, 0] << u.rad, [-1, 0, 1]), # array + ], +) +def test_x_of_theta(theta, expected): + """Test :func:`sample_scf.representation.x_of_theta`.""" + assert allclose(x_of_theta(theta), expected, atol=1e-16) + + +@pytest.mark.parametrize("theta", [0, pi / 2, pi, [0, pi / 2, pi]]) # TODO! units +def test_theta_of_x_roundtrip(theta): + """Test theta and x round trip. Note that Quantities don't round trip.""" + assert allclose(theta_of_x(x_of_theta(theta)), theta << u.rad) + + +# ----------------------------------------------------- + + +@pytest.mark.parametrize( + "x, expected", + [ + (-1, pi), + (0, pi / 2), + (1, 0), + ([-1, 0, 1], [pi, pi / 2, 0]), # array + ], +) +def test_theta_of_x(x, expected): + """Test :func:`sample_scf.representation.theta_of_x`.""" + assert allclose(theta_of_x(x), expected << u.rad) # TODO! units + + +@pytest.mark.parametrize("x", [-1, 0, 1, [-1, 0, 1]]) +def test_roundtrip(x): + """Test x and theta round trip. Note that Quantities don't round trip.""" + assert allclose(x_of_theta(theta_of_x(x)), x, atol=1e-16) # TODO! units diff --git a/sample_scf/utils.py b/sample_scf/utils.py new file mode 100644 index 0000000..650ca2e --- /dev/null +++ b/sample_scf/utils.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +"""Local Utilities.""" + + +__all__ = ["plot_corner_samples", "log_prior", "log_prob"] + + +############################################################################## +# IMPORTS + +# STDLIB +from typing import Optional, Tuple, Union + +# THIRD PARTY +import astropy.units as u +import corner +from astropy.coordinates import BaseRepresentation, CartesianRepresentation +from galpy.potential import SCFPotential + +# PROJECT-SPECIFIC +from matplotlib.figure import Figure +from numpy import abs, arctan2, floating, inf, isfinite, log, nan_to_num, ndarray, sign, sqrt +from numpy import square, sum + +############################################################################## +# CODE +############################################################################## + + +def plot_corner_samples( + samples: Union[BaseRepresentation, ndarray], + r_limit: float = 1_000.0, + *, + figs: Optional[Tuple[Figure, Figure]] = None, + include_log: bool = True, + **kw, +) -> Tuple[Figure, Figure]: + """Plot samples. + + Parameters + ---------- + *samples : BaseRepresentation or (N, 3) ndarray + If an `numpy.ndarray`, samples should be in Cartesian coordinates. + r_limit : float + Largerst radius that should be plotted. + Values larger will be masked. + figs : tuple[Figure, Figure] or None, optional keyword-only + include_log : bool, optional keyword-only + + Returns + ------- + tuple[Figure, Figure] + """ + # Convert to ndarray + arr: ndarray + if isinstance(samples, BaseRepresentation): + arr = samples.represent_as(CartesianRepresentation)._values.view(float).reshape(-1, 3) + else: + arr = samples + + # Correcting for large r + r = sqrt(sum(square(arr), axis=1)) + mask = r <= r_limit + + # plot stuff + truths = [0, 0, 0] + hist_kwargs = {"density": True} + hist_kwargs.update(kw.pop("hist_kwargs", {})) + kw.pop("plot_contours", None) + kw.pop("plot_density", None) + + # ----------- + # normal plot + + labels = ["x", "y", "z"] + + fig1 = corner.corner( + arr[mask, :], + labels=labels, + raster=True, + bins=50, + truths=truths, + show_titles=True, + title_kwargs={"fontsize": 12}, + label_kwargs={"fontsize": 13}, + hist_kwargs=hist_kwargs, + plot_contours=False, + plot_density=False, + fig=None if figs is None else figs[0], + **kw, + ) + fig1.suptitle("Samples") + + # ----------- + # logarithmic plot + + if not include_log: + fig2 = None + + else: + + labels = [r"$\log_{10}(x)$", r"$\log_{10}(y)$", r"$\log_{10}(z)$"] + + fig2 = corner.corner( + nan_to_num(sign(arr) * log(abs(arr))), + labels=labels, + raster=True, + bins=50, + truths=truths, + show_titles=True, + title_kwargs={"fontsize": 12}, + label_kwargs={"fontsize": 13}, + fig=None if figs is None else figs[1], + plot_contours=False, + plot_density=False, + hist_kwargs=hist_kwargs, + **kw, + ) + fig2.suptitle("Samples") + + return fig1, fig2 + + +def log_prior(R: floating, r_limit: floating) -> floating: + """Log-Prior. + + Parameters + ---------- + R : float + r_limit : float + + Returns + ------- + float + """ + # outside + if r_limit is not None and R > r_limit: + return -inf + return 0.0 + + +def log_prob( + x: ndarray, /, pot: SCFPotential, rho0: u.Quantity, r_limit: Optional[floating] = None +) -> floating: + """Log-Probability. + + Parameters + ---------- + x : (3, ) array + Cartesian coordinates in kpc + pot : `galpy.potential.SCFPotential` + rho0 : Quantity + The central density. + r_limit : float + + Returns + ------- + float + """ + # Checks + if rho0 == 0: + raise ValueError("`mtot` cannot be 0.") + elif r_limit == 0: + raise ValueError("`r_limit` cannot be 0.") + + # convert Cartesian to Cylindrical coordinates + R = sqrt(sum(square(x))) + z = x[-1] + phi = arctan2(x[1], x[0]) + + # calculate log-prior + lp = log_prior(R, r_limit) + if not isfinite(lp): + return lp + + # the density as a likelihood + logrho0 = log(rho0.value) + dens = pot.dens(R, z, phi).to_value(rho0.unit) + + logdens = nan_to_num(log(dens), copy=False, nan=logrho0, posinf=logrho0) + ll = logdens - logrho0 # normalize the density + + return lp + ll diff --git a/setup.cfg b/setup.cfg index f58a21e..4bf7c73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,7 @@ setup_requires = setuptools_scm install_requires = astropy extension_helpers + galpy matplotlib mypy numpy >= 1.20 @@ -31,7 +32,7 @@ docs = sphinx-astropy [options.package_data] -sample_scf = data/* +sample_scf = data/*, tests/scf_coeffs.npz [tool:pytest] testpaths = "sample_scf" "docs" diff --git a/setup.py b/setup.py index 88b86e2..81872f8 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ # NOTE: The configuration for the package, including the name, version, and # other information are set in the setup.cfg file. -# BUILT-IN +# STDLIB import os import sys @@ -13,6 +13,8 @@ from extension_helpers import get_extensions from setuptools import setup +# from mypyc.build import mypycify + # First provide helpful messages if contributors try and run legacy commands # for tests or docs. @@ -77,10 +79,23 @@ version = '{version}' """.lstrip() +# # TODO! model after https://github.com/python/mypy/blob/master/setup.py +# mypyc_targets = [ +# os.path.join("sample_scf", x) +# for x in ("__init__.py", "base.py", "core.py", "utils.py", "interpolated.py", +# "exact.py") +# ] +# # The targets come out of file system apis in an unspecified +# # order. Sort them so that the mypyc output is deterministic. +# mypyc_targets.sort() + setup( use_scm_version={ "write_to": os.path.join("sample_scf", "version.py"), "write_to_template": VERSION_TEMPLATE, }, ext_modules=get_extensions(), + # name="sample_scf", + # packages=["sample_scf"], + # ext_modules=mypycify(["--disallow-untyped-defs", *mypyc_targets]), ) diff --git a/tox.ini b/tox.ini index 9f2f175..f09cdea 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ isolated_build = true indexserver = NIGHTLY = https://pypi.anaconda.org/scipy-wheels-nightly/simple + [testenv] # Suppress display of matplotlib plots generated during docs build setenv = MPLBACKEND=agg @@ -69,6 +70,7 @@ commands = cov: pytest --pyargs sample_scf {toxinidir}/docs --cov sample_scf --cov-config={toxinidir}/setup.cfg {posargs} cov: coverage xml -o {toxinidir}/coverage.xml + [testenv:build_docs] changedir = docs description = invoke sphinx-build to build the HTML docs @@ -77,6 +79,7 @@ commands = pip freeze sphinx-build -W -b html . _build/html + [testenv:linkcheck] changedir = docs description = check the links in the HTML docs @@ -85,9 +88,17 @@ commands = pip freeze sphinx-build -W -b linkcheck . _build/html + [testenv:codestyle] skip_install = true changedir = . description = check code style, e.g. with flake8 deps = flake8 commands = flake8 sample_scf --count --max-line-length=100 + + +[flake8] +max-line-length = 100 +ignore = + E203, # space before colon + W503 # line break before binary operator