Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
58 changes: 58 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,3 +1464,61 @@ def _pair(x):
# ---- reshape -> (N, C*kH*kW, L) ----
_, CKK, oH, oW = patches.shape
return patches.reshape(N, CKK, oH * oW)


def _adaptive_pool(
inputs, output_size, data_format="channels_first", pool_fn=jnp.mean
):
"""
Optimized adaptive pooling for JAX backend, fully vectorized and
tracer-safe.
"""
if isinstance(output_size, int):
output_size = (output_size, output_size)
out_h, out_w = output_size

# Handle data format
if data_format == "channels_last":
inputs = jnp.transpose(inputs, (0, 3, 1, 2)) # NHWC → NCHW
n, c, h, w = inputs.shape

# Precompute static pooling bins as concrete numpy arrays (not traced)
h_bins = [
(int(jnp.floor(i * h / out_h)), int(jnp.ceil((i + 1) * h / out_h)))
for i in range(out_h)
]
w_bins = [
(int(jnp.floor(j * w / out_w)), int(jnp.ceil((j + 1) * w / out_w)))
for j in range(out_w)
]

# Define pooling over one image (C,H,W)
def pool_single_image(img):
pooled_rows = []
for hs, he in h_bins:
pooled_cols = []
for ws, we in w_bins:
region = img[:, hs:he, ws:we]
pooled_cols.append(pool_fn(region, axis=(1, 2)))
pooled_rows.append(jnp.stack(pooled_cols, axis=-1))
return jnp.stack(pooled_rows, axis=-2) # (C, out_h, out_w)

# Vectorize over batch
outputs = jax.vmap(pool_single_image)(inputs) # (N, C, out_h, out_w)

# Convert back if channels_last
if data_format == "channels_last":
outputs = jnp.transpose(outputs, (0, 2, 3, 1))
return outputs


def adaptive_avg_pool(
inputs, output_size, data_format="channels_first", name=None
):
return _adaptive_pool(inputs, output_size, data_format, pool_fn=jnp.mean)


def adaptive_max_pool(
inputs, output_size, data_format="channels_first", name=None
):
return _adaptive_pool(inputs, output_size, data_format, pool_fn=jnp.max)
59 changes: 59 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,3 +1237,62 @@ def _pair(x):

# ---- reshape -> (N, C*kH*kW, L) ----
return patches.reshape(N, C * k[0] * k[1], -1)


def _adaptive_pool2d(inputs, output_size, mode="avg", data_format=None):
"""Adaptive pooling for 2D inputs."""
from keras.src import backend

data_format = backend.standardize_data_format(data_format)
x = convert_to_tensor(inputs)

if isinstance(output_size, int):
out_h = out_w = int(output_size)
else:
out_h, out_w = output_size

if data_format == "channels_last":
N, H, W, C = x.shape
x_nchw = np.transpose(x, (0, 3, 1, 2))
else:
N, C, H, W = x.shape
x_nchw = x

out = np.empty((N, C, out_h, out_w), dtype=x.dtype)

for i in range(out_h):
h_start = int(np.floor(i * H / out_h))
h_end = int(np.ceil((i + 1) * H / out_h))
h_start = max(0, min(h_start, H - 1))
h_end = max(h_start + 1, min(h_end, H))

for j in range(out_w):
w_start = int(np.floor(j * W / out_w))
w_end = int(np.ceil((j + 1) * W / out_w))
w_start = max(0, min(w_start, W - 1))
w_end = max(w_start + 1, min(w_end, W))

patch = x_nchw[:, :, h_start:h_end, w_start:w_end]

if mode == "avg":
out[:, :, i, j] = np.mean(patch, axis=(2, 3))
else:
out[:, :, i, j] = np.max(patch, axis=(2, 3))

if data_format == "channels_last":
return np.transpose(out, (0, 2, 3, 1))
return out


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling 2D wrapper."""
return _adaptive_pool2d(
inputs, output_size, mode="avg", data_format=data_format
)


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling 2D wrapper."""
return _adaptive_pool2d(
inputs, output_size, mode="max", data_format=data_format
)
16 changes: 16 additions & 0 deletions keras/src/backend/openvino/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def max_pool(
)


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling - OpenVINO backend not yet supported."""
raise NotImplementedError(
"adaptive_max_pool is not yet supported for OpenVINO backend. "
"Please use JAX, NumPy, PyTorch, or TensorFlow backend."
)


def average_pool(
inputs,
pool_size,
Expand All @@ -145,6 +153,14 @@ def average_pool(
)


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling - OpenVINO backend not yet supported."""
raise NotImplementedError(
"adaptive_avg_pool is not yet supported for OpenVINO backend. "
"Please use JAX, NumPy, PyTorch, or TensorFlow backend."
)


