Skip to content

Extract Linear.forward to nn::functional::linear #3147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 8, 2025

Conversation

crutcher
Copy link
Contributor

@crutcher crutcher commented May 4, 2025

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • [n/a] Made sure the book is up to date with changes in this PR.

Changes

This begins extracting nn.<Module>.forward code into common utility libs under nn::functional, by moving the Linear impl to nn::functional::linear.

Also provides linear_pytorch (name?) so there's a direct impl of the transpose semantics under pytorch.

Testing

Tested 1d, forward (with,without) bias.

Copy link

codecov bot commented May 4, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.58%. Comparing base (50d5a20) to head (7978c50).
Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3147      +/-   ##
==========================================
- Coverage   81.58%   81.58%   -0.01%     
==========================================
  Files         822      823       +1     
  Lines      118016   118085      +69     
==========================================
+ Hits        96284    96335      +51     
- Misses      21732    21750      +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@antimora
Copy link
Collaborator

antimora commented May 4, 2025

I am trying to learn the context.

Is there a particular reason for adding linear_pytorch? Wouldn't be it be easier to transpose the weight when importing like we do when importing pt file here

Also there was a recent conversation about storing different layouts possibly: https://discord.com/channels/1038839012602941528/1091796857996451942/1365333083716325437

@crutcher
Copy link
Contributor Author

crutcher commented May 4, 2025

@antimora I'll do my best to un-pack this without writing a novel; and I can hop on discord and talk about it this week as well.

I've been working on translating some complex vision models (such as SWIN) to Burn; and they have complex kernels that frequently rely on more primitive ideas. That's my proximate immediate drive.

In doing so, I've been bumping up against the awkwardness of not having access to the equivalent of the torch.nn.functional tree of non-module functions. Sure, many of these components could be forced to be sub-modules; but it gets tedious. It's also easier to test the pure functions in isolation; they don't have configs, they don't have Params, etc. If we also consider all the non-AI applications for tensors, having a robust library of common functions which aren't tied to a training/weight management scheme seems attractive.

