Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
repos:
- repo: https://github.com/hadialqattan/pycln
rev: "v2.5.0"
rev: "v2.6.0"
hooks:
- id: pycln
args:
- --all

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.6
rev: v0.15.0
hooks:
- id: ruff
- id: ruff-format

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 'v5.0.0'
rev: 'v6.0.0'
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-merge-conflict

- repo: https://github.com/PyCQA/bandit
rev: '1.8.0'
rev: '1.9.3'
hooks:
- id: bandit
files: ^jax_healpy/
117 changes: 49 additions & 68 deletions jax_healpy/pixelfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,8 @@ def pix2vec(nside: int, ipix: ArrayLike, nest: bool = False) -> Array:

Returns
-------
x, y, z : floats, scalar or array-like
The coordinates of vector corresponding to input pixels. Scalar if all input
are scalar, array otherwise. Usual numpy broadcasting rules apply.
vec : float, array
Array of shape (*input_shape, 3) containing the 3D position vectors.
Comment thread
ASKabalan marked this conversation as resolved.

See Also
--------
Expand Down Expand Up @@ -1462,7 +1461,7 @@ def _pix2vec_ring(nside, pixels):
(1 - z) * (1 + z),
)
)
return jnp.array([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z]).T
return jnp.stack([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z], axis=-1)
Comment thread
ASKabalan marked this conversation as resolved.


@partial(jit, static_argnames=['lonlat'])
Expand All @@ -1482,8 +1481,7 @@ def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array:
Returns
-------
vec : float, array
if theta and phi are vectors, the result is a 2D array with a vector per row
otherwise, it is a 1D array of shape (3,)
Array of shape (*input_shape, 3) containing the 3D position vectors.

See Also
--------
Expand All @@ -1497,7 +1495,7 @@ def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array:
x = sin_theta * jnp.cos(phi)
y = sin_theta * jnp.sin(phi)
z = jnp.cos(theta)
return jnp.array([x, y, z]).T
return jnp.stack([x, y, z], axis=-1)


@partial(jit, static_argnames=['lonlat'])
Expand All @@ -1507,24 +1505,27 @@ def vec2ang(vectors: ArrayLike, lonlat: bool = False) -> tuple[Array, Array]:
Parameters
----------
vectors : float, array-like
the vector(s) to convert, shape is (3,) or (N, 3)
the vector(s) to convert, shape is (*batch, 3)
lonlat : bool, optional
If True, return angles will be longitude and latitude in degree,
otherwise, angles will be co-latitude and longitude in radians (default)

Returns
-------
theta, phi : float, tuple of two arrays
the colatitude and longitude in radians
the colatitude and longitude in radians, each of shape (*batch,)

