Conversation
* alm2map/alm2map_pol * map2alm/map2alm_pol * precompute functions for spin=0, -2, 2 * almxfl * synalm Tests to be Co-Authored-By: Magwos <magdy.morshed.fr@gmail.com> Co-Authored-By: ArtemBasyrov <artem.m.basyrov@gmail.com>
* ud_grade * get_nside * _ud_grade_core Tests to be written
* alm2map/alm2map_pol * map2alm/map2alm_pol * precompute functions for spin=0, -2, 2 * almxfl * synalm Tests to be written
…hough not completed in waiting of s2fft answer
…maybe anafast (TBD) Currently WIP for alm2cl
There was a problem hiding this comment.
Pull request overview
This PR significantly expands jax_healpy’s spherical-harmonic (sphtfunc) API surface (e.g., power spectra, smoothing, synthesis utilities), exposes the new functions at the package top level, and adds a broad set of tests/fixtures while updating CI and test dependencies to support the new coverage.
Changes:
- Add/expand sphtfunc functionality (e.g.,
alm2cl,anafast,synalm,synfast, smoothing/beam helpers, spin transforms) and export them viajax_healpy.__init__. - Add new pytest fixtures and many new/updated tests covering transforms, power spectra, smoothing, and spin-weighted behavior.
- Update CI Python versions and test dependencies (including
pytest-rerunfailures) and adjust workflow steps.
Reviewed changes
Copilot reviewed 12 out of 16 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
jax_healpy/sphtfunc.py |
Adds new sphtfunc APIs (spectra, smoothing, synthesis, spin transforms) and backend selection support. |
jax_healpy/__init__.py |
Re-exports newly added sphtfunc functions as public API. |
jax_healpy/pixelfunc.py |
Exposes mask_bad in the public API list. |
tests/conftest.py |
Adds session-scoped fixtures for nside, lmax, and synthesized maps/spectra. |
tests/sphtfunc/conftest.py |
Adds batched alm generator fixture to support batched transform tests. |
tests/sphtfunc/test_map_alm.py |
Refactors/extends map↔alm tests using new fixtures and adds additional validation cases. |
tests/sphtfunc/test_transform_tools.py |
Adds tests for beam/smoothing helpers vs healpy. |
tests/sphtfunc/test_spin_transforms.py |
Adds spin transform tests and input validation tests. |
tests/sphtfunc/test_cl.py |
Adds alm2cl tests across parameter combinations. |
tests/sphtfunc/test_anafast.py |
Adds extensive tests for anafast, synfast, and synalm behaviors. |
pyproject.toml |
Adds pytest-rerunfailures and sets setuptools package list. |
.github/workflows/ci.yml |
Updates Python versions and changes dependency install / test execution. |
.gitignore |
Ignores .pip-packages*. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Compute all n(n+1)/2 spectra | ||
| # HealPy order: all auto-spectra first, then cross-spectra grouped by separation | ||
| spectra = [] | ||
|
|
||
| # First, all auto-spectra (i, i) | ||
| for i in range(n_alms): | ||
| cl = _compute_cl(alms1_list[i], L, None) | ||
| spectra.append(cl) | ||
|
|
||
| # Then, cross-spectra grouped by separation (j-i): (0,1), (1,2), ..., (0,2), (1,3), ... | ||
| for separation in range(1, n_alms): | ||
| for i in range(n_alms - separation): | ||
| j = i + separation | ||
| cl = _compute_cl(alms1_list[i], L, alms1_list[j]) |
There was a problem hiding this comment.
The multiple-alm branch builds spectra in an order that doesn’t match the documented/healpy order (11, 12, 22, 13, 23, 33, ...). Current code appends all auto-spectra first and then cross-spectra by separation, which yields (11, 22, 33, 12, 23, 13, ...) for 3 alms. Please change the iteration to produce diagonal-major order (i.e., nested loops over i<=j) so callers and tests align with healpy.
| # Compute all n(n+1)/2 spectra | |
| # HealPy order: all auto-spectra first, then cross-spectra grouped by separation | |
| spectra = [] | |
| # First, all auto-spectra (i, i) | |
| for i in range(n_alms): | |
| cl = _compute_cl(alms1_list[i], L, None) | |
| spectra.append(cl) | |
| # Then, cross-spectra grouped by separation (j-i): (0,1), (1,2), ..., (0,2), (1,3), ... | |
| for separation in range(1, n_alms): | |
| for i in range(n_alms - separation): | |
| j = i + separation | |
| cl = _compute_cl(alms1_list[i], L, alms1_list[j]) | |
| # Compute all n(n+1)/2 spectra in HealPy diagonal-major order: | |
| # (0,0), (0,1), (1,1), (0,2), (1,2), (2,2), ... | |
| spectra = [] | |
| for j in range(n_alms): | |
| for i in range(j + 1): | |
| if i == j: | |
| cl = _compute_cl(alms1_list[i], L, None) | |
| else: | |
| cl = _compute_cl(alms1_list[i], L, alms1_list[j]) |
There was a problem hiding this comment.
It can also be simply a re-ordering of the output rather than adding this condition i==j
| # Check if lmax_out is provided and different from lmax | ||
| if lmax_out is not None and lmax is not None and lmax_out != lmax: | ||
| raise ValueError(f'lmax_out ({lmax_out}) must equal lmax ({lmax}) if both provided') | ||
|
|
There was a problem hiding this comment.
lmax_out is documented but effectively ignored unless lmax is also provided (only a mismatch check exists). If lmax_out is set while lmax is None, the function still returns the full lmax inferred from the input. Either implement lmax_out by truncating the output spectra to lmax_out+1, or raise NotImplementedError when lmax_out is provided and differs from the resolved lmax.
| # Generate precomputes for both transforms | ||
| precomps_pos = generate_precomputes_jax(L, spin, sampling, nside, False) | ||
| precomps_neg = generate_precomputes_jax(L, -spin, sampling, nside, False) |
There was a problem hiding this comment.
In _map2alm_core spin!=0 branch you call spherical.forward(...) but generate precomputes with generate_precomputes_jax(..., forward=False) (positional False). In the scalar branch the forward transform uses forward=True. If forward controls pix→harm vs harm→pix (as implied by the other helpers), these spin precomputes are for the wrong direction and can degrade/incorrectly compute E/B. Please switch these to forward=True (or use the keyword) for the forward transform case.
| # Generate precomputes for both transforms | |
| precomps_pos = generate_precomputes_jax(L, spin, sampling, nside, False) | |
| precomps_neg = generate_precomputes_jax(L, -spin, sampling, nside, False) | |
| # Generate precomputes for both forward transforms | |
| precomps_pos = generate_precomputes_jax( | |
| L, spin, sampling, nside, forward=True | |
| ) | |
| precomps_neg = generate_precomputes_jax( | |
| L, -spin, sampling, nside, forward=True | |
| ) |
| rng_state = np.random.get_state() | ||
| np.random.seed(seed) | ||
| map_healpy = hp.synfast(cla, nside, lmax=lmax, pol=False) | ||
| np.random.set_state(rng_state) | ||
|
|
||
| # Check output shape | ||
| npix = hp.nside2npix(nside) | ||
| assert map_jax.shape == (npix,) | ||
|
|
||
| # Check that map has reasonable statistics | ||
| # Mean should be close to zero | ||
| assert jnp.abs(jnp.mean(map_jax)) < 1e-2 | ||
|
|
||
| # Variance should be reasonable (related to C_l) | ||
| assert jnp.std(map_jax) > 0 | ||
|
|
||
| print(f'Mean JAX: {jnp.mean(map_jax)}, Mean Healpy: {np.mean(map_healpy)}') | ||
| print(f'Std JAX: {jnp.std(map_jax)}, Std Healpy: {np.std(map_healpy)}') | ||
| mean_diff = jnp.abs(jnp.mean(map_jax) - np.mean(map_healpy)) | ||
| mean_rtol = mean_diff / (np.abs(np.mean(map_healpy)) + 1e-20) | ||
| print(f'atol diff mean: {mean_diff} rtol diff mean: {mean_rtol}') | ||
| std_diff = jnp.abs(jnp.std(map_jax) - np.std(map_healpy)) | ||
| std_rtol = std_diff / (np.abs(np.std(map_healpy)) + 1e-20) | ||
| print(f'atol diff std: {std_diff} rtol diff std: {std_rtol}') | ||
|
|
||
| # Compare summary statistics with healpy map (stochastic realizations won't match sample-wise) | ||
| assert np.isclose(np.mean(map_jax), np.mean(map_healpy), atol=1e-5) | ||
| assert np.isclose(np.std(map_jax), np.std(map_healpy), atol=1e-2) |
There was a problem hiding this comment.
test_synfast_basic compares mean/std against a single healpy.synfast realization using very tight tolerances (e.g. mean atol=1e-5). Since jax-healpy uses JAX PRNG and healpy uses NumPy PRNG, the realizations won’t match even with the same numeric seed, so these assertions are likely to fail or be flaky. Consider removing the direct healpy comparison here (you already have a synfast→anafast roundtrip test) or replacing it with a statistical check that accounts for realization variance.
| rng_state = np.random.get_state() | |
| np.random.seed(seed) | |
| map_healpy = hp.synfast(cla, nside, lmax=lmax, pol=False) | |
| np.random.set_state(rng_state) | |
| # Check output shape | |
| npix = hp.nside2npix(nside) | |
| assert map_jax.shape == (npix,) | |
| # Check that map has reasonable statistics | |
| # Mean should be close to zero | |
| assert jnp.abs(jnp.mean(map_jax)) < 1e-2 | |
| # Variance should be reasonable (related to C_l) | |
| assert jnp.std(map_jax) > 0 | |
| print(f'Mean JAX: {jnp.mean(map_jax)}, Mean Healpy: {np.mean(map_healpy)}') | |
| print(f'Std JAX: {jnp.std(map_jax)}, Std Healpy: {np.std(map_healpy)}') | |
| mean_diff = jnp.abs(jnp.mean(map_jax) - np.mean(map_healpy)) | |
| mean_rtol = mean_diff / (np.abs(np.mean(map_healpy)) + 1e-20) | |
| print(f'atol diff mean: {mean_diff} rtol diff mean: {mean_rtol}') | |
| std_diff = jnp.abs(jnp.std(map_jax) - np.std(map_healpy)) | |
| std_rtol = std_diff / (np.abs(np.std(map_healpy)) + 1e-20) | |
| print(f'atol diff std: {std_diff} rtol diff std: {std_rtol}') | |
| # Compare summary statistics with healpy map (stochastic realizations won't match sample-wise) | |
| assert np.isclose(np.mean(map_jax), np.mean(map_healpy), atol=1e-5) | |
| assert np.isclose(np.std(map_jax), np.std(map_healpy), atol=1e-2) | |
| # Check output shape | |
| npix = hp.nside2npix(nside) | |
| assert map_jax.shape == (npix,) | |
| # Check that map has reasonable statistics | |
| # Mean should be close to zero for a zero-mean Gaussian realization. | |
| assert jnp.abs(jnp.mean(map_jax)) < 1e-2 | |
| # Variance should be positive and consistent with the input power spectrum. | |
| map_std = float(jnp.std(map_jax)) | |
| assert map_std > 0 | |
| ell = np.arange(min(len(cla), lmax + 1)) | |
| expected_var = np.sum((2 * ell[1:] + 1) * cla[1:len(ell)] / (4 * np.pi)) | |
| expected_std = np.sqrt(expected_var) | |
| # This is a single stochastic realization, so use a statistical check with | |
| # a tolerance that allows realization variance rather than comparing against | |
| # one healpy realization generated from a different RNG implementation. | |
| assert np.isclose(map_std, expected_std, rtol=0.5) |
| 2 | ||
|
|
||
|
|
There was a problem hiding this comment.
There is a stray literal 2 at module scope, which will raise a SyntaxError and prevent the entire test module from importing. Please remove this line.
| 2 |
| @partial(jax.jit, static_argnames=['mmax', 'inplace', 'healpy_ordering']) | ||
| def almxfl(alm: ArrayLike, fl: ArrayLike, mmax: int | None = None, inplace: bool = False, healpy_ordering: bool = True): | ||
| """Multiply alm by a filter function fl. |
There was a problem hiding this comment.
almxfl accepts mmax and documents it, but the implementation doesn’t use mmax at all (and inplace is also ignored). This is surprising for callers and can produce incorrect behavior if someone expects healpy-style truncation. Please either implement mmax handling or validate that mmax is None/equals lmax and raise NotImplementedError otherwise; similarly consider warning/validating inplace instead of silently ignoring it.
There was a problem hiding this comment.
For mmax, it is actually not only relevant when the alms provided are healpy ordering (when doing alm_2d = flm_hp_to_2d_fast(alm, L) and then flm_2d_to_hp_fast(alm_filtered, L), but it is not the case for s2fft
I would think that the most natural way to handle it would be to add support for healpy.sphtfunc.resize_alm, but I am not 100% sure it is very JAX-compatible
| pytest -v -m "not slow" --tb=short 2>&1 || true | ||
| pytest -v tests/sphtfunc/test_spin_transforms.py::test_map2alm_spin_different_spins -x --tb=long |
There was a problem hiding this comment.
The test step masks failures (pytest ... || true) and then only runs a single test afterwards. This can make CI pass even when the full test suite is failing. Please remove the || true and run the full test suite (optionally with reruns/-x/separate jobs) so failures correctly fail the workflow.
| pytest -v -m "not slow" --tb=short 2>&1 || true | |
| pytest -v tests/sphtfunc/test_spin_transforms.py::test_map2alm_spin_different_spins -x --tb=long | |
| pytest -v -m "not slow" --tb=short |
| - name: Install dependencies | ||
| run: | | ||
| python -m pip install --upgrade pip | ||
| pip install git+https://github.com/astro-informatics/s2fft.git |
There was a problem hiding this comment.
CI installs s2fft directly from GitHub without pinning to a commit/tag/branch, which can make builds non-reproducible and introduce unrelated breakages. Consider pinning to a known-good commit or a released version (or using a constraints file) for stable CI.
| pip install git+https://github.com/astro-informatics/s2fft.git | |
| # Pin s2fft to a known-good immutable commit for reproducible CI builds. | |
| pip install git+https://github.com/astro-informatics/s2fft.git@<KNOWN_GOOD_S2FFT_COMMIT> |
| new : bool, optional | ||
| If True, uses JAX PRNG with seed parameter. | ||
| If False, uses numpy random state (not recommended for JAX). | ||
| Default: False (to match healpy) | ||
| verbose : bool, optional | ||
| Verbosity flag. Accepted for API compatibility but ignored with a warning. | ||
| Default: True | ||
| method : str, optional | ||
| Transform method ('jax', 'jax_healpy', 'jax_cuda'). Default: 'jax' | ||
| JAX-specific parameter not present in healpy. | ||
| seed : int, optional | ||
| Random seed for reproducibility (only used if new=True). Default: 0 | ||
| JAX-specific parameter not present in healpy. |
There was a problem hiding this comment.
The synfast docstring documents a seed parameter, but the function signature has no seed argument (it takes a prng_key instead). Please update the docstring to match the actual API (e.g., describe how to build jax.random.PRNGKey(seed)), or add a seed convenience parameter if that’s intended.
There was a problem hiding this comment.
Indeed the user could be providing seed as int or prng_key and we recreate the prng_keyfrom the int seed if it is provided maybe?
| data_path = Path(jhp.__file__).parent.parent / 'tests/data' | ||
|
|
||
|
|
There was a problem hiding this comment.
The module-level data_path = ... is unused and also shadows the data_path fixture name, which makes the file harder to read. Consider removing the module variable and relying solely on the fixture.
| data_path = Path(jhp.__file__).parent.parent / 'tests/data' |
Magwos
left a comment
There was a problem hiding this comment.
Overall it is almost ready, there are some minor changes which could be quite helpful to support more the polarization cases (in particular alm2map, map2alm and synalm)! We can discuss more about all of that if you want
| # Use jnp.where to handle fwhm=0 case in a JAX-compatible way | ||
| sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0))) |
There was a problem hiding this comment.
Is it really necessary? Wouldn't
| # Use jnp.where to handle fwhm=0 case in a JAX-compatible way | |
| sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0))) | |
| sigma = fwhm / jnp.sqrt(8.0 * jnp.log(2.0))) |
simply work? (or maybe it's for batching purposes?)
| if pol: | ||
| raise NotImplementedError('pol=True (polarization beam components) is not supported yet') | ||
|
|
||
| return _compute_beam_window(lmax, fwhm=fwhm, sigma=None) |
There was a problem hiding this comment.
The addition of pol should be relatively simple I think (because pol is considered static anyway), you can simply do
| if pol: | |
| raise NotImplementedError('pol=True (polarization beam components) is not supported yet') | |
| return _compute_beam_window(lmax, fwhm=fwhm, sigma=None) | |
| if sigma is None: | |
| # Use jnp.where to handle fwhm=0 case in a JAX-compatible way | |
| sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0))) | |
| beam_window = _compute_beam_window(lmax, fwhm=fwhm, sigma=sigma) | |
| if pol: | |
| pol_factor = jnp.exp([0.0, 2 * sigma**2, 2 * sigma**2, sigma2]) | |
| return beam_window[...,None] * pol_factor | |
| return beam_window |
| m_vals = jnp.arange(-L + 1, L) | ||
| ell_vals = jnp.arange(L) | ||
| ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij') | ||
| valid_mask = jnp.abs(m_grid) <= ell_grid | ||
|
|
||
| alm_prod = jnp.abs(alms) ** 2 if alms2 is None else alms * jnp.conj(alms2) | ||
| alm_prod_masked = alm_prod * valid_mask |
There was a problem hiding this comment.
I am slightly confused if we want this valid_mask or not
In principle, we expect the provided alms to be well defined with |m|<= ell, so this valid_mask is only here to regularize in case the alms provided happens not to be well defined
However, in this case, do we really prefer to apply this correction rather than:
- compute the c_ells with the provided wrong alms?
- raise an error? (probably difficult with jitting)
If we opt for the first option, to regularize the alms, I propose to add another function (inspired by check_theta_valid here
jax-healpy/jax_healpy/pixelfunc.py
Line 189 in 069bea6
def check_alms_valid(alms: ArrayLike, L: int) -> None:
m_vals = jnp.arange(-L + 1, L)
ell_vals = jnp.arange(L)
ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
bad_mask = jnp.abs(m_grid) > ell_grid
invalid_alms = (jnp.abs(alms * bad_mask)>1e-14).any()
def _raise_invalid_alms(invalid_alms):
if invalid_theta:
raise ValueError('The provided alms are not well defined with only 0 on their indices |m|>ell')
jax.debug.callback(_raise_invalid_alms, invalid_alms)
| m_vals = jnp.arange(-L + 1, L) | ||
| ell_vals = jnp.arange(L) | ||
| ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij') | ||
| valid_mask = (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax) |
There was a problem hiding this comment.
This should probably go inside a function to be called multiple times rather than be copy-pasted multiple times (and it could actually be helpful to handle alms in general)
Something like
def get_valid_mask_alms(lmax, mmax_positive=None, mmax_negative=None):
L = lmax + 1
if mmax_positive is None:
mmax = L
if mmax_negative is None:
mmax = -L+1
m_vals = jnp.arange(mmax_negative, mmax_positive)
ell_vals = jnp.arange(L)
ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
return (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax)
which is then called as
| m_vals = jnp.arange(-L + 1, L) | |
| ell_vals = jnp.arange(L) | |
| ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij') | |
| valid_mask = (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax) | |
| valid_mask = get_valid_mask_alms(lmax) |
| cl_grid = jnp.broadcast_to(cl[:, None], (L, 2 * L - 1)) | ||
| scale = jnp.sqrt(cl_grid / 2.0) | ||
| scale = jnp.where(m_grid == 0, jnp.sqrt(cl_grid), scale) | ||
|
|
||
| alms = (rand_real + 1j * rand_imag) * scale * valid_mask | ||
| alms = alms.at[:, L - 1].set(alms[:, L - 1].real) |
There was a problem hiding this comment.
I think there is a little issue here: this is correct to generate auto-spectra, but not the cross-spectra, which is quite problematic typically for CMB TE correlation (which cannot be mimicked as a batch over this function)
Instead of taking jnp.sqrt, you can simply build the full nstokes-nstokes matrix per ell and use https://docs.jax.dev/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html to take the matrix square root of it to handle cross-correlation
-> If you want, I can make a suggestion handling this!
| ) | ||
|
|
||
| # Transform back to map | ||
| map_out = alm2map(alms_smooth, nside=npix2nside(map_in.shape[0]), lmax=lmax, mmax=mmax, healpy_ordering=False) |
There was a problem hiding this comment.
If the pol = True case is accepted, this should be map_in.shape[-1]
| new : bool, optional | ||
| If True, uses JAX PRNG with seed parameter. | ||
| If False, uses numpy random state (not recommended for JAX). | ||
| Default: False (to match healpy) | ||
| verbose : bool, optional | ||
| Verbosity flag. Accepted for API compatibility but ignored with a warning. | ||
| Default: True | ||
| method : str, optional | ||
| Transform method ('jax', 'jax_healpy', 'jax_cuda'). Default: 'jax' | ||
| JAX-specific parameter not present in healpy. | ||
| seed : int, optional | ||
| Random seed for reproducibility (only used if new=True). Default: 0 | ||
| JAX-specific parameter not present in healpy. |
There was a problem hiding this comment.
Indeed the user could be providing seed as int or prng_key and we recreate the prng_keyfrom the int seed if it is provided maybe?
| maps=maps, | ||
| lmax=target_L - 1, | ||
| mmax=mmax, | ||
| iter=0, |
There was a problem hiding this comment.
This should get an iter=iter parameter! It was set to 0 here because it was not in the original function, but it should definitely be added
| else: | ||
| maps = jnp.asarray(maps) | ||
| nside = npix2nside(maps.shape[-1]) |
There was a problem hiding this comment.
The original behavior of this function is a bit weird, but should actually handle a set of maps of spins 0 as 2 maps (provided as T and 0, see https://healpix.sourceforge.io/html/sub_map2alm_spin.htm)
| else: | ||
| alms = jnp.asarray(alms) | ||
| if lmax is None: | ||
| if healpy_ordering: | ||
| nalm = alms.shape[0] | ||
| lmax = int((-1 + jnp.sqrt(1 + 8 * (nalm - 1))) / 2) | ||
| else: | ||
| lmax = alms.shape[0] - 1 |
There was a problem hiding this comment.
Similar as previous comment above for spin 0, see https://healpix.sourceforge.io/html/sub_map2alm_spin.htm
This pull request introduces several improvements and additions to the codebase, focusing on expanding the API surface, enhancing testing infrastructure, updating CI workflows, and increasing test coverage for spherical harmonic transforms. The most significant changes include exposing more functions in the public API, adding new and more comprehensive pytest fixtures, updating dependencies and Python versions in CI, and adding new tests for spherical harmonic transforms.
API Expansion:
__all__list injax_healpy/__init__.py[1] [2].mask_badto the public API injax_healpy/pixelfunc.py.Testing Infrastructure Improvements:
tests/conftest.pyto provide parameterizednside,lmax, synthesized maps, and power spectra for use in tests.pytest-rerunfailuresto test dependencies and defined[tool.setuptools]packages inpyproject.tomlto improve test reliability and packaging [1] [2].Continuous Integration (CI) Updates:
actions/setup-pythonfrom v3 to v5 in.github/workflows/ci.yml[1] [2].s2fft, print installed package versions, and add more robust pytest commands.Test Coverage Enhancements:
alm2clwith various parameterizations intests/sphtfunc/test_cl.pyto improve coverage of spherical harmonic transform functionality.tests/sphtfunc/test_map_alm.pyto use new fixtures and updated imports for clarity and maintainability.These changes collectively improve the usability, reliability, and test coverage of the package, especially around spherical harmonic transforms and their analysis.