-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers #21820
Conversation
Summary of ChangesHello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances Keras 3 by adding adaptive average and max pooling layers for 2D spatial data. These new layers allow users to specify a target output size, with the pooling kernel and stride automatically adjusted, providing greater flexibility in network architectures, particularly for tasks requiring fixed-size feature maps regardless of input dimensions. The implementation prioritizes the JAX backend while ensuring seamless integration with other Keras backends. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers, along with their corresponding backend operations. The changes include the layer definitions, JAX backend implementations, ops API, and comprehensive tests. The layer APIs and tests are well-designed. However, the JAX backend implementation has significant performance issues due to the use of Python loops, which are not JIT-compatible. There are also opportunities to improve code quality by removing dead code and reducing duplication. My review provides specific feedback on these points.
keras/src/backend/jax/nn.py
Outdated
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply average pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Average over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.mean(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Average over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.mean(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of adaptive pooling uses Python for loops to iterate over output positions. This is an anti-pattern in JAX as it prevents JIT compilation and leads to very poor performance, especially for larger inputs or output sizes. The computation should be expressed using JAX's vectorized operations or JIT-compatible loops like lax.fori_loop to achieve good performance. A fully vectorized einsum-based approach for average pooling, or a lax.fori_loop over output pixels for both pooling types, would be significantly more performant. This comment also applies to the adaptive_max_pool implementation.
keras/src/backend/jax/nn.py
Outdated
| def _adaptive_pool_start_index(output_idx, output_size, input_size): | ||
| """Calculate start index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32) | ||
|
|
||
|
|
||
| def _adaptive_pool_end_index(output_idx, output_size, input_size): | ||
| """Calculate end index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype( | ||
| jnp.int32 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras/src/backend/jax/nn.py
Outdated
| def adaptive_avg_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive average pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply average pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Average over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.mean(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Average over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.mean(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def adaptive_max_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive max pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply max pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Max over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.max(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Max over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.max(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions adaptive_avg_pool and adaptive_max_pool are nearly identical, with the only difference being the pooling operation (jnp.mean vs jnp.max). This code duplication can be avoided by creating a generic _adaptive_pool helper function that takes the pooling function as an argument. This would improve maintainability and reduce redundancy.
For example:
def _adaptive_pool(inputs, output_size, data_format, pool_op):
# ... common setup code ...
for i in range(out_h):
for j in range(out_w):
# ... common region calculation ...
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
pooled = pool_op(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
pooled = pool_op(region, axis=(2, 3))
result_list.append(pooled)
# ... common reshape and transpose code ...
return output
def adaptive_avg_pool(inputs, output_size, data_format="channels_last", name=None):
# ...
return _adaptive_pool(inputs, output_size, data_format, jnp.mean)
def adaptive_max_pool(inputs, output_size, data_format="channels_last", name=None):
# ...
return _adaptive_pool(inputs, output_size, data_format, jnp.max)Note that this refactoring suggestion still contains the performance issue mentioned in another comment. The primary goal here is to illustrate how to reduce code duplication.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21820 +/- ##
==========================================
- Coverage 82.66% 73.75% -8.92%
==========================================
Files 577 580 +3
Lines 59419 59672 +253
Branches 9313 9362 +49
==========================================
- Hits 49121 44011 -5110
- Misses 7898 13420 +5522
+ Partials 2400 2241 -159
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…X, NumPy, PyTorch, and TensorFlow backends
fix #21813
Add adaptive pooling support across major backends
This PR adds adaptive pooling support in Keras 3 across all major backends — JAX, TensorFlow, PyTorch, and NumPy.
It introduces both average and max variants with automatic kernel size and stride computation, fully aligned with PyTorch behavior.
Changes
New Layers
Backend Implementations
Ops API
keras.ops.adaptive_avg_poolkeras.ops.adaptive_max_poolTests
channels_firstandchannels_lastdata formats.