-
Notifications
You must be signed in to change notification settings - Fork 565
Extract Linear.forward to nn::functional::linear #3147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3147 +/- ##
==========================================
- Coverage 81.58% 81.58% -0.01%
==========================================
Files 822 823 +1
Lines 118016 118085 +69
==========================================
+ Hits 96284 96335 +51
- Misses 21732 21750 +18 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
I am trying to learn the context. Is there a particular reason for adding linear_pytorch? Wouldn't be it be easier to transpose the weight when importing like we do when importing pt file here Also there was a recent conversation about storing different layouts possibly: https://discord.com/channels/1038839012602941528/1091796857996451942/1365333083716325437 |
@antimora I'll do my best to un-pack this without writing a novel; and I can hop on discord and talk about it this week as well. I've been working on translating some complex vision models (such as SWIN) to Burn; and they have complex kernels that frequently rely on more primitive ideas. That's my proximate immediate drive. In doing so, I've been bumping up against the awkwardness of not having access to the equivalent of the In some places, it's clear that Burn made different choices than pytorch about how some of these functions work. Now, in this specific case, I've always hated the But! Compat functions and option-prefill functions are essentially free.
If you want the The question of if loading or processing of the Linear layer should be aligned with, or modally aligned with, the weight structure from other nn layers is more a strategy question that should depend upon how Burn does all the common layers, and I'm not trying to pick a dog in that fight at all. But I do feel strongly about the utility of the |
@crutcher Thank you for explaining your motivation. I am +1 on nn::functional module but I am not sure there is a value for BTW, there is a performance hit when transposing (even changing only strides) the weight matrix. @louisfd did a benchmark test to see if there was a difference when I suggested a special flag for linear: #1084. This was done before CubeCL, so it could be different now. |
I've removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree with the functional
modules in burn nn. We already have some functional APIs, but they are in burn-tensor
. So I would put that file in burn_tensor/src/api/module/linear.rs
. Then the burn::nn::Linear
can use that function, similar to how you did this here.
@antimora I've addressed your comments. @nathanielsimard are you sure about the move? my thoughts on this are that fleshing out a full functional api will end up refactoring many components in burn-core; and they'll be in a split logic situation. i'm also concerned that, as we see in most std libs, many of the algorithms will be best expressed in terms of other algorithms; and sequencing becomes messy. Is there any guide to how the layer cake should be stacked? Should there maybe be an intermediate layer? Maybe:
I'll do what you want with this PR; but I'm concerned that we're walking into a tangle if I continue building this kind of PR. |
@nathanielsimard I've moved the tree |
7674050
to
9824b97
Compare
Lifted to ModuleOps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation is almost there!
See my comments below. But for this small stuff, I don't mind moving things around to complete your work myself. Especially since unsqueeze
and flatten
are only available via the public tensor API, not backend ops. So we just need to do B::float_reshape
instead.
Will do this tomorrow to complete it if you'd like 🙂 otherwise feel free to complete yourself!
e65982f
to
9f2d2fa
Compare
Replaced `Tensor::<TestBackend, _>::from_data` with `TensorData::from` for improved readability and consistency. Ensured data type conversion aligns with the input tensor's dtype. Added required imports to support changes.
9f2d2fa
to
03c2131
Compare
Can take care of the last mile delivery tomorrow! Had a couple things to investigate today (notably that dreaded macos CI issue).. so time ran out before I could make it here 😅 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@crutcher just completed the changes, see my comments below if you are curious.
Also cleaned up the vector norm tests a bit. When I ran the tests locally the lp norm test failed with f16 (tolerance difference), so I went over them after setting the half precision tolerance.
fn linear( | ||
input: FloatTensor<B>, | ||
weight: FloatTensor<B>, | ||
bias: Option<FloatTensor<B>>, | ||
) -> FloatTensor<B> { | ||
let ndims_in = input.shape().num_dims(); | ||
let [d_input, d_output] = weight.shape().dims(); | ||
if ndims_in == 1 { | ||
// Insert and remove an extra batch dimension for the batch matmul to work. | ||
let input = B::float_reshape(input, Shape::from([1, d_input])); | ||
let output = Self::linear(input, weight, bias); | ||
return B::float_reshape(output, Shape::from([d_output])); | ||
} | ||
|
||
let weight = unsqueeze::<B>(weight, ndims_in); | ||
let output = B::float_matmul(input, weight); | ||
match bias { | ||
Some(bias) => B::float_add(output, unsqueeze::<B>(bias, ndims_in)), | ||
None => output, | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the equivalent implementation using primitives. Primitives use backend tensor ops. Since unsqueeze and flatten are user-level APIs (defined on the Tensor
struct) I couldn't use them here. But we should improve this for convenience.
Note: FloatTensor<B>
is just a type alias for B::FloatTensorPrimitive
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets say we have a GEMM kernel, would this part be optimized using it?
@nathanielsimard @wingertge Any plans for GEMM op?
linear( | ||
input, | ||
self.weight.val(), | ||
self.bias.as_ref().map(|b| b.val()), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward now simply calls the module::linear
method.
pub fn linear<B: Backend, const D: usize>( | ||
input: Tensor<B, D>, | ||
weight: Tensor<B, 2>, | ||
bias: Option<Tensor<B, 1>>, | ||
) -> Tensor<B, D> { | ||
Tensor::new(TensorPrimitive::Float(B::linear( | ||
input.primitive.tensor(), | ||
weight.primitive.tensor(), | ||
bias.map(|b| b.primitive.tensor()), | ||
))) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The module function which operates on Tensor
(not backend primitives) is the one that calls the module ops trait implementation. .tensor()
will simply return the float tensor primitive (floats can be represented via quantization too (WIP), so they're wrapped in an enum).
The resulting primitive is simply encapsulated in the Tensor
struct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fn linear( | ||
input: FloatTensor<B>, | ||
weight: FloatTensor<B>, | ||
bias: Option<FloatTensor<B>>, | ||
) -> FloatTensor<B> { | ||
let ndims_in = input.shape().num_dims(); | ||
let [d_input, d_output] = weight.shape().dims(); | ||
if ndims_in == 1 { | ||
// Insert and remove an extra batch dimension for the batch matmul to work. | ||
let input = B::float_reshape(input, Shape::from([1, d_input])); | ||
let output = Self::linear(input, weight, bias); | ||
return B::float_reshape(output, Shape::from([d_output])); | ||
} | ||
|
||
let weight = unsqueeze::<B>(weight, ndims_in); | ||
let output = B::float_matmul(input, weight); | ||
match bias { | ||
Some(bias) => B::float_add(output, unsqueeze::<B>(bias, ndims_in)), | ||
None => output, | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets say we have a GEMM kernel, would this part be optimized using it?
@nathanielsimard @wingertge Any plans for GEMM op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving my partial contribution 😆
@laggui thanks for fixing this! |
Pull Request Template
Checklist
cargo run-checks
command has been executed.Changes
This begins extracting
nn.<Module>.forward
code into common utility libs undernn::functional
, by moving theLinear
impl tonn::functional::linear
.Also provides
linear_pytorch
(name?) so there's a direct impl of the transpose semantics under pytorch.Testing
Tested 1d, forward (with,without) bias.