Skip to content

Weird tensor shape when computing the gradient of batched grouped convolutions with a custom VJP #34639

@LambdaP

Description

@LambdaP

Description

For the past few months I have occasionally encountered various performance issues when training convolutional neural networks using custom VJP rules, like slow executions and out-of-memory errors. I believe these issues to be connected to the minimal example described below.

Fundamentally, the issues seems to be that when grouped convolutions appear in the backward pass of a custom VJP, which is called under jax.vmap(), JAX will process the input as a tensor of shape NC×1×H×W instead of N×C×H×W, where N is the batch size, C is the features count, and H and W are the spatial dimensions. For an example of an issue that this can cause, this, in turn, leads cudnn to struggle with optimizing the convolution, as the tensor has an unusual shape.

This behavior appears whenever a custom VJP rule is defined, and is not specific to the custom VJP rule. The code below can be used to generate the Jaxpr IR for a grouped convolution that does not use a custom VJP rule, and for a convolution with the same kernel that uses a trivial custom VJP rule that simply calls the underlying VJP.

import jax
import jax.numpy as jnp


def conv_kernel(k, c, groups=1):
    assert c % groups == 0
    per_group = c // groups
    return jnp.ones((c, per_group, k, k))


def conv2d(kernel, x):
    c, p, _, _ = kernel.shape
    groups = c // p

    x = jnp.expand_dims(x, axis=0)

    out = jax.lax.conv_general_dilated(
        lhs=x,
        rhs=kernel,
        window_strides=(1, 1),
        padding="SAME",
        feature_group_count=groups,
    )

    return jnp.squeeze(out, axis=0)


@jax.custom_vjp
def wrapped_conv2d(kernel, x):
    return conv2d(kernel, x)


def wrapped_conv2d_fwd(kernel, x):
    y, vjp_fn = jax.vjp(conv2d, kernel, x)
    return y, vjp_fn


def wrapped_conv2d_bwd(residuals, g):
    vjp_fn = residuals

    g_out = vjp_fn(g)
    return g_out


wrapped_conv2d.defvjp(wrapped_conv2d_fwd, wrapped_conv2d_bwd)


def loss_base(kernel, images, labels):
    def model(x):
        x = conv2d(kernel, x)
        logit = jnp.mean(x)
        return jax.nn.sigmoid(logit)

    preds = jax.vmap(model)(images)

    return jnp.mean((preds - labels) ** 2)


def loss_custom(kernel, images, labels):
    def model(x):
        x = wrapped_conv2d(kernel, x)
        logit = jnp.mean(x)
        return jax.nn.sigmoid(logit)

    preds = jax.vmap(model)(images)

    return jnp.mean((preds - labels) ** 2)


batch_size = 32
channels = 256
height = 56
width = 56

kernel_size = 3
groups = 2

images = jnp.ones((batch_size, channels, height, width))
labels = jnp.ones((batch_size,))
kernel = conv_kernel(kernel_size, channels, groups)

loss_fn = jax.jit(jax.grad(loss_base))
traced = loss_fn.trace(kernel, images, labels.astype(jnp.float32))
print(traced.jaxpr)

loss_fn = jax.jit(jax.grad(loss_custom))
traced = loss_fn.trace(kernel, images, labels.astype(jnp.float32))
print(traced.jaxpr)

Here is the output of this script when ran on my computer:

{ lambda ; a:f32[256,128,3,3] b:f32[32,256,56,56] c:f32[32]. let
    d:f32[32,1,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
      shape=(32, 1, 256, 56, 56)
      sharding=None
    ] b
    e:f32[32,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 256, 56, 56)
      sharding=None
    ] d
    f:f32[32,256,56,56] = conv_general_dilated[
      batch_group_count=1
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))
      feature_group_count=2
      lhs_dilation=(1, 1)
      out_sharding=None
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] e a
    g:f32[32,1,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 1, 256, 56, 56)
      sharding=None
    ] f
    h:f32[32,256,56,56] = squeeze[dimensions=(1,)] g
    i:f32[32] = reduce_sum[
      axes=(np.int64(1), np.int64(2), np.int64(3))
      out_sharding=None
    ] h
    j:f32[32] = div i 802816.0:f32[]
    k:f32[32] = logistic j
    l:f32[32] = sub 1.0:f32[] k
    m:f32[32] = mul k l
    n:f32[32] = sub k c
    o:f32[32] = integer_pow[y=2] n
    p:f32[32] = integer_pow[y=1] n
    q:f32[32] = mul 2.0:f32[] p
    r:f32[] = reduce_sum[axes=(0,) out_sharding=None] o
    _:f32[] = div r 32.0:f32[]
    s:f32[] = div 1.0:f32[] 32.0:f32[]
    t:f32[32] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(32,)
      sharding=None
    ] s
    u:f32[32] = mul t q
    v:f32[32] = mul u m
    w:f32[32] = div v 802816.0:f32[]
    x:f32[32,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(np.int64(0),)
      shape=(32, 256, 56, 56)
      sharding=None
    ] w
    y:f32[32,1,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(0, 2, 3, 4)
      shape=(32, 1, 256, 56, 56)
      sharding=None
    ] x
    z:f32[32,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 256, 56, 56)
      sharding=None
    ] y
    ba:f32[256,128,3,3] = conv_general_dilated[
      batch_group_count=2
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3))
      feature_group_count=1
      lhs_dilation=(1, 1)
      out_sharding=None
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] e z
  in (ba,) }
{ lambda ; a:f32[256,128,3,3] b:f32[32,256,56,56] c:f32[32]. let
    d:f32[32,1,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
      shape=(32, 1, 256, 56, 56)
      sharding=None
    ] b
    e:f32[32,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 256, 56, 56)
      sharding=None
    ] d
    f:f32[32,256,56,56] = conv_general_dilated[
      batch_group_count=1
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))
      feature_group_count=2
      lhs_dilation=(1, 1)
      out_sharding=None
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] e a
    g:f32[32,1,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 1, 256, 56, 56)
      sharding=None
    ] f
    h:f32[32,256,56,56] = squeeze[dimensions=(1,)] g
    i:f32[32] = reduce_sum[
      axes=(np.int64(1), np.int64(2), np.int64(3))
      out_sharding=None
    ] h
    j:f32[32] = div i 802816.0:f32[]
    k:f32[32] = logistic j
    l:f32[32] = sub 1.0:f32[] k
    m:f32[32] = mul k l
    n:f32[32] = sub k c
    o:f32[32] = integer_pow[y=2] n
    p:f32[32] = integer_pow[y=1] n
    q:f32[32] = mul 2.0:f32[] p
    r:f32[] = reduce_sum[axes=(0,) out_sharding=None] o
    _:f32[] = div r 32.0:f32[]
    s:f32[] = div 1.0:f32[] 32.0:f32[]
    t:f32[32] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(32,)
      sharding=None
    ] s
    u:f32[32] = mul t q
    v:f32[32] = mul u m
    w:f32[32] = div v 802816.0:f32[]
    x:f32[32,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(np.int64(0),)
      shape=(32, 256, 56, 56)
      sharding=None
    ] w
    y:f32[32,1,256,56,56] = broadcast_in_dim[
      broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
      shape=(32, 1, 256, 56, 56)
      sharding=None
    ] x
    z:f32[1,8192,56,56] = reshape[
      dimensions=(1, 0, 2, 3, 4)
      new_sizes=(1, 8192, 56, 56)
      sharding=None
    ] d
    ba:f32[1,8192,56,56] = reshape[
      dimensions=(1, 0, 2, 3, 4)
      new_sizes=(1, 8192, 56, 56)
      sharding=None
    ] y
    bb:f32[8192,128,3,3] = conv_general_dilated[
      batch_group_count=64
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3))
      feature_group_count=1
      lhs_dilation=(1, 1)
      out_sharding=None
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] z ba
    bc:f32[32,256,128,3,3] = reshape[
      dimensions=None
      new_sizes=(32, 256, 128, 3, 3)
      sharding=None
    ] bb
    bd:f32[2,128,128,3,3] = reshape[
      dimensions=None
      new_sizes=(2, 128, 128, 3, 3)
      sharding=None
    ] a
    be:f32[128,256,3,3] = reshape[
      dimensions=(1, 0, 2, 3, 4)
      new_sizes=(128, 256, 3, 3)
      sharding=None
    ] bd
    bf:f32[128,256,3,3] = rev[dimensions=(2, 3)] be
    bg:f32[32,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 256, 56, 56)
      sharding=None
    ] y
    bh:f32[32,256,56,56] = conv_general_dilated[
      batch_group_count=1
      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(0, 1, 2, 3))
      feature_group_count=2
      lhs_dilation=(1, 1)
      out_sharding=None
      padding=((1, 1), (1, 1))
      precision=None
      preferred_element_type=None
      rhs_dilation=(1, 1)
      window_strides=(1, 1)
    ] bg bf
    bi:f32[32,1,256,56,56] = reshape[
      dimensions=None
      new_sizes=(32, 1, 256, 56, 56)
      sharding=None
    ] bh
    _:f32[32,256,56,56] = reduce_sum[axes=(np.int64(1),) out_sharding=None] bi
    bj:f32[256,128,3,3] = reduce_sum[axes=(0,) out_sharding=None] bc
  in (bj,) }

