Skip to content

Conversation

@MalyalaKarthik66
Copy link
Contributor

@MalyalaKarthik66 MalyalaKarthik66 commented Nov 4, 2025

fix #21813

Add adaptive pooling support across major backends

This PR adds adaptive pooling support in Keras 3 across all major backends — JAX, TensorFlow, PyTorch, and NumPy.
It introduces both average and max variants with automatic kernel size and stride computation, fully aligned with PyTorch behavior.

Changes

New Layers

  • AdaptiveAveragePooling2D: Performs adaptive average pooling to a target spatial size.
  • AdaptiveMaxPooling2D: Performs adaptive max pooling to a target spatial size.

Backend Implementations

  • JAX: Fully implemented adaptive pooling logic (matching PyTorch’s numerical behavior).
  • TensorFlow, PyTorch, and NumPy: Delegated to native ops where available.

Ops API

  • keras.ops.adaptive_avg_pool
  • keras.ops.adaptive_max_pool

Tests

  • Comprehensive unit tests for both layers.
  • PyTorch numerical parity verification (ensures same results as PyTorch).
  • Support for both channels_first and channels_last data formats.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances Keras 3 by adding adaptive average and max pooling layers for 2D spatial data. These new layers allow users to specify a target output size, with the pooling kernel and stride automatically adjusted, providing greater flexibility in network architectures, particularly for tasks requiring fixed-size feature maps regardless of input dimensions. The implementation prioritizes the JAX backend while ensuring seamless integration with other Keras backends.

Highlights

  • New Adaptive Pooling Layers: Introduced AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers to Keras 3, allowing for a fixed output size regardless of input dimensions.
  • JAX Backend Implementation: Provided a PyTorch-compatible implementation for these adaptive pooling operations specifically for the JAX backend.
  • Unified Ops API: Exposed new keras.ops.adaptive_avg_pool and keras.ops.adaptive_max_pool functions for backend-agnostic usage.
  • Comprehensive Testing: Included extensive unit tests, numerical parity checks against PyTorch, and support for both channels_first and channels_last data formats.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers, along with their corresponding backend operations. The changes include the layer definitions, JAX backend implementations, ops API, and comprehensive tests. The layer APIs and tests are well-designed. However, the JAX backend implementation has significant performance issues due to the use of Python loops, which are not JIT-compatible. There are also opportunities to improve code quality by removing dead code and reducing duplication. My review provides specific feedback on these points.

Comment on lines 1515 to 1533
for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply average pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Average over spatial dimensions (axis 1, 2)
pooled = jnp.mean(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Average over spatial dimensions (axis 2, 3)
pooled = jnp.mean(region, axis=(2, 3))

result_list.append(pooled)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of adaptive pooling uses Python for loops to iterate over output positions. This is an anti-pattern in JAX as it prevents JIT compilation and leads to very poor performance, especially for larger inputs or output sizes. The computation should be expressed using JAX's vectorized operations or JIT-compatible loops like lax.fori_loop to achieve good performance. A fully vectorized einsum-based approach for average pooling, or a lax.fori_loop over output pixels for both pooling types, would be significantly more performant. This comment also applies to the adaptive_max_pool implementation.

Comment on lines 1469 to 1478
def _adaptive_pool_start_index(output_idx, output_size, input_size):
"""Calculate start index for adaptive pooling (PyTorch compatible)."""
return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32)


def _adaptive_pool_end_index(output_idx, output_size, input_size):
"""Calculate end index for adaptive pooling (PyTorch compatible)."""
return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype(
jnp.int32
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper functions _adaptive_pool_start_index and _adaptive_pool_end_index are defined but not used. This dead code should be removed to improve code clarity.

Comment on lines 1481 to 1618
def adaptive_avg_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive average pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply average pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Average over spatial dimensions (axis 1, 2)
pooled = jnp.mean(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Average over spatial dimensions (axis 2, 3)
pooled = jnp.mean(region, axis=(2, 3))

result_list.append(pooled)

# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output


def adaptive_max_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive max pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply max pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Max over spatial dimensions (axis 1, 2)
pooled = jnp.max(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Max over spatial dimensions (axis 2, 3)
pooled = jnp.max(region, axis=(2, 3))

result_list.append(pooled)

# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions adaptive_avg_pool and adaptive_max_pool are nearly identical, with the only difference being the pooling operation (jnp.mean vs jnp.max). This code duplication can be avoided by creating a generic _adaptive_pool helper function that takes the pooling function as an argument. This would improve maintainability and reduce redundancy.

For example:

def _adaptive_pool(inputs, output_size, data_format, pool_op):
    # ... common setup code ...
    for i in range(out_h):
        for j in range(out_w):
            # ... common region calculation ...
            if data_format == "channels_last":
                region = inputs[:, start_h:end_h, start_w:end_w, :]
                pooled = pool_op(region, axis=(1, 2))
            else:  # channels_first
                region = inputs[:, :, start_h:end_h, start_w:end_w]
                pooled = pool_op(region, axis=(2, 3))
            result_list.append(pooled)
    # ... common reshape and transpose code ...
    return output

def adaptive_avg_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.mean)

def adaptive_max_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.max)

Note that this refactoring suggestion still contains the performance issue mentioned in another comment. The primary goal here is to illustrate how to reduce code duplication.

@codecov-commenter
Copy link

codecov-commenter commented Nov 4, 2025

Codecov Report

❌ Patch coverage is 71.36752% with 67 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.75%. Comparing base (6d06085) to head (f830e93).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/nn.py 4.44% 43 Missing ⚠️
keras/src/layers/pooling/adaptive_max_pooling2d.py 71.42% 4 Missing and 4 partials ⚠️
keras/src/backend/tensorflow/nn.py 92.59% 2 Missing and 2 partials ⚠️
...s/src/layers/pooling/adaptive_average_pooling2d.py 85.71% 2 Missing and 2 partials ⚠️
keras/src/ops/nn.py 63.63% 2 Missing and 2 partials ⚠️
keras/src/backend/jax/nn.py 92.30% 1 Missing and 1 partial ⚠️
keras/src/backend/numpy/nn.py 94.11% 1 Missing and 1 partial ⚠️

❗ There is a different number of reports uploaded between BASE (6d06085) and HEAD (f830e93). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (6d06085) HEAD (f830e93)
keras 5 3
keras-openvino 1 0
keras-torch 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21820      +/-   ##
==========================================
- Coverage   82.66%   73.75%   -8.92%     
==========================================
  Files         577      580       +3     
  Lines       59419    59672     +253     
  Branches     9313     9362      +49     
==========================================
- Hits        49121    44011    -5110     
- Misses       7898    13420    +5522     
+ Partials     2400     2241     -159     
Flag Coverage Δ
keras 73.67% <71.36%> (-8.82%) ⬇️
keras-jax 63.21% <37.17%> (-0.12%) ⬇️
keras-numpy 57.49% <41.02%> (-0.08%) ⬇️
keras-openvino ?
keras-tensorflow 64.07% <48.71%> (-0.06%) ⬇️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add AdaptivePooling - Avg/Max

3 participants