Skip to content

Implement alm2map#4

Open
ASKabalan wants to merge 28 commits intomainfrom
implement-alm2map
Open

Implement alm2map#4
ASKabalan wants to merge 28 commits intomainfrom
implement-alm2map

Conversation

@ASKabalan
Copy link
Copy Markdown
Member

@ASKabalan ASKabalan commented Apr 24, 2026

This pull request introduces several improvements and additions to the codebase, focusing on expanding the API surface, enhancing testing infrastructure, updating CI workflows, and increasing test coverage for spherical harmonic transforms. The most significant changes include exposing more functions in the public API, adding new and more comprehensive pytest fixtures, updating dependencies and Python versions in CI, and adding new tests for spherical harmonic transforms.

API Expansion:

  • Exposed additional spherical harmonic transform and analysis functions in the public API by updating the imports and __all__ list in jax_healpy/__init__.py [1] [2].
  • Added mask_bad to the public API in jax_healpy/pixelfunc.py.

Testing Infrastructure Improvements:

  • Added new and more comprehensive pytest fixtures in tests/conftest.py to provide parameterized nside, lmax, synthesized maps, and power spectra for use in tests.
  • Added pytest-rerunfailures to test dependencies and defined [tool.setuptools] packages in pyproject.toml to improve test reliability and packaging [1] [2].

Continuous Integration (CI) Updates:

  • Updated the Python version in CI to 3.11 and 3.12, dropped 3.10, and upgraded actions/setup-python from v3 to v5 in .github/workflows/ci.yml [1] [2].
  • Modified CI steps to install a specific branch of s2fft, print installed package versions, and add more robust pytest commands.

Test Coverage Enhancements:

  • Added new tests for alm2cl with various parameterizations in tests/sphtfunc/test_cl.py to improve coverage of spherical harmonic transform functionality.
  • Refactored tests in tests/sphtfunc/test_map_alm.py to use new fixtures and updated imports for clarity and maintainability.

These changes collectively improve the usability, reliability, and test coverage of the package, especially around spherical harmonic transforms and their analysis.

ASKabalan and others added 28 commits November 21, 2025 14:27
* alm2map/alm2map_pol
* map2alm/map2alm_pol
* precompute functions for spin=0, -2, 2
* almxfl
* synalm

Tests to be

Co-Authored-By: Magwos <magdy.morshed.fr@gmail.com>
Co-Authored-By: ArtemBasyrov <artem.m.basyrov@gmail.com>
* ud_grade
* get_nside
* _ud_grade_core
Tests to be written
* alm2map/alm2map_pol
* map2alm/map2alm_pol
* precompute functions for spin=0, -2, 2
* almxfl
* synalm

Tests to be written
…hough not completed in waiting of s2fft answer
@ASKabalan ASKabalan requested review from Magwos and pchanial April 24, 2026 15:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR significantly expands jax_healpy’s spherical-harmonic (sphtfunc) API surface (e.g., power spectra, smoothing, synthesis utilities), exposes the new functions at the package top level, and adds a broad set of tests/fixtures while updating CI and test dependencies to support the new coverage.

Changes:

  • Add/expand sphtfunc functionality (e.g., alm2cl, anafast, synalm, synfast, smoothing/beam helpers, spin transforms) and export them via jax_healpy.__init__.
  • Add new pytest fixtures and many new/updated tests covering transforms, power spectra, smoothing, and spin-weighted behavior.
  • Update CI Python versions and test dependencies (including pytest-rerunfailures) and adjust workflow steps.

Reviewed changes

Copilot reviewed 12 out of 16 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
jax_healpy/sphtfunc.py Adds new sphtfunc APIs (spectra, smoothing, synthesis, spin transforms) and backend selection support.
jax_healpy/__init__.py Re-exports newly added sphtfunc functions as public API.
jax_healpy/pixelfunc.py Exposes mask_bad in the public API list.
tests/conftest.py Adds session-scoped fixtures for nside, lmax, and synthesized maps/spectra.
tests/sphtfunc/conftest.py Adds batched alm generator fixture to support batched transform tests.
tests/sphtfunc/test_map_alm.py Refactors/extends map↔alm tests using new fixtures and adds additional validation cases.
tests/sphtfunc/test_transform_tools.py Adds tests for beam/smoothing helpers vs healpy.
tests/sphtfunc/test_spin_transforms.py Adds spin transform tests and input validation tests.
tests/sphtfunc/test_cl.py Adds alm2cl tests across parameter combinations.
tests/sphtfunc/test_anafast.py Adds extensive tests for anafast, synfast, and synalm behaviors.
pyproject.toml Adds pytest-rerunfailures and sets setuptools package list.
.github/workflows/ci.yml Updates Python versions and changes dependency install / test execution.
.gitignore Ignores .pip-packages*.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread jax_healpy/sphtfunc.py
Comment on lines +986 to +999
# Compute all n(n+1)/2 spectra
# HealPy order: all auto-spectra first, then cross-spectra grouped by separation
spectra = []

# First, all auto-spectra (i, i)
for i in range(n_alms):
cl = _compute_cl(alms1_list[i], L, None)
spectra.append(cl)

# Then, cross-spectra grouped by separation (j-i): (0,1), (1,2), ..., (0,2), (1,3), ...
for separation in range(1, n_alms):
for i in range(n_alms - separation):
j = i + separation
cl = _compute_cl(alms1_list[i], L, alms1_list[j])
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple-alm branch builds spectra in an order that doesn’t match the documented/healpy order (11, 12, 22, 13, 23, 33, ...). Current code appends all auto-spectra first and then cross-spectra by separation, which yields (11, 22, 33, 12, 23, 13, ...) for 3 alms. Please change the iteration to produce diagonal-major order (i.e., nested loops over i<=j) so callers and tests align with healpy.

Suggested change
# Compute all n(n+1)/2 spectra
# HealPy order: all auto-spectra first, then cross-spectra grouped by separation
spectra = []
# First, all auto-spectra (i, i)
for i in range(n_alms):
cl = _compute_cl(alms1_list[i], L, None)
spectra.append(cl)
# Then, cross-spectra grouped by separation (j-i): (0,1), (1,2), ..., (0,2), (1,3), ...
for separation in range(1, n_alms):
for i in range(n_alms - separation):
j = i + separation
cl = _compute_cl(alms1_list[i], L, alms1_list[j])
# Compute all n(n+1)/2 spectra in HealPy diagonal-major order:
# (0,0), (0,1), (1,1), (0,2), (1,2), (2,2), ...
spectra = []
for j in range(n_alms):
for i in range(j + 1):
if i == j:
cl = _compute_cl(alms1_list[i], L, None)
else:
cl = _compute_cl(alms1_list[i], L, alms1_list[j])

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can also be simply a re-ordering of the output rather than adding this condition i==j

Comment thread jax_healpy/sphtfunc.py
Comment on lines +963 to +966
# Check if lmax_out is provided and different from lmax
if lmax_out is not None and lmax is not None and lmax_out != lmax:
raise ValueError(f'lmax_out ({lmax_out}) must equal lmax ({lmax}) if both provided')

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lmax_out is documented but effectively ignored unless lmax is also provided (only a mismatch check exists). If lmax_out is set while lmax is None, the function still returns the full lmax inferred from the input. Either implement lmax_out by truncating the output spectra to lmax_out+1, or raise NotImplementedError when lmax_out is provided and differs from the resolved lmax.

Copilot uses AI. Check for mistakes.
Comment thread jax_healpy/sphtfunc.py
Comment on lines +385 to +387
# Generate precomputes for both transforms
precomps_pos = generate_precomputes_jax(L, spin, sampling, nside, False)
precomps_neg = generate_precomputes_jax(L, -spin, sampling, nside, False)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _map2alm_core spin!=0 branch you call spherical.forward(...) but generate precomputes with generate_precomputes_jax(..., forward=False) (positional False). In the scalar branch the forward transform uses forward=True. If forward controls pix→harm vs harm→pix (as implied by the other helpers), these spin precomputes are for the wrong direction and can degrade/incorrectly compute E/B. Please switch these to forward=True (or use the keyword) for the forward transform case.