def _adjust_strides_dilation(
x,
num_spatial_dims,
Expand Down
84 changes: 84 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,48 @@ def max_pool(
return outputs


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling 2D for TensorFlow backend."""
import tensorflow as tf

from keras.src import backend

data_format = backend.standardize_data_format(data_format)
x = tf.convert_to_tensor(inputs)

if isinstance(output_size, int):
out_h = out_w = int(output_size)
else:
out_h, out_w = output_size

if data_format == "channels_last":
N, H, W, C = x.shape
x_nchw = tf.transpose(x, [0, 3, 1, 2])
else:
N, C, H, W = x.shape
x_nchw = x

result_list = []
for i in range(out_h):
for j in range(out_w):
h_start = int(tf.math.floor(i * H / out_h))
h_end = int(tf.math.ceil((i + 1) * H / out_h))
w_start = int(tf.math.floor(j * W / out_w))
w_end = int(tf.math.ceil((j + 1) * W / out_w))

patch = x_nchw[:, :, h_start:h_end, w_start:w_end]
pooled = tf.reduce_max(patch, axis=[2, 3])
result_list.append(pooled)

output = tf.stack(result_list, axis=0)
output = tf.reshape(output, [out_h, out_w, N, C])
output = tf.transpose(
output, [2, 0, 1, 3] if data_format == "channels_last" else [2, 3, 0, 1]
)

return output


def average_pool(
inputs,
pool_size,
Expand Down Expand Up @@ -268,6 +310,48 @@ def average_pool(
return outputs


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling 2D for TensorFlow backend."""
import tensorflow as tf

from keras.src import backend

data_format = backend.standardize_data_format(data_format)
x = tf.convert_to_tensor(inputs)

if isinstance(output_size, int):
out_h = out_w = int(output_size)
else:
out_h, out_w = output_size

if data_format == "channels_last":
N, H, W, C = x.shape
x_nchw = tf.transpose(x, [0, 3, 1, 2])
else:
N, C, H, W = x.shape
x_nchw = x

result_list = []
for i in range(out_h):
for j in range(out_w):
h_start = int(tf.math.floor(i * H / out_h))
h_end = int(tf.math.ceil((i + 1) * H / out_h))
w_start = int(tf.math.floor(j * W / out_w))
w_end = int(tf.math.ceil((j + 1) * W / out_w))

patch = x_nchw[:, :, h_start:h_end, w_start:w_end]
pooled = tf.reduce_mean(patch, axis=[2, 3])
result_list.append(pooled)

output = tf.stack(result_list, axis=0)
output = tf.reshape(output, [out_h, out_w, N, C])
output = tf.transpose(
output, [2, 0, 1, 3] if data_format == "channels_last" else [2, 3, 0, 1]
)

return output


def _convert_data_format(data_format, ndim):
if data_format == "channels_last":
if ndim == 3:
Expand Down
88 changes: 88 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,51 @@ def max_pool(
return outputs


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling (1D/2D/3D) with channels_last support."""
inputs = convert_to_tensor(inputs)
num_spatial_dims = inputs.ndim - 2

data_format = backend.standardize_data_format(data_format)
orig_format = data_format
if data_format == "channels_last":
inputs = _transpose_spatial_inputs(inputs)

if isinstance(output_size, int):
torch_output_size = (
output_size
if num_spatial_dims == 1
else (output_size,) * num_spatial_dims
)
else:
torch_output_size = standardize_tuple(
output_size, num_spatial_dims, "output_size"
)

if get_device() == "meta":
inputs = torch.empty(
size=inputs.shape, dtype=inputs.dtype, device="cpu"
)

if num_spatial_dims == 1:
res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
elif num_spatial_dims == 2:
res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
elif num_spatial_dims == 3:
res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
else:
raise ValueError(
"Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
f"Received input shape: {inputs.shape}."
)

outputs = res[0] if isinstance(res, tuple) else res

if orig_format == "channels_last":
outputs = _transpose_spatial_outputs(outputs)
return outputs


def average_pool(
inputs,
pool_size,
Expand Down Expand Up @@ -458,6 +503,49 @@ def average_pool(
return outputs


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling (1D/2D/3D) with channels_last support."""
inputs = convert_to_tensor(inputs)
num_spatial_dims = inputs.ndim - 2

data_format = backend.standardize_data_format(data_format)
orig_format = data_format
if data_format == "channels_last":
inputs = _transpose_spatial_inputs(inputs)

if isinstance(output_size, int):
torch_output_size = (
output_size
if num_spatial_dims == 1
else (output_size,) * num_spatial_dims
)
else:
torch_output_size = standardize_tuple(
output_size, num_spatial_dims, "output_size"
)

if get_device() == "meta":
inputs = torch.empty(
size=inputs.shape, dtype=inputs.dtype, device="cpu"
)

if num_spatial_dims == 1:
outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
elif num_spatial_dims == 2:
outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
elif num_spatial_dims == 3:
outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
else:
raise ValueError(
"Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
f"Received input shape: {inputs.shape}."
)

if orig_format == "channels_last":
outputs = _transpose_spatial_outputs(outputs)
return outputs


def conv(
inputs,
kernel,
Expand Down
4 changes: 4 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
SpectralNormalization,
)
from keras.src.layers.normalization.unit_normalization import UnitNormalization
from keras.src.layers.pooling.adaptive_average_pooling2d import (
AdaptiveAveragePooling2D,
)
from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
from keras.src.layers.pooling.average_pooling1d import AveragePooling1D
from keras.src.layers.pooling.average_pooling2d import AveragePooling2D
from keras.src.layers.pooling.average_pooling3d import AveragePooling3D
Expand Down
4 changes: 4 additions & 0 deletions keras/src/layers/pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from keras.src.layers.pooling.adaptive_average_pooling2d import (
AdaptiveAveragePooling2D,
)
from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
Loading
Loading