- 
                Notifications
    
You must be signed in to change notification settings  - Fork 730
 
More aggressive conv shape functions, docs, maybe_/expect_ call patterns. #3709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
          Codecov Report❌ Patch coverage is  ❌ Your patch check has failed because the patch coverage (53.64%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@            Coverage Diff             @@
##             main    #3709      +/-   ##
==========================================
+ Coverage   64.33%   64.35%   +0.01%     
==========================================
  Files        1156     1156              
  Lines      134463   134577     +114     
==========================================
+ Hits        86508    86608     +100     
- Misses      47955    47969      +14     ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking for invalid configurations is a good idea.
Regarding the API changes, I think a better approach would be to implement output_shape(input_shape, weight_shape, options) to get the full shape, e.g.
/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution.
pub fn calculate_conv_output_shape<const N: usize>(
    in_shape: &Shape,
    weight_shape: &Shape,
    options: &ConvOptions<N>,
) -> Shape {
    assert_eq!(weight_shape.num_dims(), N + 2);
    assert_eq!(in_shape.num_dims(), N + 2);
    let kernel_size = &weight_shape.dims[2..];
    let mut out_shape = in_shape.clone();
    // Spatial dims
    for (i, size_i) in out_shape.dims[2..].iter_mut().enumerate() {
        *size_i = calculate_conv_output_size(
            kernel_size[i],
            options.stride[i],
            options.padding[i],
            options.dilation[i],
            *size_i,
        );
    }
    // Output channels
    out_shape.dims[1] = weight_shape.dims[0];
    out_shape
}Eventually, we could make the output shape explicit to the backend impl, e.g:
pub trait ModuleOps<B: Backend> {
  fn conv2d(
      x: FloatTensor<B>,
      weight: FloatTensor<B>,
      bias: Option<FloatTensor<B>>,
      options: ConvOptions<2>,
      output_shape: Shape,
  ) -> FloatTensor<B>;
}That way, there is a single source of truth for the output shape. We can easily compute the complete output shape before calling the backend op. But that will impact some other APIs atm.
| 
           @laggui I'm down to provide an api for D2 + conv; but I don't want to lose the raw computation.  | 
    
75c32ba    to
    a01c97f      
    Compare
  
    
          
 Not sure I follow, what do you mean by raw computation?  | 
    
| 
           @crutcher, might have missed a notification. Still valid?  | 
    
| 
           @laggui, @antimora I did miss the update. 
  | 
    
a01c97f    to
    63a8968      
    Compare
  
    63a8968    to
    4984a39      
    Compare
  
    
          
 That makes sense! We'd keep the  @crutcher see also my PR that introduces more   | 
    
See #3705
Changes
Flush out conv shape calculations; with
maybe_andexpect_variants in 1 and N-D; include both[usize; D]and&[usize]variants.