Raise a clear error when pooling window_shape has too many dimensions#5478
Open
adityasingh2400 wants to merge 1 commit into
Open
Raise a clear error when pooling window_shape has too many dimensions#5478adityasingh2400 wants to merge 1 commit into
adityasingh2400 wants to merge 1 commit into
Conversation
The pool helper computes num_batch_dims as inputs.ndim - (len(window_shape) + 1). When window_shape has more entries than the input has spatial dimensions this value becomes negative, and the user hits a confusing internal assertion such as len((4, 3)) != len((2, 2, 2, 1)) instead of an actionable message. This adds an explicit check that raises a ValueError describing the expected (batch_dims..., spatial_dims..., features) layout and the maximum number of window_shape entries for the given input. A regression test covers the new behavior. Fixes google#4494
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #4494
The pooling helpers (avg_pool, max_pool, min_pool) expect inputs laid out as (batch_dims..., spatial_dims..., features) and a window_shape with one entry per spatial dimension. Internally pool computes num_batch_dims as inputs.ndim - (len(window_shape) + 1).
When window_shape has more entries than the input has spatial dimensions, num_batch_dims goes negative. The function then keeps going and the user hits a confusing internal assertion. For example, max_pool(jnp.zeros((4, 3)), (2, 2, 2)) currently fails with AssertionError: len((4, 3)) != len((2, 2, 2, 1)), which gives no hint about the expected dimension layout. The linked issue describes exactly this confusion: it is not clear from the behavior or the error how window_shape relates to the input dimensions.
This change adds an explicit guard at the top of pool that raises a ValueError describing the expected (batch_dims..., spatial_dims..., features) layout and the maximum number of window_shape entries allowed for the given input. Valid pooling calls are unaffected since the new branch only triggers when window_shape was already too long to produce a meaningful result.
Verification: added tests/linen/linen_test.py::PoolTest::test_pooling_window_shape_too_long_raises which fails on main with the old cryptic assertion and passes with this change. Ran python -m pytest tests/linen/linen_test.py::PoolTest -q on CPU, all 11 cases pass. ruff check on the changed files reports no issues.