Skip to content

Commit 78a2f24

Browse files
crutcherlaggui
andauthored
Extract Linear.forward to nn::functional::linear (#3147)
* Extract Linear.forward to nn::functional::linear * remove linear_pytorch per review * compat docs * Refactor test assertions to use TensorData directly. Replaced `Tensor::<TestBackend, _>::from_data` with `TensorData::from` for improved readability and consistency. Ensured data type conversion aligns with the input tensor's dtype. Added required imports to support changes. * fmt * move nn.functional to burn-tensor * doc * fix tests * Lift linear to ModuleOps * move to moduleops * Refactor linear module ops with primitives * Fix vector norm cuda tests / f16 tolerance * Use alloc vec * Update lock --------- Co-authored-by: Guillaume Lagrange <[email protected]>
1 parent 90acdbc commit 78a2f24

File tree

8 files changed

+236
-68
lines changed

8 files changed

+236
-68
lines changed

Cargo.lock

+16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-core/src/nn/linear.rs

+16-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use burn_tensor::module::linear;
2+
13
use crate as burn;
24

35
use crate::config::Config;
@@ -24,7 +26,7 @@ pub struct LinearConfig {
2426
pub initializer: Initializer,
2527
}
2628

27-
/// Applies a linear transformation to the input tensor:
29+
/// Applies a linear transformation to the input tensor.
2830
///
2931
/// Should be created with [LinearConfig]
3032
///
@@ -65,24 +67,24 @@ impl LinearConfig {
6567
impl<B: Backend> Linear<B> {
6668
/// Applies the forward pass on the input tensor.
6769
///
70+
/// # Arguments
71+
///
72+
/// - `input` - The input tensor of shape `[..., d_input]`.
73+
///
6874
/// # Shapes
6975
///
7076
/// - input: `[..., d_input]`
7177
/// - output: `[..., d_output]`
78+
///
79+
/// # Returns
80+
///
81+
/// The transformed tensor of shape `[..., d_output]`.
7282
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
73-
if D == 1 {
74-
// Insert and remove an extra batch dimension for the batch matmul to work.
75-
return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1);
76-
}
77-
78-
let weight = self.weight.val().unsqueeze();
79-
let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());
80-
let output = input.matmul(weight);
81-
82-
match bias {
83-
Some(bias) => output + bias,
84-
None => output,
85-
}
83+
linear(
84+
input,
85+
self.weight.val(),
86+
self.bias.as_ref().map(|b| b.val()),
87+
)
8688
}
8789
}
8890

crates/burn-tensor/src/tensor/module.rs

+37
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,40 @@ where
357357
options,
358358
)))
359359
}
360+
361+
/// Applies a [linear transformation](crate::ops::ModuleOps::linear) to the input tensor using the given weight and bias.
362+
///
363+
/// ```math
364+
/// y = x @ weight + [bias]
365+
/// ```
366+
///
367+
/// # Arguments:
368+
///
369+
/// - `input` is the input tensor, ``[..., d_input]``.
370+
/// - `weight` is the weight tensor, ``[d_input, d_output]``.
371+
/// - `bias` is the bias tensor (optional), ``[d_output]``.
372+
///
373+
/// # Returns:
374+
///
375+
/// The transformed tensor, ``[..., d_output]``.
376+
///
377+
/// # Compatibility
378+
///
379+
/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
380+
/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
381+
/// multiplication:
382+
///
383+
/// ```math
384+
/// y = x @ weight^T + [bias]
385+
/// ```
386+
pub fn linear<B: Backend, const D: usize>(
387+
input: Tensor<B, D>,
388+
weight: Tensor<B, 2>,
389+
bias: Option<Tensor<B, 1>>,
390+
) -> Tensor<B, D> {
391+
Tensor::new(TensorPrimitive::Float(B::linear(
392+
input.primitive.tensor(),
393+
weight.primitive.tensor(),
394+
bias.map(|b| b.primitive.tensor()),
395+
)))
396+
}

crates/burn-tensor/src/tensor/ops/modules/base.rs

+61
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use alloc::vec;
12
use core::num::NonZeroUsize;
23

34
use super::{conv, pool, unfold::unfold4d_using_conv2d};
@@ -764,6 +765,66 @@ pub trait ModuleOps<B: Backend> {
764765
output_size: [usize; 2],
765766
options: InterpolateOptions,
766767
) -> FloatTensor<B>;
768+
769+
/// Applies a linear transformation to the input tensor using the given weight and bias.
770+
///
771+
/// ```math
772+
/// y = x @ weight + [bias]
773+
/// ```
774+
///
775+
/// # Arguments:
776+
///
777+
/// - `input` is the input tensor, ``[..., d_input]``.
778+
/// - `weight` is the weight tensor, ``[d_input, d_output]``.
779+
/// - `b` is the bias tensor (optional), ``[d_output]``.
780+
///
781+
/// # Returns:
782+
///
783+
/// The transformed tensor, ``[..., d_output]``.
784+
///
785+
/// # Compatibility
786+
///
787+
/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
788+
/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
789+
/// multiplication:
790+
///
791+
/// ```math
792+
/// y = x @ weight^T + [bias]
793+
/// ```
794+
fn linear(
795+
input: FloatTensor<B>,
796+
weight: FloatTensor<B>,
797+
bias: Option<FloatTensor<B>>,
798+
) -> FloatTensor<B> {
799+
let ndims_in = input.shape().num_dims();
800+
let [d_input, d_output] = weight.shape().dims();
801+
if ndims_in == 1 {
802+
// Insert and remove an extra batch dimension for the batch matmul to work.
803+
let input = B::float_reshape(input, Shape::from([1, d_input]));
804+
let output = Self::linear(input, weight, bias);
805+
return B::float_reshape(output, Shape::from([d_output]));
806+
}
807+
808+
let weight = unsqueeze::<B>(weight, ndims_in);
809+
let output = B::float_matmul(input, weight);
810+
match bias {
811+
Some(bias) => B::float_add(output, unsqueeze::<B>(bias, ndims_in)),
812+
None => output,
813+
}
814+
}
815+
}
816+
817+
// Unsqueeze op on primitive.
818+
// TODO: would be nice to have this on primitives too for convenience.
819+
fn unsqueeze<B: Backend>(tensor: FloatTensor<B>, ndims_out: usize) -> FloatTensor<B> {
820+
let shape = tensor.shape();
821+
let ndims_in = shape.num_dims();
822+
823+
let mut dims = vec![1; ndims_out];
824+
let num_ones = ndims_out - ndims_in;
825+
dims[num_ones..(ndims_in + num_ones)].copy_from_slice(&shape.dims[..ndims_in]);
826+
827+
B::float_reshape(tensor, Shape::from(dims))
767828
}
768829

769830
#[cfg(test)]

0 commit comments

Comments
 (0)