As you can see, the two traces are similar until the end, where the second trace goes into the following reshaping:

    z:f32[1,8192,56,56] = reshape[
      dimensions=(1, 0, 2, 3, 4)
      new_sizes=(1, 8192, 56, 56)
      sharding=None
    ] d
    ba:f32[1,8192,56,56] = reshape[
      dimensions=(1, 0, 2, 3, 4)
      new_sizes=(1, 8192, 56, 56)
      sharding=None
    ] y

(the second trace also includes an extra convolution in bh corresponding to the gradient of the loss with respect to the input, that is then discarded, but this is not what my issue is about).

I have observed this behavior in situations where the backward function of the custom VJP rule does not call jax.vjp() ; for example, with the (possibly incorrect) following re-implementation of the rule:

@jax.custom_vjp
def handwritten_vjp_conv2d(kernel, x):
    return conv2d(kernel, x)


def handwritten_vjp_conv2d_fwd(kernel, x):
    y = conv2d(kernel, x)
    return y, (kernel, x)


def handwritten_vjp_conv2d_bwd(residuals, g):
    kernel, x = residuals
    c, p, k, _ = kernel.shape
    groups = c // p

    # Expand dims to match what conv2d expects
    x_expanded = jnp.expand_dims(x, axis=0)
    g_expanded = jnp.expand_dims(g, axis=0)

    # Gradient w.r.t. kernel
    # This is essentially correlating the input with the output gradient
    kernel_grad = jax.lax.conv_general_dilated(
        lhs=x_expanded,
        rhs=g_expanded,
        window_strides=(1, 1),
        padding=((1, 1), (1, 1)),
        dimension_numbers=jax.lax.ConvDimensionNumbers(
            lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3)
        ),
        feature_group_count=1,
        batch_group_count=groups,
    )

    # Gradient w.r.t. input
    # Convolve gradient with flipped kernel
    kernel_flipped = jnp.flip(kernel, axis=(2, 3))
    x_grad = jax.lax.conv_general_dilated(
        lhs=g_expanded,
        rhs=kernel_flipped,
        window_strides=(1, 1),
        padding=((1, 1), (1, 1)),
        dimension_numbers=("NCHW", "OIHW", "NCHW"),
        feature_group_count=groups,
    )
    x_grad = jnp.squeeze(x_grad, axis=0)

    return (kernel_grad, x_grad)


handwritten_vjp_conv2d.defvjp(handwritten_vjp_conv2d_fwd, handwritten_vjp_conv2d_bwd)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.9.0
jaxlib: 0.9.0
numpy: 2.4.1
python: 3.13.5 (main, Jul 11 2025, 22:43:46) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='mo', release='6.12.59', version='#1-NixOS SMP PREEMPT_DYNAMIC Mon Nov 24 09:36:08 UTC 2025', machine='x86_64')

I have also observed the issue with python 3.12 and JAX 7.2.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions