diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e34872..7f9fb99 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,19 +1,19 @@ repos: - repo: https://github.com/hadialqattan/pycln - rev: "v2.5.0" + rev: "v2.6.0" hooks: - id: pycln args: - --all - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.6 + rev: v0.15.0 hooks: - id: ruff - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 'v5.0.0' + rev: 'v6.0.0' hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -21,7 +21,7 @@ repos: - id: check-merge-conflict - repo: https://github.com/PyCQA/bandit - rev: '1.8.0' + rev: '1.9.3' hooks: - id: bandit files: ^jax_healpy/ diff --git a/jax_healpy/pixelfunc.py b/jax_healpy/pixelfunc.py index 05580a5..ff21a0b 100644 --- a/jax_healpy/pixelfunc.py +++ b/jax_healpy/pixelfunc.py @@ -1424,9 +1424,8 @@ def pix2vec(nside: int, ipix: ArrayLike, nest: bool = False) -> Array: Returns ------- - x, y, z : floats, scalar or array-like - The coordinates of vector corresponding to input pixels. Scalar if all input - are scalar, array otherwise. Usual numpy broadcasting rules apply. + vec : float, array + Array of shape (*input_shape, 3) containing the 3D position vectors. See Also -------- @@ -1462,7 +1461,7 @@ def _pix2vec_ring(nside, pixels): (1 - z) * (1 + z), ) ) - return jnp.array([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z]).T + return jnp.stack([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z], axis=-1) @partial(jit, static_argnames=['lonlat']) @@ -1482,8 +1481,7 @@ def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array: Returns ------- vec : float, array - if theta and phi are vectors, the result is a 2D array with a vector per row - otherwise, it is a 1D array of shape (3,) + Array of shape (*input_shape, 3) containing the 3D position vectors. See Also -------- @@ -1497,7 +1495,7 @@ def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array: x = sin_theta * jnp.cos(phi) y = sin_theta * jnp.sin(phi) z = jnp.cos(theta) - return jnp.array([x, y, z]).T + return jnp.stack([x, y, z], axis=-1) @partial(jit, static_argnames=['lonlat']) @@ -1507,7 +1505,7 @@ def vec2ang(vectors: ArrayLike, lonlat: bool = False) -> tuple[Array, Array]: Parameters ---------- vectors : float, array-like - the vector(s) to convert, shape is (3,) or (N, 3) + the vector(s) to convert, shape is (*batch, 3) lonlat : bool, optional If True, return angles will be longitude and latitude in degree, otherwise, angles will be co-latitude and longitude in radians (default) @@ -1515,16 +1513,19 @@ def vec2ang(vectors: ArrayLike, lonlat: bool = False) -> tuple[Array, Array]: Returns ------- theta, phi : float, tuple of two arrays - the colatitude and longitude in radians + the colatitude and longitude in radians, each of shape (*batch,) See Also -------- ang2vec, rotator.vec2dir, rotator.dir2vec """ - vectors = vectors.reshape(-1, 3) + vectors = jnp.asarray(vectors) + # Enforce documented shape (*batch, 3) to avoid silently ignoring components + if vectors.ndim == 0 or vectors.shape[-1] != 3: + raise ValueError(f'`vec2ang` expects `vectors` with shape (*batch, 3); got array with shape {vectors.shape}') dnorm = jnp.sqrt(vectors[..., 0] ** 2 + vectors[..., 1] ** 2 + vectors[..., 2] ** 2) - theta = jnp.arccos(vectors[:, 2] / dnorm) - phi = jnp.arctan2(vectors[:, 1], vectors[:, 0]) + theta = jnp.arccos(vectors[..., 2] / dnorm) + phi = jnp.arctan2(vectors[..., 1], vectors[..., 0]) phi = jnp.where(phi < 0, phi + 2 * np.pi, phi) if lonlat: return _thetaphi2lonlat(theta, phi) @@ -2162,7 +2163,6 @@ def get_all_neighbours( # theta contains pixel indices ipix = theta.astype(_pixel_dtype_for(nside)) input_shape = ipix.shape - ipix_flat = ipix.flatten() else: # theta, phi contain angular coordinates - convert to pixels phi = jnp.asarray(phi) @@ -2176,45 +2176,26 @@ def get_all_neighbours( theta_bc, phi_bc = jnp.broadcast_arrays(theta, phi) input_shape = theta_bc.shape - # Convert angular coordinates to pixel indices - ipix_flat = ang2pix(nside, theta_bc.flatten(), phi_bc.flatten(), nest=nest) + # Convert angular coordinates to pixel indices (preserves batch shape) + ipix = ang2pix(nside, theta_bc, phi_bc, nest=nest) - # Convert pixels to (x, y, face) coordinates - ix, iy, face_num = pix2xyf(nside, ipix_flat, nest=nest) + # Convert pixels to (x, y, face) coordinates (all element-wise, preserves batch shape) + ix, iy, face_num = pix2xyf(nside, ipix, nest=nest) - # Vectorized neighbor finding for all pixels - neighbors_flat = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest) + # Vectorized neighbor finding for all pixels — output shape (8, *input_shape) + neighbors = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest) # Conditionally include center pixel based on get_center parameter if get_center: # Add center pixels as first element: [CENTER, SW, W, NW, N, NE, E, SE, S] - if phi is None: - # Pixel mode: center pixels are the input pixels themselves - center_pixels_flat = ipix_flat - else: - # Angular mode: center pixels are pixels at the given coordinates - # We already have ipix_flat from the coordinate conversion above - center_pixels_flat = ipix_flat - - # Combine center + neighbors: shape (9, N) - result_flat = jnp.concatenate([center_pixels_flat[None, :], neighbors_flat], axis=0) - - # Reshape result to (9, *input_shape) - if input_shape == (): - # Scalar input - should return shape (9,), not (9, 1) - return result_flat.squeeze() # Remove the extra dimension - else: - # Array input - reshape from (9, N) to (9, *input_shape) - return result_flat.reshape((9,) + input_shape) + # Combine center + neighbors: shape (9, *input_shape) + result = jnp.concatenate([ipix[None, ...], neighbors], axis=0) else: - # Original behavior: return only 8 neighbors for backward compatibility - # Reshape result to (8, *input_shape) - if input_shape == (): - # Scalar input - should return shape (8,), not (8, 1) - return neighbors_flat.squeeze() # Remove the extra dimension - else: - # Array input - reshape from (8, N) to (8, *input_shape) - return neighbors_flat.reshape((8,) + input_shape) + result = neighbors + + if input_shape == (): + return result.squeeze() + return result def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, nest: bool = False) -> Array: @@ -2236,10 +2217,10 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne nside : int HEALPix resolution parameter (must be power of 2) ix, iy : Array - Face-local x, y coordinates of pixels (shape: (N,)) + Face-local x, y coordinates of pixels (shape: (*batch,)) Valid range: [0, nside-1] for pixels within face face_num : Array - Face numbers of pixels (shape: (N,)) + Face numbers of pixels (shape: (*batch,)) Valid range: [0, 11] for HEALPix faces nest : bool, optional Whether to use NESTED ordering scheme. Default is False (RING ordering). @@ -2247,7 +2228,7 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne Returns ------- neighbors : Array - Neighbor pixel indices for each input pixel. Shape: (8, N) + Neighbor pixel indices for each input pixel. Shape: (8, *batch) Neighbors in order: [SW, W, NW, N, NE, E, SE, S] Non-existent neighbors (at map boundaries) are marked with -1. @@ -2256,22 +2237,24 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne This function implements the exact neighbor-finding logic from the original HEALPix C++ library, ensuring bit-for-bit compatibility with healpy results. """ - n_pixels = ix.shape[0] + batch_shape = ix.shape + batch_ndim = len(batch_shape) # Initialize output array for neighbors - neighbors = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside)) + neighbors = jnp.full((8,) + batch_shape, -1, dtype=_pixel_dtype_for(nside)) # Apply 8-direction offsets to get neighbor coordinates - # Use broadcasting: ix[None, :] + _NB_XOFFSET[:, None] -> (8, N) - neighbor_ix = ix[None, :] + _NB_XOFFSET[:, None] # Shape: (8, N) - neighbor_iy = iy[None, :] + _NB_YOFFSET[:, None] # Shape: (8, N) - neighbor_face = jnp.broadcast_to(face_num[None, :], (8, n_pixels)) # Shape: (8, N) + # Reshape offsets to (8, 1, 1, ...) for broadcasting with (*batch,) + offset_shape = (-1,) + (1,) * batch_ndim + neighbor_ix = ix[None, ...] + _NB_XOFFSET.reshape(offset_shape) # Shape: (8, *batch) + neighbor_iy = iy[None, ...] + _NB_YOFFSET.reshape(offset_shape) # Shape: (8, *batch) + neighbor_face = jnp.broadcast_to(face_num[None, ...], (8,) + batch_shape) # Shape: (8, *batch) # Check which neighbors are within the current face (no boundary crossing) # Valid range is [0, nside-1] for both ix and iy within_face = ( (neighbor_ix >= 0) & (neighbor_ix < nside) & (neighbor_iy >= 0) & (neighbor_iy < nside) - ) # Shape: (8, N) + ) # Shape: (8, *batch) # For neighbors within face, convert directly to pixels valid_mask = within_face @@ -2286,7 +2269,7 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne # that handles the most common boundary cases # Apply face boundary corrections using lookup tables - corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, neighbor_face, face_num, nest) + corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, face_num, nest) # Use corrected neighbors where we have boundary crossings neighbors = jnp.where(boundary_mask, corrected_neighbors, neighbors) @@ -2295,7 +2278,7 @@ def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, ne def _handle_face_boundaries( - nside: int, neighbor_ix: Array, neighbor_iy: Array, neighbor_face: Array, original_face: Array, nest: bool + nside: int, neighbor_ix: Array, neighbor_iy: Array, original_face: Array, nest: bool ) -> Array: """Handle neighbor pixels that cross face boundaries. @@ -2317,18 +2300,16 @@ def _handle_face_boundaries( nside : int HEALPix resolution parameter neighbor_ix, neighbor_iy : Array - Neighbor coordinates that may be outside face boundaries. Shape: (8, N) - neighbor_face : Array - Face numbers for neighbors (initially same as original). Shape: (8, N) + Neighbor coordinates that may be outside face boundaries. Shape: (8, *batch) original_face : Array - Original face numbers of input pixels. Shape: (N,) + Original face numbers of input pixels. Shape: (*batch,) nest : bool Whether to use NESTED ordering Returns ------- corrected_neighbors : Array - Corrected neighbor pixel indices. Shape: (8, N) + Corrected neighbor pixel indices. Shape: (8, *batch) Returns -1 for invalid neighbors (outside map boundaries) Notes @@ -2338,17 +2319,17 @@ def _handle_face_boundaries( (_NB_FACEARRAY, _NB_SWAPARRAY) encode the complex geometric relationships between HEALPix faces and handle all 12 face transitions correctly. """ - n_pixels = original_face.shape[0] + batch_shape = original_face.shape # Initialize result with invalid neighbors - result = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside)) + result = jnp.full((8,) + batch_shape, -1, dtype=_pixel_dtype_for(nside)) # Process each neighbor direction individually for direction_idx in range(8): # Get coordinates for this direction across all pixels - ix = neighbor_ix[direction_idx, :] # Shape: (n_pixels,) - iy = neighbor_iy[direction_idx, :] # Shape: (n_pixels,) - orig_face = original_face # Shape: (n_pixels,) + ix = neighbor_ix[direction_idx] # Shape: (*batch,) + iy = neighbor_iy[direction_idx] # Shape: (*batch,) + orig_face = original_face # Shape: (*batch,) # Check boundary conditions - exact replication of original algorithm x_low = ix < 0 @@ -2413,7 +2394,7 @@ def _handle_face_boundaries( neighbor_pixels = xyf2pix(nside, corrected_ix, corrected_iy, corrected_face, nest=nest) # Update result for this direction - only valid crossings get neighbor pixels - result = result.at[direction_idx, :].set(jnp.where(valid_crossing, neighbor_pixels, -1)) + result = result.at[direction_idx].set(jnp.where(valid_crossing, neighbor_pixels, -1)) return result diff --git a/tests/pixelfunc/test_ang_vec.py b/tests/pixelfunc/test_ang_vec.py index d3122f6..43717b5 100644 --- a/tests/pixelfunc/test_ang_vec.py +++ b/tests/pixelfunc/test_ang_vec.py @@ -67,11 +67,73 @@ def test_ang2vec_array(lonlat: bool) -> None: def test_vec2ang_array(lonlat: bool) -> None: vec = np.array([[1, 2, 3], [-1, 2, -1]]) theta0, phi0 = hp.vec2ang(vec[0], lonlat=lonlat) - assert theta0.shape == (1,) - assert phi0.shape == (1,) + assert theta0.shape == () + assert phi0.shape == () theta1, phi1 = hp.vec2ang(vec[1], lonlat=lonlat) theta, phi = hp.vec2ang(vec, lonlat=lonlat) assert theta.shape == (2,) assert phi.shape == (2,) - assert_allclose(theta, np.array([theta0[0], theta1[0]]), rtol=1e-14) - assert_allclose(phi, np.array([phi0[0], phi1[0]]), rtol=1e-14) + assert_allclose(theta, np.array([theta0, theta1]), rtol=1e-14) + assert_allclose(phi, np.array([phi0, phi1]), rtol=1e-14) + + +def test_vec2ang_multidim_batch(): + """vec2ang should preserve arbitrary batch dimensions.""" + rng = np.random.default_rng(42) + vecs = rng.normal(size=(3, 2, 3)) # batch shape (3, 2), last dim is xyz + theta, phi = hp.vec2ang(vecs) + assert theta.shape == (3, 2) + assert phi.shape == (3, 2) + # Check against flattened computation + theta_flat, phi_flat = hp.vec2ang(vecs.reshape(-1, 3)) + assert_allclose(theta.ravel(), theta_flat, rtol=1e-14) + assert_allclose(phi.ravel(), phi_flat, rtol=1e-14) + + +def test_ang2vec_multidim_batch(): + """ang2vec should produce (*batch, 3) output for multi-dim inputs.""" + theta = np.array([[0.5, 1.0], [1.5, 2.0], [2.5, 0.3]]) # shape (3, 2) + phi = np.array([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]) + vec = hp.ang2vec(theta, phi) + assert vec.shape == (3, 2, 3) + # Verify element-wise consistency + for i in range(3): + for j in range(2): + vec_ij = hp.ang2vec(theta[i, j], phi[i, j]) + assert_allclose(vec[i, j], vec_ij, rtol=1e-14) + + +def test_pix2vec_multidim_batch(): + """pix2vec should produce (*batch, 3) output for multi-dim pixel arrays.""" + nside = 8 + pixels = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) # shape (3, 4) + vec = hp.pix2vec(nside, pixels) + assert vec.shape == (3, 4, 3) + # Verify against flat computation + vec_flat = hp.pix2vec(nside, pixels.ravel()) + assert_allclose(vec.reshape(-1, 3), vec_flat, rtol=1e-14) + + +def test_vec2ang_ang2vec_roundtrip_multidim(): + """Round-trip vec2ang(ang2vec(theta, phi)) should preserve multi-dim shapes.""" + theta = np.array([[0.5, 1.0], [1.5, 2.0]]) # (2, 2) + phi = np.array([[0.1, 1.1], [2.1, 3.1]]) + vec = hp.ang2vec(theta, phi) + assert vec.shape == (2, 2, 3) + theta_rt, phi_rt = hp.vec2ang(vec) + assert theta_rt.shape == (2, 2) + assert phi_rt.shape == (2, 2) + assert_allclose(theta_rt, theta, rtol=1e-14, atol=1e-15) + assert_allclose(phi_rt, phi, rtol=1e-14, atol=1e-15) + + +def test_get_all_neighbours_multidim_batch(): + """get_all_neighbours should produce (8, *batch) for multi-dim theta/phi.""" + nside = 8 + theta = np.array([[0.5, 1.0, 1.5], [2.0, 2.5, 0.3]]) # (2, 3) + phi = np.array([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]) + neighbors = hp.get_all_neighbours(nside, theta, phi) + assert neighbors.shape == (8, 2, 3) + # Verify against flat computation + neighbors_flat = hp.get_all_neighbours(nside, theta.ravel(), phi.ravel()) + assert_allclose(neighbors.reshape(8, -1), neighbors_flat, rtol=1e-14)