-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
Description
For the past few months I have occasionally encountered various performance issues when training convolutional neural networks using custom VJP rules, like slow executions and out-of-memory errors. I believe these issues to be connected to the minimal example described below.
Fundamentally, the issues seems to be that when grouped convolutions appear in the backward pass of a custom VJP, which is called under jax.vmap(), JAX will process the input as a tensor of shape NC×1×H×W instead of N×C×H×W, where N is the batch size, C is the features count, and H and W are the spatial dimensions. For an example of an issue that this can cause, this, in turn, leads cudnn to struggle with optimizing the convolution, as the tensor has an unusual shape.
This behavior appears whenever a custom VJP rule is defined, and is not specific to the custom VJP rule. The code below can be used to generate the Jaxpr IR for a grouped convolution that does not use a custom VJP rule, and for a convolution with the same kernel that uses a trivial custom VJP rule that simply calls the underlying VJP.
import jax
import jax.numpy as jnp
def conv_kernel(k, c, groups=1):
assert c % groups == 0
per_group = c // groups
return jnp.ones((c, per_group, k, k))
def conv2d(kernel, x):
c, p, _, _ = kernel.shape
groups = c // p
x = jnp.expand_dims(x, axis=0)
out = jax.lax.conv_general_dilated(
lhs=x,
rhs=kernel,
window_strides=(1, 1),
padding="SAME",
feature_group_count=groups,
)
return jnp.squeeze(out, axis=0)
@jax.custom_vjp
def wrapped_conv2d(kernel, x):
return conv2d(kernel, x)
def wrapped_conv2d_fwd(kernel, x):
y, vjp_fn = jax.vjp(conv2d, kernel, x)
return y, vjp_fn
def wrapped_conv2d_bwd(residuals, g):
vjp_fn = residuals
g_out = vjp_fn(g)
return g_out
wrapped_conv2d.defvjp(wrapped_conv2d_fwd, wrapped_conv2d_bwd)
def loss_base(kernel, images, labels):
def model(x):
x = conv2d(kernel, x)
logit = jnp.mean(x)
return jax.nn.sigmoid(logit)
preds = jax.vmap(model)(images)
return jnp.mean((preds - labels) ** 2)
def loss_custom(kernel, images, labels):
def model(x):
x = wrapped_conv2d(kernel, x)
logit = jnp.mean(x)
return jax.nn.sigmoid(logit)
preds = jax.vmap(model)(images)
return jnp.mean((preds - labels) ** 2)
batch_size = 32
channels = 256
height = 56
width = 56
kernel_size = 3
groups = 2
images = jnp.ones((batch_size, channels, height, width))
labels = jnp.ones((batch_size,))
kernel = conv_kernel(kernel_size, channels, groups)
loss_fn = jax.jit(jax.grad(loss_base))
traced = loss_fn.trace(kernel, images, labels.astype(jnp.float32))
print(traced.jaxpr)
loss_fn = jax.jit(jax.grad(loss_custom))
traced = loss_fn.trace(kernel, images, labels.astype(jnp.float32))
print(traced.jaxpr)Here is the output of this script when ran on my computer:
{ lambda ; a:f32[256,128,3,3] b:f32[32,256,56,56] c:f32[32]. let
d:f32[32,1,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
shape=(32, 1, 256, 56, 56)
sharding=None
] b
e:f32[32,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 256, 56, 56)
sharding=None
] d
f:f32[32,256,56,56] = conv_general_dilated[
batch_group_count=1
dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))
feature_group_count=2
lhs_dilation=(1, 1)
out_sharding=None
padding=((1, 1), (1, 1))
precision=None
preferred_element_type=None
rhs_dilation=(1, 1)
window_strides=(1, 1)
] e a
g:f32[32,1,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 1, 256, 56, 56)
sharding=None
] f
h:f32[32,256,56,56] = squeeze[dimensions=(1,)] g
i:f32[32] = reduce_sum[
axes=(np.int64(1), np.int64(2), np.int64(3))
out_sharding=None
] h
j:f32[32] = div i 802816.0:f32[]
k:f32[32] = logistic j
l:f32[32] = sub 1.0:f32[] k
m:f32[32] = mul k l
n:f32[32] = sub k c
o:f32[32] = integer_pow[y=2] n
p:f32[32] = integer_pow[y=1] n
q:f32[32] = mul 2.0:f32[] p
r:f32[] = reduce_sum[axes=(0,) out_sharding=None] o
_:f32[] = div r 32.0:f32[]
s:f32[] = div 1.0:f32[] 32.0:f32[]
t:f32[32] = broadcast_in_dim[
broadcast_dimensions=()
shape=(32,)
sharding=None
] s
u:f32[32] = mul t q
v:f32[32] = mul u m
w:f32[32] = div v 802816.0:f32[]
x:f32[32,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(np.int64(0),)
shape=(32, 256, 56, 56)
sharding=None
] w
y:f32[32,1,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(0, 2, 3, 4)
shape=(32, 1, 256, 56, 56)
sharding=None
] x
z:f32[32,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 256, 56, 56)
sharding=None
] y
ba:f32[256,128,3,3] = conv_general_dilated[
batch_group_count=2
dimension_numbers=ConvDimensionNumbers(lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3))
feature_group_count=1
lhs_dilation=(1, 1)
out_sharding=None
padding=((1, 1), (1, 1))
precision=None
preferred_element_type=None
rhs_dilation=(1, 1)
window_strides=(1, 1)
] e z
in (ba,) }
{ lambda ; a:f32[256,128,3,3] b:f32[32,256,56,56] c:f32[32]. let
d:f32[32,1,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
shape=(32, 1, 256, 56, 56)
sharding=None
] b
e:f32[32,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 256, 56, 56)
sharding=None
] d
f:f32[32,256,56,56] = conv_general_dilated[
batch_group_count=1
dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))
feature_group_count=2
lhs_dilation=(1, 1)
out_sharding=None
padding=((1, 1), (1, 1))
precision=None
preferred_element_type=None
rhs_dilation=(1, 1)
window_strides=(1, 1)
] e a
g:f32[32,1,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 1, 256, 56, 56)
sharding=None
] f
h:f32[32,256,56,56] = squeeze[dimensions=(1,)] g
i:f32[32] = reduce_sum[
axes=(np.int64(1), np.int64(2), np.int64(3))
out_sharding=None
] h
j:f32[32] = div i 802816.0:f32[]
k:f32[32] = logistic j
l:f32[32] = sub 1.0:f32[] k
m:f32[32] = mul k l
n:f32[32] = sub k c
o:f32[32] = integer_pow[y=2] n
p:f32[32] = integer_pow[y=1] n
q:f32[32] = mul 2.0:f32[] p
r:f32[] = reduce_sum[axes=(0,) out_sharding=None] o
_:f32[] = div r 32.0:f32[]
s:f32[] = div 1.0:f32[] 32.0:f32[]
t:f32[32] = broadcast_in_dim[
broadcast_dimensions=()
shape=(32,)
sharding=None
] s
u:f32[32] = mul t q
v:f32[32] = mul u m
w:f32[32] = div v 802816.0:f32[]
x:f32[32,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(np.int64(0),)
shape=(32, 256, 56, 56)
sharding=None
] w
y:f32[32,1,256,56,56] = broadcast_in_dim[
broadcast_dimensions=(0, np.int64(2), np.int64(3), np.int64(4))
shape=(32, 1, 256, 56, 56)
sharding=None
] x
z:f32[1,8192,56,56] = reshape[
dimensions=(1, 0, 2, 3, 4)
new_sizes=(1, 8192, 56, 56)
sharding=None
] d
ba:f32[1,8192,56,56] = reshape[
dimensions=(1, 0, 2, 3, 4)
new_sizes=(1, 8192, 56, 56)
sharding=None
] y
bb:f32[8192,128,3,3] = conv_general_dilated[
batch_group_count=64
dimension_numbers=ConvDimensionNumbers(lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3))
feature_group_count=1
lhs_dilation=(1, 1)
out_sharding=None
padding=((1, 1), (1, 1))
precision=None
preferred_element_type=None
rhs_dilation=(1, 1)
window_strides=(1, 1)
] z ba
bc:f32[32,256,128,3,3] = reshape[
dimensions=None
new_sizes=(32, 256, 128, 3, 3)
sharding=None
] bb
bd:f32[2,128,128,3,3] = reshape[
dimensions=None
new_sizes=(2, 128, 128, 3, 3)
sharding=None
] a
be:f32[128,256,3,3] = reshape[
dimensions=(1, 0, 2, 3, 4)
new_sizes=(128, 256, 3, 3)
sharding=None
] bd
bf:f32[128,256,3,3] = rev[dimensions=(2, 3)] be
bg:f32[32,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 256, 56, 56)
sharding=None
] y
bh:f32[32,256,56,56] = conv_general_dilated[
batch_group_count=1
dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(0, 1, 2, 3))
feature_group_count=2
lhs_dilation=(1, 1)
out_sharding=None
padding=((1, 1), (1, 1))
precision=None
preferred_element_type=None
rhs_dilation=(1, 1)
window_strides=(1, 1)
] bg bf
bi:f32[32,1,256,56,56] = reshape[
dimensions=None
new_sizes=(32, 1, 256, 56, 56)
sharding=None
] bh
_:f32[32,256,56,56] = reduce_sum[axes=(np.int64(1),) out_sharding=None] bi
bj:f32[256,128,3,3] = reduce_sum[axes=(0,) out_sharding=None] bc
in (bj,) }
As you can see, the two traces are similar until the end, where the second trace goes into the following reshaping:
z:f32[1,8192,56,56] = reshape[
dimensions=(1, 0, 2, 3, 4)
new_sizes=(1, 8192, 56, 56)
sharding=None
] d
ba:f32[1,8192,56,56] = reshape[
dimensions=(1, 0, 2, 3, 4)
new_sizes=(1, 8192, 56, 56)
sharding=None
] y
(the second trace also includes an extra convolution in bh corresponding to the gradient of the loss with respect to the input, that is then discarded, but this is not what my issue is about).
I have observed this behavior in situations where the backward function of the custom VJP rule does not call jax.vjp() ; for example, with the (possibly incorrect) following re-implementation of the rule:
@jax.custom_vjp
def handwritten_vjp_conv2d(kernel, x):
return conv2d(kernel, x)
def handwritten_vjp_conv2d_fwd(kernel, x):
y = conv2d(kernel, x)
return y, (kernel, x)
def handwritten_vjp_conv2d_bwd(residuals, g):
kernel, x = residuals
c, p, k, _ = kernel.shape
groups = c // p
# Expand dims to match what conv2d expects
x_expanded = jnp.expand_dims(x, axis=0)
g_expanded = jnp.expand_dims(g, axis=0)
# Gradient w.r.t. kernel
# This is essentially correlating the input with the output gradient
kernel_grad = jax.lax.conv_general_dilated(
lhs=x_expanded,
rhs=g_expanded,
window_strides=(1, 1),
padding=((1, 1), (1, 1)),
dimension_numbers=jax.lax.ConvDimensionNumbers(
lhs_spec=(1, 0, 2, 3), rhs_spec=(1, 0, 2, 3), out_spec=(1, 0, 2, 3)
),
feature_group_count=1,
batch_group_count=groups,
)
# Gradient w.r.t. input
# Convolve gradient with flipped kernel
kernel_flipped = jnp.flip(kernel, axis=(2, 3))
x_grad = jax.lax.conv_general_dilated(
lhs=g_expanded,
rhs=kernel_flipped,
window_strides=(1, 1),
padding=((1, 1), (1, 1)),
dimension_numbers=("NCHW", "OIHW", "NCHW"),
feature_group_count=groups,
)
x_grad = jnp.squeeze(x_grad, axis=0)
return (kernel_grad, x_grad)
handwritten_vjp_conv2d.defvjp(handwritten_vjp_conv2d_fwd, handwritten_vjp_conv2d_bwd)System info (python version, jaxlib version, accelerator, etc.)
jax: 0.9.0
jaxlib: 0.9.0
numpy: 2.4.1
python: 3.13.5 (main, Jul 11 2025, 22:43:46) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='mo', release='6.12.59', version='#1-NixOS SMP PREEMPT_DYNAMIC Mon Nov 24 09:36:08 UTC 2025', machine='x86_64')
I have also observed the issue with python 3.12 and JAX 7.2.