Suggested change
# Generate precomputes for both transforms
precomps_pos = generate_precomputes_jax(L, spin, sampling, nside, False)
precomps_neg = generate_precomputes_jax(L, -spin, sampling, nside, False)
# Generate precomputes for both forward transforms
precomps_pos = generate_precomputes_jax(
L, spin, sampling, nside, forward=True
)
precomps_neg = generate_precomputes_jax(
L, -spin, sampling, nside, forward=True
)

Copilot uses AI. Check for mistakes.
Comment on lines +98 to +125
rng_state = np.random.get_state()
np.random.seed(seed)
map_healpy = hp.synfast(cla, nside, lmax=lmax, pol=False)
np.random.set_state(rng_state)

# Check output shape
npix = hp.nside2npix(nside)
assert map_jax.shape == (npix,)

# Check that map has reasonable statistics
# Mean should be close to zero
assert jnp.abs(jnp.mean(map_jax)) < 1e-2

# Variance should be reasonable (related to C_l)
assert jnp.std(map_jax) > 0

print(f'Mean JAX: {jnp.mean(map_jax)}, Mean Healpy: {np.mean(map_healpy)}')
print(f'Std JAX: {jnp.std(map_jax)}, Std Healpy: {np.std(map_healpy)}')
mean_diff = jnp.abs(jnp.mean(map_jax) - np.mean(map_healpy))
mean_rtol = mean_diff / (np.abs(np.mean(map_healpy)) + 1e-20)
print(f'atol diff mean: {mean_diff} rtol diff mean: {mean_rtol}')
std_diff = jnp.abs(jnp.std(map_jax) - np.std(map_healpy))
std_rtol = std_diff / (np.abs(np.std(map_healpy)) + 1e-20)
print(f'atol diff std: {std_diff} rtol diff std: {std_rtol}')

# Compare summary statistics with healpy map (stochastic realizations won't match sample-wise)
assert np.isclose(np.mean(map_jax), np.mean(map_healpy), atol=1e-5)
assert np.isclose(np.std(map_jax), np.std(map_healpy), atol=1e-2)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_synfast_basic compares mean/std against a single healpy.synfast realization using very tight tolerances (e.g. mean atol=1e-5). Since jax-healpy uses JAX PRNG and healpy uses NumPy PRNG, the realizations won’t match even with the same numeric seed, so these assertions are likely to fail or be flaky. Consider removing the direct healpy comparison here (you already have a synfast→anafast roundtrip test) or replacing it with a statistical check that accounts for realization variance.

Suggested change
rng_state = np.random.get_state()
np.random.seed(seed)
map_healpy = hp.synfast(cla, nside, lmax=lmax, pol=False)
np.random.set_state(rng_state)
# Check output shape
npix = hp.nside2npix(nside)
assert map_jax.shape == (npix,)
# Check that map has reasonable statistics
# Mean should be close to zero
assert jnp.abs(jnp.mean(map_jax)) < 1e-2
# Variance should be reasonable (related to C_l)
assert jnp.std(map_jax) > 0
print(f'Mean JAX: {jnp.mean(map_jax)}, Mean Healpy: {np.mean(map_healpy)}')
print(f'Std JAX: {jnp.std(map_jax)}, Std Healpy: {np.std(map_healpy)}')
mean_diff = jnp.abs(jnp.mean(map_jax) - np.mean(map_healpy))
mean_rtol = mean_diff / (np.abs(np.mean(map_healpy)) + 1e-20)
print(f'atol diff mean: {mean_diff} rtol diff mean: {mean_rtol}')
std_diff = jnp.abs(jnp.std(map_jax) - np.std(map_healpy))
std_rtol = std_diff / (np.abs(np.std(map_healpy)) + 1e-20)
print(f'atol diff std: {std_diff} rtol diff std: {std_rtol}')
# Compare summary statistics with healpy map (stochastic realizations won't match sample-wise)
assert np.isclose(np.mean(map_jax), np.mean(map_healpy), atol=1e-5)
assert np.isclose(np.std(map_jax), np.std(map_healpy), atol=1e-2)
# Check output shape
npix = hp.nside2npix(nside)
assert map_jax.shape == (npix,)
# Check that map has reasonable statistics
# Mean should be close to zero for a zero-mean Gaussian realization.
assert jnp.abs(jnp.mean(map_jax)) < 1e-2
# Variance should be positive and consistent with the input power spectrum.
map_std = float(jnp.std(map_jax))
assert map_std > 0
ell = np.arange(min(len(cla), lmax + 1))
expected_var = np.sum((2 * ell[1:] + 1) * cla[1:len(ell)] / (4 * np.pi))
expected_std = np.sqrt(expected_var)
# This is a single stochastic realization, so use a statistical check with
# a tolerance that allows realization variance rather than comparing against
# one healpy realization generated from a different RNG implementation.
assert np.isclose(map_std, expected_std, rtol=0.5)

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +126
2


Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a stray literal 2 at module scope, which will raise a SyntaxError and prevent the entire test module from importing. Please remove this line.

Suggested change
2

Copilot uses AI. Check for mistakes.
Comment thread jax_healpy/sphtfunc.py
Comment on lines +781 to +783
@partial(jax.jit, static_argnames=['mmax', 'inplace', 'healpy_ordering'])
def almxfl(alm: ArrayLike, fl: ArrayLike, mmax: int | None = None, inplace: bool = False, healpy_ordering: bool = True):
"""Multiply alm by a filter function fl.
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almxfl accepts mmax and documents it, but the implementation doesn’t use mmax at all (and inplace is also ignored). This is surprising for callers and can produce incorrect behavior if someone expects healpy-style truncation. Please either implement mmax handling or validate that mmax is None/equals lmax and raise NotImplementedError otherwise; similarly consider warning/validating inplace instead of silently ignoring it.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mmax, it is actually not only relevant when the alms provided are healpy ordering (when doing alm_2d = flm_hp_to_2d_fast(alm, L) and then flm_2d_to_hp_fast(alm_filtered, L), but it is not the case for s2fft
I would think that the most natural way to handle it would be to add support for healpy.sphtfunc.resize_alm, but I am not 100% sure it is very JAX-compatible

Comment thread .github/workflows/ci.yml
Comment on lines +49 to +50
pytest -v -m "not slow" --tb=short 2>&1 || true
pytest -v tests/sphtfunc/test_spin_transforms.py::test_map2alm_spin_different_spins -x --tb=long
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test step masks failures (pytest ... || true) and then only runs a single test afterwards. This can make CI pass even when the full test suite is failing. Please remove the || true and run the full test suite (optionally with reruns/-x/separate jobs) so failures correctly fail the workflow.

Suggested change
pytest -v -m "not slow" --tb=short 2>&1 || true
pytest -v tests/sphtfunc/test_spin_transforms.py::test_map2alm_spin_different_spins -x --tb=long
pytest -v -m "not slow" --tb=short

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install git+https://github.com/astro-informatics/s2fft.git
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI installs s2fft directly from GitHub without pinning to a commit/tag/branch, which can make builds non-reproducible and introduce unrelated breakages. Consider pinning to a known-good commit or a released version (or using a constraints file) for stable CI.

Suggested change
pip install git+https://github.com/astro-informatics/s2fft.git
# Pin s2fft to a known-good immutable commit for reproducible CI builds.
pip install git+https://github.com/astro-informatics/s2fft.git@<KNOWN_GOOD_S2FFT_COMMIT>

Copilot uses AI. Check for mistakes.
Comment thread jax_healpy/sphtfunc.py
Comment on lines +1610 to +1622
new : bool, optional
If True, uses JAX PRNG with seed parameter.
If False, uses numpy random state (not recommended for JAX).
Default: False (to match healpy)
verbose : bool, optional
Verbosity flag. Accepted for API compatibility but ignored with a warning.
Default: True
method : str, optional
Transform method ('jax', 'jax_healpy', 'jax_cuda'). Default: 'jax'
JAX-specific parameter not present in healpy.
seed : int, optional
Random seed for reproducibility (only used if new=True). Default: 0
JAX-specific parameter not present in healpy.
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synfast docstring documents a seed parameter, but the function signature has no seed argument (it takes a prng_key instead). Please update the docstring to match the actual API (e.g., describe how to build jax.random.PRNGKey(seed)), or add a seed convenience parameter if that’s intended.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the user could be providing seed as int or prng_key and we recreate the prng_keyfrom the int seed if it is provided maybe?

Comment thread tests/sphtfunc/test_cl.py
Comment on lines +20 to +22
data_path = Path(jhp.__file__).parent.parent / 'tests/data'


Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module-level data_path = ... is unused and also shadows the data_path fixture name, which makes the file harder to read. Consider removing the module variable and relying solely on the fixture.

Suggested change
data_path = Path(jhp.__file__).parent.parent / 'tests/data'

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

@Magwos Magwos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it is almost ready, there are some minor changes which could be quite helpful to support more the polarization cases (in particular alm2map, map2alm and synalm)! We can discuss more about all of that if you want

Comment thread jax_healpy/sphtfunc.py
Comment on lines +162 to +163
# Use jnp.where to handle fwhm=0 case in a JAX-compatible way
sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary? Wouldn't

Suggested change
# Use jnp.where to handle fwhm=0 case in a JAX-compatible way
sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0)))
sigma = fwhm / jnp.sqrt(8.0 * jnp.log(2.0)))

simply work? (or maybe it's for batching purposes?)

Comment thread jax_healpy/sphtfunc.py
Comment on lines +888 to +891
if pol:
raise NotImplementedError('pol=True (polarization beam components) is not supported yet')

return _compute_beam_window(lmax, fwhm=fwhm, sigma=None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of pol should be relatively simple I think (because pol is considered static anyway), you can simply do

Suggested change
if pol:
raise NotImplementedError('pol=True (polarization beam components) is not supported yet')
return _compute_beam_window(lmax, fwhm=fwhm, sigma=None)
if sigma is None:
# Use jnp.where to handle fwhm=0 case in a JAX-compatible way
sigma = jnp.where(fwhm == 0.0, 0.0, fwhm / jnp.sqrt(8.0 * jnp.log(2.0)))
beam_window = _compute_beam_window(lmax, fwhm=fwhm, sigma=sigma)
if pol:
pol_factor = jnp.exp([0.0, 2 * sigma**2, 2 * sigma**2, sigma2])
return beam_window[...,None] * pol_factor
return beam_window

Comment thread jax_healpy/sphtfunc.py
Comment on lines +172 to +178
m_vals = jnp.arange(-L + 1, L)
ell_vals = jnp.arange(L)
ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
valid_mask = jnp.abs(m_grid) <= ell_grid

alm_prod = jnp.abs(alms) ** 2 if alms2 is None else alms * jnp.conj(alms2)
alm_prod_masked = alm_prod * valid_mask
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly confused if we want this valid_mask or not

In principle, we expect the provided alms to be well defined with |m|<= ell, so this valid_mask is only here to regularize in case the alms provided happens not to be well defined

However, in this case, do we really prefer to apply this correction rather than:

  • compute the c_ells with the provided wrong alms?
  • raise an error? (probably difficult with jitting)

If we opt for the first option, to regularize the alms, I propose to add another function (inspired by check_theta_valid here

def check_theta_valid(theta):
)

def check_alms_valid(alms: ArrayLike, L: int) -> None:
    m_vals = jnp.arange(-L + 1, L)
    ell_vals = jnp.arange(L)
    ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
    bad_mask = jnp.abs(m_grid) > ell_grid

    invalid_alms = (jnp.abs(alms * bad_mask)>1e-14).any()
    def _raise_invalid_alms(invalid_alms):
        if invalid_theta:
            raise ValueError('The provided alms are not well defined with only 0 on their indices |m|>ell')

    jax.debug.callback(_raise_invalid_alms, invalid_alms)

Comment thread jax_healpy/sphtfunc.py
Comment on lines +203 to +206
m_vals = jnp.arange(-L + 1, L)
ell_vals = jnp.arange(L)
ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
valid_mask = (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go inside a function to be called multiple times rather than be copy-pasted multiple times (and it could actually be helpful to handle alms in general)

Something like

def get_valid_mask_alms(lmax, mmax_positive=None, mmax_negative=None):
    L = lmax + 1
    if mmax_positive is None:
       mmax  = L
    if mmax_negative is None:
       mmax  = -L+1

    m_vals = jnp.arange(mmax_negative, mmax_positive)
    ell_vals = jnp.arange(L)
    ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
    return (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax)

which is then called as

Suggested change
m_vals = jnp.arange(-L + 1, L)
ell_vals = jnp.arange(L)
ell_grid, m_grid = jnp.meshgrid(ell_vals, m_vals, indexing='ij')
valid_mask = (jnp.abs(m_grid) <= ell_grid) & (jnp.abs(m_grid) <= mmax)
valid_mask = get_valid_mask_alms(lmax)

Comment thread jax_healpy/sphtfunc.py
Comment on lines +208 to +213
cl_grid = jnp.broadcast_to(cl[:, None], (L, 2 * L - 1))
scale = jnp.sqrt(cl_grid / 2.0)
scale = jnp.where(m_grid == 0, jnp.sqrt(cl_grid), scale)

alms = (rand_real + 1j * rand_imag) * scale * valid_mask
alms = alms.at[:, L - 1].set(alms[:, L - 1].real)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a little issue here: this is correct to generate auto-spectra, but not the cross-spectra, which is quite problematic typically for CMB TE correlation (which cannot be mimicked as a batch over this function)

Instead of taking jnp.sqrt, you can simply build the full nstokes-nstokes matrix per ell and use https://docs.jax.dev/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html to take the matrix square root of it to handle cross-correlation
-> If you want, I can make a suggestion handling this!

Comment thread jax_healpy/sphtfunc.py
)

# Transform back to map
map_out = alm2map(alms_smooth, nside=npix2nside(map_in.shape[0]), lmax=lmax, mmax=mmax, healpy_ordering=False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pol = True case is accepted, this should be map_in.shape[-1]

Comment thread jax_healpy/sphtfunc.py
Comment on lines +1610 to +1622
new : bool, optional
If True, uses JAX PRNG with seed parameter.
If False, uses numpy random state (not recommended for JAX).
Default: False (to match healpy)
verbose : bool, optional
Verbosity flag. Accepted for API compatibility but ignored with a warning.
Default: True
method : str, optional
Transform method ('jax', 'jax_healpy', 'jax_cuda'). Default: 'jax'
JAX-specific parameter not present in healpy.
seed : int, optional
Random seed for reproducibility (only used if new=True). Default: 0
JAX-specific parameter not present in healpy.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the user could be providing seed as int or prng_key and we recreate the prng_keyfrom the int seed if it is provided maybe?

Comment thread jax_healpy/sphtfunc.py
maps=maps,
lmax=target_L - 1,
mmax=mmax,
iter=0,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should get an iter=iter parameter! It was set to 0 here because it was not in the original function, but it should definitely be added

Comment thread jax_healpy/sphtfunc.py
Comment on lines +1757 to +1759
else:
maps = jnp.asarray(maps)
nside = npix2nside(maps.shape[-1])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original behavior of this function is a bit weird, but should actually handle a set of maps of spins 0 as 2 maps (provided as T and 0, see https://healpix.sourceforge.io/html/sub_map2alm_spin.htm)

Comment thread jax_healpy/sphtfunc.py
Comment on lines +1858 to +1865
else:
alms = jnp.asarray(alms)
if lmax is None:
if healpy_ordering:
nalm = alms.shape[0]
lmax = int((-1 + jnp.sqrt(1 + 8 * (nalm - 1))) / 2)
else:
lmax = alms.shape[0] - 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar as previous comment above for spin 0, see https://healpix.sourceforge.io/html/sub_map2alm_spin.htm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants