maintain shapes for ang2vec and vec2ang functions for batched inputs#2
maintain shapes for ang2vec and vec2ang functions for batched inputs#2
ang2vec and vec2ang functions for batched inputs#2Conversation
GoalThe main goal is to avoid excessive and unneeded reshaping because in distributed setup .. jax healpy can work out of the box with no extra work on our part .. However if you reshape this can cause at worse a communication but most of the time it causes a local copy or a bitcast to reorder the array (so an unnessecary copy) example import jax
jax.config.update('jax_num_cpu_devices', 4)
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import jax_healpy as jhp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.sharding import AxisType
mesh = jax.make_mesh((2, 2), ("X", "Y"),
axis_types=(AxisType.Auto, AxisType.Auto))
sharding = NamedSharding(mesh, P("X", "Y"))
points = jax.random.normal(jax.random.PRNGKey(0), (4, 4 ,4 ,3))
points = lax.with_sharding_constraint(points, sharding)
print(f"Sharding of points is {points.sharding.spec}")
def fn(points):
theta, phi = jhp.vec2ang(points)
print(f"Shape and sharding of theta is {theta.shape}, {theta.sharding.spec}")
print(f"Shape and sharding of phi is {phi.shape}, {phi.sharding.spec}")
return theta, phi
a = fn(points)
With old code this gives With new code this gives |
There was a problem hiding this comment.
Pull request overview
Improves support for batched (multi-dimensional) inputs across core coordinate conversion and neighbor-query APIs so outputs preserve the input batch shape.
Changes:
- Update
pix2vec,ang2vec, andvec2angto return arrays shaped like(*batch, 3)/(*batch,)by usingjnp.stack(..., axis=-1)and...indexing. - Refactor
get_all_neighboursand internal neighbor logic to operate element-wise on arbitrary batch shapes and return(8, *batch)/(9, *batch)outputs. - Add/adjust tests to verify correct shape preservation and consistency for multi-dimensional batches.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
jax_healpy/pixelfunc.py |
Shape-preserving outputs for vector/angle conversions and batched neighbor finding. |
tests/pixelfunc/test_ang_vec.py |
New/updated tests covering multi-dimensional batch shapes and round-trips. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Neighbor coordinates that may be outside face boundaries. Shape: (8, *batch) | ||
| neighbor_face : Array | ||
| Face numbers for neighbors (initially same as original). Shape: (8, N) | ||
| Face numbers for neighbors (initially same as original). Shape: (8, *batch) |
There was a problem hiding this comment.
neighbor_face is documented as an input to _handle_face_boundaries, but the current implementation never references it (it uses original_face/computed new_face instead). Either remove neighbor_face from the signature/docstring (and call site) or incorporate it into the correction logic so the function interface matches what’s actually used.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| # Original behavior: return only 8 neighbors for backward compatibility | ||
| # Reshape result to (8, *input_shape) | ||
| # Original behavior: return only 8 neighbors | ||
| if input_shape == (): | ||
| # Scalar input - should return shape (8,), not (8, 1) | ||
| return neighbors_flat.squeeze() # Remove the extra dimension | ||
| else: | ||
| # Array input - reshape from (8, N) to (8, *input_shape) | ||
| return neighbors_flat.reshape((8,) + input_shape) | ||
| return neighbors.squeeze() | ||
| return neighbors |
There was a problem hiding this comment.
get_all_neighbours now preserves arbitrary batch shapes (returns (8, *input_shape) / (9, *input_shape)), but the docstring currently describes array outputs as (8, N) / (9, N). Please update the docstring to reflect generalized batch behavior.
|
@pchanial can be reviewed |
This pull request improves support for multi-dimensional (batch) inputs in the core HEALPix coordinate and neighbor functions, ensuring that all relevant functions preserve batch shapes and work seamlessly with arrays of arbitrary shape. The changes also update the documentation and add comprehensive tests to verify correct behavior for batched inputs.
Batch shape support and documentation updates:
ang2vec,pix2vec, andvec2angto return outputs with shapes that preserve input batch dimensions, usingjnp.stack(..., axis=-1)instead of.Tand updating docstrings to clarify output shapes. [1] [2] [3] [4]get_all_neighboursand internal neighbor-finding logic to operate on arrays with arbitrary batch dimensions, ensuring outputs have shapes like(8, *batch)or(9, *batch)as appropriate. [1] [2] [3] [4] [5] [6]Testing improvements:
tests/pixelfunc/test_ang_vec.pythat verify all major functions—ang2vec,vec2ang,pix2vec, andget_all_neighbours—correctly handle and preserve arbitrary batch dimensions, including round-trip and consistency checks.Bug fixes and code cleanup:
vec2angandget_all_neighboursto avoid unnecessary flattening and reshaping, improving clarity and correctness for both scalar and array inputs. [1] [2] [3]