See Also
--------
ang2vec, rotator.vec2dir, rotator.dir2vec
"""
vectors = vectors.reshape(-1, 3)
vectors = jnp.asarray(vectors)
Comment thread
ASKabalan marked this conversation as resolved.
# Enforce documented shape (*batch, 3) to avoid silently ignoring components
if vectors.ndim == 0 or vectors.shape[-1] != 3:
raise ValueError(f'`vec2ang` expects `vectors` with shape (*batch, 3); got array with shape {vectors.shape}')
dnorm = jnp.sqrt(vectors[..., 0] ** 2 + vectors[..., 1] ** 2 + vectors[..., 2] ** 2)
theta = jnp.arccos(vectors[:, 2] / dnorm)
phi = jnp.arctan2(vectors[:, 1], vectors[:, 0])
theta = jnp.arccos(vectors[..., 2] / dnorm)
phi = jnp.arctan2(vectors[..., 1], vectors[..., 0])
phi = jnp.where(phi < 0, phi + 2 * np.pi, phi)
if lonlat:
return _thetaphi2lonlat(theta, phi)
Expand Down Expand Up @@ -2162,7 +2163,6 @@ def get_all_neighbours(
# theta contains pixel indices
ipix = theta.astype(_pixel_dtype_for(nside))
input_shape = ipix.shape
ipix_flat = ipix.flatten()
else:
# theta, phi contain angular coordinates - convert to pixels
phi = jnp.asarray(phi)
Expand All @@ -2176,45 +2176,26 @@ def get_all_neighbours(
theta_bc, phi_bc = jnp.broadcast_arrays(theta, phi)
input_shape = theta_bc.shape

# Convert angular coordinates to pixel indices
ipix_flat = ang2pix(nside, theta_bc.flatten(), phi_bc.flatten(), nest=nest)
# Convert angular coordinates to pixel indices (preserves batch shape)
ipix = ang2pix(nside, theta_bc, phi_bc, nest=nest)

# Convert pixels to (x, y, face) coordinates
ix, iy, face_num = pix2xyf(nside, ipix_flat, nest=nest)
# Convert pixels to (x, y, face) coordinates (all element-wise, preserves batch shape)
ix, iy, face_num = pix2xyf(nside, ipix, nest=nest)

# Vectorized neighbor finding for all pixels
neighbors_flat = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest)
# Vectorized neighbor finding for all pixels — output shape (8, *input_shape)
neighbors = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest)

# Conditionally include center pixel based on get_center parameter
if get_center:
# Add center pixels as first element: [CENTER, SW, W, NW, N, NE, E, SE, S]
if phi is None:
# Pixel mode: center pixels are the input pixels themselves
center_pixels_flat = ipix_flat
else:
# Angular mode: center pixels are pixels at the given coordinates
# We already have ipix_flat from the coordinate conversion above
center_pixels_flat = ipix_flat

# Combine center + neighbors: shape (9, N)
result_flat = jnp.concatenate([center_pixels_flat[None, :], neighbors_flat], axis=0)

# Reshape result to (9, *input_shape)
if input_shape == ():
# Scalar input - should return shape (9,), not (9, 1)
return result_flat.squeeze() # Remove the extra dimension
else:
# Array input - reshape from (9, N) to (9, *input_shape)
return result_flat.reshape((9,) + input_shape)
# Combine center + neighbors: shape (9, *input_shape)
result = jnp.concatenate([ipix[None, ...], neighbors], axis=0)
else:
# Original behavior: return only 8 neighbors for backward compatibility
# Reshape result to (8, *input_shape)
if input_shape == ():
# Scalar input - should return shape (8,), not (8, 1)
return neighbors_flat.squeeze() # Remove the extra dimension
else:
# Array input - reshape from (8, N) to (8, *input_shape)
return neighbors_flat.reshape((8,) + input_shape)
result = neighbors

if input_shape == ():
return result.squeeze()
return result


def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, nest: bool = False) -> Array:
Expand All @@ -2236,18 +2217,18 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne
nside : int
HEALPix resolution parameter (must be power of 2)
ix, iy : Array
Face-local x, y coordinates of pixels (shape: (N,))
Face-local x, y coordinates of pixels (shape: (*batch,))
Valid range: [0, nside-1] for pixels within face
face_num : Array
Face numbers of pixels (shape: (N,))
Face numbers of pixels (shape: (*batch,))
Valid range: [0, 11] for HEALPix faces
nest : bool, optional
Whether to use NESTED ordering scheme. Default is False (RING ordering).

Returns
-------
neighbors : Array
Neighbor pixel indices for each input pixel. Shape: (8, N)
Neighbor pixel indices for each input pixel. Shape: (8, *batch)
Neighbors in order: [SW, W, NW, N, NE, E, SE, S]
Non-existent neighbors (at map boundaries) are marked with -1.

Expand All @@ -2256,22 +2237,24 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne
This function implements the exact neighbor-finding logic from the original
HEALPix C++ library, ensuring bit-for-bit compatibility with healpy results.
"""
n_pixels = ix.shape[0]
batch_shape = ix.shape
batch_ndim = len(batch_shape)

# Initialize output array for neighbors
neighbors = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside))
neighbors = jnp.full((8,) + batch_shape, -1, dtype=_pixel_dtype_for(nside))

