Skip to content

Conversation

@crutcher
Copy link
Contributor

@crutcher crutcher commented Sep 12, 2025

See #3705

Changes

Flush out conv shape calculations; with maybe_ and expect_ variants in 1 and N-D; include both [usize; D] and &[usize] variants.

@crutcher
Copy link
Contributor Author

@laggui

@codecov
Copy link

codecov bot commented Sep 12, 2025

Codecov Report

❌ Patch coverage is 53.64964% with 127 lines in your changes missing coverage. Please review.
✅ Project coverage is 64.35%. Comparing base (f007a31) to head (4984a39).

Files with missing lines Patch % Lines
crates/burn-tensor/src/tensor/ops/modules/conv.rs 65.94% 63 Missing ⚠️
crates/burn-fusion/src/ops/module.rs 0.00% 20 Missing ⚠️
crates/burn-router/src/ops/op_module.rs 0.00% 18 Missing ⚠️
crates/burn-cubecl/src/kernel/conv/im2col.rs 0.00% 12 Missing ⚠️
...rates/burn-cubecl/src/kernel/conv/deform_conv2d.rs 0.00% 6 Missing ⚠️
crates/burn-cubecl/src/kernel/conv/direct.rs 0.00% 6 Missing ⚠️
...urn-cubecl/src/kernel/conv/implicit_gemm/launch.rs 0.00% 2 Missing ⚠️

❌ Your patch check has failed because the patch coverage (53.64%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (64.35%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3709      +/-   ##
==========================================
+ Coverage   64.33%   64.35%   +0.01%     
==========================================
  Files        1156     1156              
  Lines      134463   134577     +114     
==========================================
+ Hits        86508    86608     +100     
- Misses      47955    47969      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@crutcher crutcher marked this pull request as ready for review September 12, 2025 20:59
@crutcher crutcher changed the title [WIP] Work on conv shape funcs Work on conv shape funcs Sep 12, 2025
@crutcher crutcher changed the title Work on conv shape funcs More aggressive conv shape functions, docs, maybe_/expect_ call patterns. Sep 12, 2025
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for invalid configurations is a good idea.

Regarding the API changes, I think a better approach would be to implement output_shape(input_shape, weight_shape, options) to get the full shape, e.g.

/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution.
pub fn calculate_conv_output_shape<const N: usize>(
    in_shape: &Shape,
    weight_shape: &Shape,
    options: &ConvOptions<N>,
) -> Shape {
    assert_eq!(weight_shape.num_dims(), N + 2);
    assert_eq!(in_shape.num_dims(), N + 2);

    let kernel_size = &weight_shape.dims[2..];

    let mut out_shape = in_shape.clone();
    // Spatial dims
    for (i, size_i) in out_shape.dims[2..].iter_mut().enumerate() {
        *size_i = calculate_conv_output_size(
            kernel_size[i],
            options.stride[i],
            options.padding[i],
            options.dilation[i],
            *size_i,
        );
    }
    // Output channels
    out_shape.dims[1] = weight_shape.dims[0];

    out_shape
}

Eventually, we could make the output shape explicit to the backend impl, e.g:

pub trait ModuleOps<B: Backend> {
  fn conv2d(
      x: FloatTensor<B>,
      weight: FloatTensor<B>,
      bias: Option<FloatTensor<B>>,
      options: ConvOptions<2>,
      output_shape: Shape,
  ) -> FloatTensor<B>;
}

That way, there is a single source of truth for the output shape. We can easily compute the complete output shape before calling the backend op. But that will impact some other APIs atm.

@crutcher
Copy link
Contributor Author

@laggui I'm down to provide an api for D2 + conv; but I don't want to lose the raw computation.

@crutcher crutcher force-pushed the crutcher/conv_shapes branch from 75c32ba to a01c97f Compare September 15, 2025 21:02
@crutcher crutcher requested a review from laggui September 15, 2025 21:03
@laggui
Copy link
Member

laggui commented Sep 16, 2025

@laggui I'm down to provide an api for D2 + conv; but I don't want to lose the raw computation.

Not sure I follow, what do you mean by raw computation?

@antimora
Copy link
Collaborator

antimora commented Oct 1, 2025

@crutcher, might have missed a notification. Still valid?

@crutcher
Copy link
Contributor Author

crutcher commented Oct 8, 2025

@laggui, @antimora I did miss the update.

raw computation; I want to be able to explicitly re-use the exact same shape calculations in library code that the conv family uses; so that we can tie contracts which depend upon those shapes to conv without those libraries being forced to try and re-implement that logic.

@crutcher crutcher force-pushed the crutcher/conv_shapes branch from a01c97f to 63a8968 Compare October 8, 2025 22:29
@crutcher crutcher force-pushed the crutcher/conv_shapes branch from 63a8968 to 4984a39 Compare October 8, 2025 22:31
@laggui
Copy link
Member

laggui commented Oct 9, 2025

raw computation; I want to be able to explicitly re-use the exact same shape calculations in library code that the conv family uses; so that we can tie contracts which depend upon those shapes to conv without those libraries being forced to try and re-implement that logic.

That makes sense! We'd keep the calculate_conv_output_shape public anyway.

@crutcher see also my PR that introduces more Shape manipulations / calculations #3845. I will follow this up with some changes for the conv output shape similar to my previous comment (without the explicit shape as input).

@crutcher crutcher marked this pull request as draft October 15, 2025 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants