Skip to content

maintain shapes for ang2vec and vec2ang functions for batched inputs#2

Open
ASKabalan wants to merge 5 commits intomainfrom
maintaint-shapes-batching
Open

maintain shapes for ang2vec and vec2ang functions for batched inputs#2
ASKabalan wants to merge 5 commits intomainfrom
maintaint-shapes-batching

Conversation

@ASKabalan
Copy link
Copy Markdown
Member

@ASKabalan ASKabalan commented Feb 11, 2026

This pull request improves support for multi-dimensional (batch) inputs in the core HEALPix coordinate and neighbor functions, ensuring that all relevant functions preserve batch shapes and work seamlessly with arrays of arbitrary shape. The changes also update the documentation and add comprehensive tests to verify correct behavior for batched inputs.

Batch shape support and documentation updates:

  • Updated ang2vec, pix2vec, and vec2ang to return outputs with shapes that preserve input batch dimensions, using jnp.stack(..., axis=-1) instead of .T and updating docstrings to clarify output shapes. [1] [2] [3] [4]
  • Refactored get_all_neighbours and internal neighbor-finding logic to operate on arrays with arbitrary batch dimensions, ensuring outputs have shapes like (8, *batch) or (9, *batch) as appropriate. [1] [2] [3] [4] [5] [6]
  • Updated docstrings throughout to clarify input and output shapes for batch support. [1] [2] [3]

Testing improvements:

  • Added new tests to tests/pixelfunc/test_ang_vec.py that verify all major functions—ang2vec, vec2ang, pix2vec, and get_all_neighbours—correctly handle and preserve arbitrary batch dimensions, including round-trip and consistency checks.

Bug fixes and code cleanup:

  • Fixed shape handling in vec2ang and get_all_neighbours to avoid unnecessary flattening and reshaping, improving clarity and correctness for both scalar and array inputs. [1] [2] [3]

@ASKabalan
Copy link
Copy Markdown
Member Author

Goal

The main goal is to avoid excessive and unneeded reshaping because in distributed setup .. jax healpy can work out of the box with no extra work on our part ..

However if you reshape this can cause at worse a communication but most of the time it causes a local copy or a bitcast to reorder the array (so an unnessecary copy)

example

import jax
jax.config.update('jax_num_cpu_devices', 4)
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import jax_healpy as jhp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.sharding import AxisType

mesh = jax.make_mesh((2, 2), ("X", "Y"),
                     axis_types=(AxisType.Auto, AxisType.Auto))

sharding = NamedSharding(mesh, P("X", "Y"))
points = jax.random.normal(jax.random.PRNGKey(0), (4, 4 ,4 ,3))
points = lax.with_sharding_constraint(points, sharding)
print(f"Sharding of points is {points.sharding.spec}")


def fn(points):
    theta, phi = jhp.vec2ang(points)
    print(f"Shape and sharding of theta is {theta.shape}, {theta.sharding.spec}")
    print(f"Shape and sharding of phi is {phi.shape}, {phi.sharding.spec}")
    return theta, phi

a = fn(points)

With old code this gives

Sharding of points is PartitionSpec('X', 'Y')
Shape and sharding of theta is (64,), PartitionSpec('X',)
Shape and sharding of phi is (64,), PartitionSpec('X',)

With new code this gives

Sharding of points is PartitionSpec('X', 'Y')
Shape and sharding of theta is (4, 4, 4), PartitionSpec('X', 'Y')
Shape and sharding of phi is (4, 4, 4), PartitionSpec('X', 'Y')

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves support for batched (multi-dimensional) inputs across core coordinate conversion and neighbor-query APIs so outputs preserve the input batch shape.

Changes:

  • Update pix2vec, ang2vec, and vec2ang to return arrays shaped like (*batch, 3) / (*batch,) by using jnp.stack(..., axis=-1) and ... indexing.
  • Refactor get_all_neighbours and internal neighbor logic to operate element-wise on arbitrary batch shapes and return (8, *batch) / (9, *batch) outputs.
  • Add/adjust tests to verify correct shape preservation and consistency for multi-dimensional batches.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
jax_healpy/pixelfunc.py Shape-preserving outputs for vector/angle conversions and batched neighbor finding.
tests/pixelfunc/test_ang_vec.py New/updated tests covering multi-dimensional batch shapes and round-trips.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread jax_healpy/pixelfunc.py
Comment thread jax_healpy/pixelfunc.py
Comment thread jax_healpy/pixelfunc.py Outdated
Comment on lines +2304 to +2306
Neighbor coordinates that may be outside face boundaries. Shape: (8, *batch)
neighbor_face : Array
Face numbers for neighbors (initially same as original). Shape: (8, N)
Face numbers for neighbors (initially same as original). Shape: (8, *batch)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neighbor_face is documented as an input to _handle_face_boundaries, but the current implementation never references it (it uses original_face/computed new_face instead). Either remove neighbor_face from the signature/docstring (and call site) or incorporate it into the correction logic so the function interface matches what’s actually used.

Copilot uses AI. Check for mistakes.
ASKabalan and others added 3 commits February 11, 2026 11:07
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread jax_healpy/pixelfunc.py
Comment thread jax_healpy/pixelfunc.py Outdated
Comment on lines +2197 to +2201
else:
# Original behavior: return only 8 neighbors for backward compatibility
# Reshape result to (8, *input_shape)
# Original behavior: return only 8 neighbors
if input_shape == ():
# Scalar input - should return shape (8,), not (8, 1)
return neighbors_flat.squeeze() # Remove the extra dimension
else:
# Array input - reshape from (8, N) to (8, *input_shape)
return neighbors_flat.reshape((8,) + input_shape)
return neighbors.squeeze()
return neighbors
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_all_neighbours now preserves arbitrary batch shapes (returns (8, *input_shape) / (9, *input_shape)), but the docstring currently describes array outputs as (8, N) / (9, N). Please update the docstring to reflect generalized batch behavior.

Copilot uses AI. Check for mistakes.
@ASKabalan
Copy link
Copy Markdown
Member Author

@pchanial can be reviewed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants