Skip to content

Raise a clear error when pooling window_shape has too many dimensions#5478

Open
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:fix-pooling-batch-dim-reduction
Open

Raise a clear error when pooling window_shape has too many dimensions#5478
adityasingh2400 wants to merge 1 commit into
google:mainfrom
adityasingh2400:fix-pooling-batch-dim-reduction

Conversation

@adityasingh2400
Copy link
Copy Markdown

Fixes #4494

The pooling helpers (avg_pool, max_pool, min_pool) expect inputs laid out as (batch_dims..., spatial_dims..., features) and a window_shape with one entry per spatial dimension. Internally pool computes num_batch_dims as inputs.ndim - (len(window_shape) + 1).

When window_shape has more entries than the input has spatial dimensions, num_batch_dims goes negative. The function then keeps going and the user hits a confusing internal assertion. For example, max_pool(jnp.zeros((4, 3)), (2, 2, 2)) currently fails with AssertionError: len((4, 3)) != len((2, 2, 2, 1)), which gives no hint about the expected dimension layout. The linked issue describes exactly this confusion: it is not clear from the behavior or the error how window_shape relates to the input dimensions.

This change adds an explicit guard at the top of pool that raises a ValueError describing the expected (batch_dims..., spatial_dims..., features) layout and the maximum number of window_shape entries allowed for the given input. Valid pooling calls are unaffected since the new branch only triggers when window_shape was already too long to produce a meaningful result.

Verification: added tests/linen/linen_test.py::PoolTest::test_pooling_window_shape_too_long_raises which fails on main with the old cryptic assertion and passes with this change. Ran python -m pytest tests/linen/linen_test.py::PoolTest -q on CPU, all 11 cases pass. ruff check on the changed files reports no issues.

The pool helper computes num_batch_dims as inputs.ndim - (len(window_shape)
+ 1). When window_shape has more entries than the input has spatial
dimensions this value becomes negative, and the user hits a confusing
internal assertion such as len((4, 3)) != len((2, 2, 2, 1)) instead of an
actionable message.

This adds an explicit check that raises a ValueError describing the
expected (batch_dims..., spatial_dims..., features) layout and the maximum
number of window_shape entries for the given input. A regression test
covers the new behavior.

Fixes google#4494
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Pool functions reduce over batch dimension and not last dimension

1 participant