In some places, it's clear that Burn made different choices than pytorch about how some of these functions work. Now, in this specific case, I've always hated the x A^T + b approach (which mainly seems to exist because someone at some point felt weird about Ax + b, which is the standard linear algebra form.

But!

Compat functions and option-prefill functions are essentially free.

  • they are generally stripped from programs which don't use them.
  • they are generally optimized out of programs which do use them.
  • they memorialize a contract (and the tests for that contract) in a way which is more durable than a side-note in a document.

If you want the linear_pytorch removed (or renamed?); I'm ok with that; but I do generally think that Burn would benefit from more "this is probably what you want" wrapper/alias functions for implementors.

The question of if loading or processing of the Linear layer should be aligned with, or modally aligned with, the weight structure from other nn layers is more a strategy question that should depend upon how Burn does all the common layers, and I'm not trying to pick a dog in that fight at all.

But I do feel strongly about the utility of the nn::functional tree; and I'm willing to do a lot of work to flesh it out.

@antimora
Copy link
Collaborator

antimora commented May 5, 2025

@crutcher Thank you for explaining your motivation.

I am +1 on nn::functional module but I am not sure there is a value for linear_pytorch even with a different name. It's so simple in functionality and yet specific to PyTorch's way.

BTW, there is a performance hit when transposing (even changing only strides) the weight matrix. @louisfd did a benchmark test to see if there was a difference when I suggested a special flag for linear: #1084. This was done before CubeCL, so it could be different now.

@crutcher
Copy link
Contributor Author

crutcher commented May 5, 2025

I've removed the linear_pytorch function.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with the functional modules in burn nn. We already have some functional APIs, but they are in burn-tensor. So I would put that file in burn_tensor/src/api/module/linear.rs. Then the burn::nn::Linear can use that function, similar to how you did this here.

@crutcher
Copy link
Contributor Author

crutcher commented May 5, 2025

@antimora I've addressed your comments.

@nathanielsimard are you sure about the move?

my thoughts on this are that fleshing out a full functional api will end up refactoring many components in burn-core; and they'll be in a split logic situation. i'm also concerned that, as we see in most std libs, many of the algorithms will be best expressed in terms of other algorithms; and sequencing becomes messy.

Is there any guide to how the layer cake should be stacked? Should there maybe be an intermediate layer?

Maybe:

  • burn-tensor
  • burn-fn-core
  • burn-core (or burn-ai-core)?

I'll do what you want with this PR; but I'm concerned that we're walking into a tangle if I continue building this kind of PR.

@crutcher
Copy link
Contributor Author

crutcher commented May 5, 2025

@nathanielsimard I've moved the tree

@crutcher crutcher force-pushed the crutcher/nn_func_linear branch from 7674050 to 9824b97 Compare May 5, 2025 19:20
@crutcher crutcher requested a review from nathanielsimard May 5, 2025 19:45
@crutcher
Copy link
Contributor Author

crutcher commented May 6, 2025

Lifted to ModuleOps

@crutcher crutcher requested a review from laggui May 6, 2025 19:45
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation is almost there!

See my comments below. But for this small stuff, I don't mind moving things around to complete your work myself. Especially since unsqueeze and flatten are only available via the public tensor API, not backend ops. So we just need to do B::float_reshape instead.

Will do this tomorrow to complete it if you'd like 🙂 otherwise feel free to complete yourself!

@crutcher crutcher force-pushed the crutcher/nn_func_linear branch 2 times, most recently from e65982f to 9f2d2fa Compare May 7, 2025 16:29
crutcher added 10 commits May 7, 2025 11:56
Replaced `Tensor::<TestBackend, _>::from_data` with `TensorData::from` for improved readability and consistency. Ensured data type conversion aligns with the input tensor's dtype. Added required imports to support changes.
@crutcher crutcher force-pushed the crutcher/nn_func_linear branch from 9f2d2fa to 03c2131 Compare May 7, 2025 18:56
@crutcher crutcher requested a review from laggui May 7, 2025 19:30
@crutcher crutcher marked this pull request as draft May 7, 2025 19:36
@laggui
Copy link
Member

laggui commented May 7, 2025

Can take care of the last mile delivery tomorrow!

Had a couple things to investigate today (notably that dreaded macos CI issue).. so time ran out before I could make it here 😅

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crutcher just completed the changes, see my comments below if you are curious.

Also cleaned up the vector norm tests a bit. When I ran the tests locally the lp norm test failed with f16 (tolerance difference), so I went over them after setting the half precision tolerance.

Comment on lines +793 to +814
fn linear(
input: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
) -> FloatTensor<B> {
let ndims_in = input.shape().num_dims();
let [d_input, d_output] = weight.shape().dims();
if ndims_in == 1 {
// Insert and remove an extra batch dimension for the batch matmul to work.
let input = B::float_reshape(input, Shape::from([1, d_input]));
let output = Self::linear(input, weight, bias);
return B::float_reshape(output, Shape::from([d_output]));
}

let weight = unsqueeze::<B>(weight, ndims_in);
let output = B::float_matmul(input, weight);
match bias {
Some(bias) => B::float_add(output, unsqueeze::<B>(bias, ndims_in)),
None => output,
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the equivalent implementation using primitives. Primitives use backend tensor ops. Since unsqueeze and flatten are user-level APIs (defined on the Tensor struct) I couldn't use them here. But we should improve this for convenience.

Note: FloatTensor<B> is just a type alias for B::FloatTensorPrimitive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets say we have a GEMM kernel, would this part be optimized using it?

@nathanielsimard @wingertge Any plans for GEMM op?

Comment on lines +83 to +87
linear(
input,
self.weight.val(),
self.bias.as_ref().map(|b| b.val()),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward now simply calls the module::linear method.

Comment on lines +386 to +396
pub fn linear<B: Backend, const D: usize>(
input: Tensor<B, D>,
weight: Tensor<B, 2>,
bias: Option<Tensor<B, 1>>,
) -> Tensor<B, D> {
Tensor::new(TensorPrimitive::Float(B::linear(
input.primitive.tensor(),
weight.primitive.tensor(),
bias.map(|b| b.primitive.tensor()),
)))
}
Copy link
Member

@laggui laggui May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module function which operates on Tensor (not backend primitives) is the one that calls the module ops trait implementation. .tensor() will simply return the float tensor primitive (floats can be represented via quantization too (WIP), so they're wrapped in an enum).

The resulting primitive is simply encapsulated in the Tensor struct.

@laggui laggui marked this pull request as ready for review May 8, 2025 14:46
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +793 to +814
fn linear(
input: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
) -> FloatTensor<B> {
let ndims_in = input.shape().num_dims();
let [d_input, d_output] = weight.shape().dims();
if ndims_in == 1 {
// Insert and remove an extra batch dimension for the batch matmul to work.
let input = B::float_reshape(input, Shape::from([1, d_input]));
let output = Self::linear(input, weight, bias);
return B::float_reshape(output, Shape::from([d_output]));
}

let weight = unsqueeze::<B>(weight, ndims_in);
let output = B::float_matmul(input, weight);
match bias {
Some(bias) => B::float_add(output, unsqueeze::<B>(bias, ndims_in)),
None => output,
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets say we have a GEMM kernel, would this part be optimized using it?

@nathanielsimard @wingertge Any plans for GEMM op?

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving my partial contribution 😆

@crutcher
Copy link
Contributor Author

crutcher commented May 8, 2025

@laggui thanks for fixing this!

@laggui laggui merged commit 78a2f24 into tracel-ai:main May 8, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants