While building a Rust binding over mlx-c, a static review of the convolution path surfaced several int (int32) quantities that overflow for large or extreme-parameter convolutions. Signed-integer overflow is UB in C++, and in practice yields wrong shapes/sizes. Reporting in case these are worth widening to int64_t or guarding with a raised error.
Quantities computed in int:
-
Forward output shape (conv_out_shape): per axis, dilation * (kernel - 1), the effective input input_dilation * (in - 1) + 1, and (effective_in + pad_lo + pad_hi - dilated_kernel) / stride + 1. A dilation/padding/stride near INT_MAX overflows before validation.
-
Transposed prelude (conv_transpose_general): computes 1 + dilation * (weight_dim - 1), 2 * padding, and (in - 1) * stride + ... over the weight and parameters before the nested conv_general validates the rank, so an output_padding/dilation near INT_MAX overflows even for a tiny input.
-
Negative-padding normalization: slicing every dimension forms dim + 1 / 0 - pad_lo in int32, overflowing for a dimension near INT_MAX.
-
Metal implicit GEMM (implicit_gemm_conv_2D_gpu and the 3D path): implicit_M = N * oH * oW is an int32 product of the output spatial dims. For a conv whose batch x output-spatial exceeds INT_MAX this overflows — e.g. conv2d on input [1, 46341, 46341, 1] gives implicit_M = 46341 * 46341 = 2147488281 > 2147483647. The K dimension product(weight_spatial) * channels_per_group and the lcm(input_dilation, stride) jump-table sizes are similarly int32.
Most need either multi-GB tensors or extreme parameters, so they're unlikely in normal use, but they are reachable from the safe API and are UB rather than a clean error. Would you accept widening these to int64_t (or adding overflow checks)? Happy to help pin down exact locations.
While building a Rust binding over
mlx-c, a static review of the convolution path surfaced severalint(int32) quantities that overflow for large or extreme-parameter convolutions. Signed-integer overflow is UB in C++, and in practice yields wrong shapes/sizes. Reporting in case these are worth widening toint64_tor guarding with a raised error.Quantities computed in
int:Forward output shape (
conv_out_shape): per axis,dilation * (kernel - 1), the effective inputinput_dilation * (in - 1) + 1, and(effective_in + pad_lo + pad_hi - dilated_kernel) / stride + 1. Adilation/padding/stridenearINT_MAXoverflows before validation.Transposed prelude (
conv_transpose_general): computes1 + dilation * (weight_dim - 1),2 * padding, and(in - 1) * stride + ...over the weight and parameters before the nestedconv_generalvalidates the rank, so anoutput_padding/dilationnearINT_MAXoverflows even for a tiny input.Negative-padding normalization: slicing every dimension forms
dim + 1/0 - pad_loin int32, overflowing for a dimension nearINT_MAX.Metal implicit GEMM (
implicit_gemm_conv_2D_gpuand the 3D path):implicit_M = N * oH * oWis an int32 product of the output spatial dims. For a conv whose batch x output-spatial exceedsINT_MAXthis overflows — e.g.conv2don input[1, 46341, 46341, 1]givesimplicit_M = 46341 * 46341 = 2147488281 > 2147483647. The K dimensionproduct(weight_spatial) * channels_per_groupand thelcm(input_dilation, stride)jump-table sizes are similarly int32.Most need either multi-GB tensors or extreme parameters, so they're unlikely in normal use, but they are reachable from the safe API and are UB rather than a clean error. Would you accept widening these to
int64_t(or adding overflow checks)? Happy to help pin down exact locations.