# Apply 8-direction offsets to get neighbor coordinates
# Use broadcasting: ix[None, :] + _NB_XOFFSET[:, None] -> (8, N)
neighbor_ix = ix[None, :] + _NB_XOFFSET[:, None] # Shape: (8, N)
neighbor_iy = iy[None, :] + _NB_YOFFSET[:, None] # Shape: (8, N)
neighbor_face = jnp.broadcast_to(face_num[None, :], (8, n_pixels)) # Shape: (8, N)
# Reshape offsets to (8, 1, 1, ...) for broadcasting with (*batch,)
offset_shape = (-1,) + (1,) * batch_ndim
neighbor_ix = ix[None, ...] + _NB_XOFFSET.reshape(offset_shape) # Shape: (8, *batch)
neighbor_iy = iy[None, ...] + _NB_YOFFSET.reshape(offset_shape) # Shape: (8, *batch)
neighbor_face = jnp.broadcast_to(face_num[None, ...], (8,) + batch_shape) # Shape: (8, *batch)

# Check which neighbors are within the current face (no boundary crossing)
# Valid range is [0, nside-1] for both ix and iy
within_face = (
(neighbor_ix >= 0) & (neighbor_ix < nside) & (neighbor_iy >= 0) & (neighbor_iy < nside)
) # Shape: (8, N)
) # Shape: (8, *batch)

# For neighbors within face, convert directly to pixels
valid_mask = within_face
Expand All @@ -2286,7 +2269,7 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne
# that handles the most common boundary cases

# Apply face boundary corrections using lookup tables
corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, neighbor_face, face_num, nest)
corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, face_num, nest)

# Use corrected neighbors where we have boundary crossings
neighbors = jnp.where(boundary_mask, corrected_neighbors, neighbors)
Expand All @@ -2295,7 +2278,7 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne


def _handle_face_boundaries(
nside: int, neighbor_ix: Array, neighbor_iy: Array, neighbor_face: Array, original_face: Array, nest: bool
nside: int, neighbor_ix: Array, neighbor_iy: Array, original_face: Array, nest: bool
) -> Array:
"""Handle neighbor pixels that cross face boundaries.

Expand All @@ -2317,18 +2300,16 @@ def _handle_face_boundaries(
nside : int
HEALPix resolution parameter
neighbor_ix, neighbor_iy : Array
Neighbor coordinates that may be outside face boundaries. Shape: (8, N)
neighbor_face : Array
Face numbers for neighbors (initially same as original). Shape: (8, N)
Neighbor coordinates that may be outside face boundaries. Shape: (8, *batch)
original_face : Array
Original face numbers of input pixels. Shape: (N,)
Original face numbers of input pixels. Shape: (*batch,)
nest : bool
Whether to use NESTED ordering

Returns
-------
corrected_neighbors : Array
Corrected neighbor pixel indices. Shape: (8, N)
Corrected neighbor pixel indices. Shape: (8, *batch)
Returns -1 for invalid neighbors (outside map boundaries)

Notes
Expand All @@ -2338,17 +2319,17 @@ def _handle_face_boundaries(
(_NB_FACEARRAY, _NB_SWAPARRAY) encode the complex geometric relationships
between HEALPix faces and handle all 12 face transitions correctly.
"""
n_pixels = original_face.shape[0]
batch_shape = original_face.shape

# Initialize result with invalid neighbors
result = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside))
result = jnp.full((8,) + batch_shape, -1, dtype=_pixel_dtype_for(nside))

# Process each neighbor direction individually
for direction_idx in range(8):
# Get coordinates for this direction across all pixels
ix = neighbor_ix[direction_idx, :] # Shape: (n_pixels,)
iy = neighbor_iy[direction_idx, :] # Shape: (n_pixels,)
orig_face = original_face # Shape: (n_pixels,)
ix = neighbor_ix[direction_idx] # Shape: (*batch,)
iy = neighbor_iy[direction_idx] # Shape: (*batch,)
orig_face = original_face # Shape: (*batch,)

# Check boundary conditions - exact replication of original algorithm
x_low = ix < 0
Expand Down Expand Up @@ -2413,7 +2394,7 @@ def _handle_face_boundaries(
neighbor_pixels = xyf2pix(nside, corrected_ix, corrected_iy, corrected_face, nest=nest)

# Update result for this direction - only valid crossings get neighbor pixels
result = result.at[direction_idx, :].set(jnp.where(valid_crossing, neighbor_pixels, -1))
result = result.at[direction_idx].set(jnp.where(valid_crossing, neighbor_pixels, -1))

return result

Expand Down
70 changes: 66 additions & 4 deletions tests/pixelfunc/test_ang_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,73 @@ def test_ang2vec_array(lonlat: bool) -> None:
def test_vec2ang_array(lonlat: bool) -> None:
vec = np.array([[1, 2, 3], [-1, 2, -1]])
theta0, phi0 = hp.vec2ang(vec[0], lonlat=lonlat)
assert theta0.shape == (1,)
assert phi0.shape == (1,)
assert theta0.shape == ()
assert phi0.shape == ()
theta1, phi1 = hp.vec2ang(vec[1], lonlat=lonlat)
theta, phi = hp.vec2ang(vec, lonlat=lonlat)
assert theta.shape == (2,)
assert phi.shape == (2,)
assert_allclose(theta, np.array([theta0[0], theta1[0]]), rtol=1e-14)
assert_allclose(phi, np.array([phi0[0], phi1[0]]), rtol=1e-14)
assert_allclose(theta, np.array([theta0, theta1]), rtol=1e-14)
assert_allclose(phi, np.array([phi0, phi1]), rtol=1e-14)


def test_vec2ang_multidim_batch():
"""vec2ang should preserve arbitrary batch dimensions."""
rng = np.random.default_rng(42)
vecs = rng.normal(size=(3, 2, 3)) # batch shape (3, 2), last dim is xyz
theta, phi = hp.vec2ang(vecs)
assert theta.shape == (3, 2)
assert phi.shape == (3, 2)
# Check against flattened computation
theta_flat, phi_flat = hp.vec2ang(vecs.reshape(-1, 3))
assert_allclose(theta.ravel(), theta_flat, rtol=1e-14)
assert_allclose(phi.ravel(), phi_flat, rtol=1e-14)


def test_ang2vec_multidim_batch():
"""ang2vec should produce (*batch, 3) output for multi-dim inputs."""
theta = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 0.3]]) # shape (3, 2)
phi = np.array([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]])
vec = hp.ang2vec(theta, phi)
assert vec.shape == (3, 2, 3)
# Verify element-wise consistency
for i in range(3):
for j in range(2):
vec_ij = hp.ang2vec(theta[i, j], phi[i, j])
assert_allclose(vec[i, j], vec_ij, rtol=1e-14)


def test_pix2vec_multidim_batch():
"""pix2vec should produce (*batch, 3) output for multi-dim pixel arrays."""
nside = 8
pixels = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) # shape (3, 4)
vec = hp.pix2vec(nside, pixels)
assert vec.shape == (3, 4, 3)
# Verify against flat computation
vec_flat = hp.pix2vec(nside, pixels.ravel())
assert_allclose(vec.reshape(-1, 3), vec_flat, rtol=1e-14)


def test_vec2ang_ang2vec_roundtrip_multidim():
"""Round-trip vec2ang(ang2vec(theta, phi)) should preserve multi-dim shapes."""
theta = np.array([[0.5, 1.0], [1.5, 2.0]]) # (2, 2)
phi = np.array([[0.1, 1.1], [2.1, 3.1]])
vec = hp.ang2vec(theta, phi)
assert vec.shape == (2, 2, 3)
theta_rt, phi_rt = hp.vec2ang(vec)
assert theta_rt.shape == (2, 2)
assert phi_rt.shape == (2, 2)
assert_allclose(theta_rt, theta, rtol=1e-14, atol=1e-15)
assert_allclose(phi_rt, phi, rtol=1e-14, atol=1e-15)


def test_get_all_neighbours_multidim_batch():
"""get_all_neighbours should produce (8, *batch) for multi-dim theta/phi."""
nside = 8
theta = np.array([[0.5, 1.0, 1.5], [2.0, 2.5, 0.3]]) # (2, 3)
phi = np.array([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]])
neighbors = hp.get_all_neighbours(nside, theta, phi)
assert neighbors.shape == (8, 2, 3)
# Verify against flat computation
neighbors_flat = hp.get_all_neighbours(nside, theta.ravel(), phi.ravel())
assert_allclose(neighbors.reshape(8, -1), neighbors_flat, rtol=1e-14)