Skip to content

Commit d368b7e

Browse files
committed
Lift linear to ModuleOps
1 parent 452b834 commit d368b7e

File tree

13 files changed

+48
-27
lines changed

13 files changed

+48
-27
lines changed

crates/burn-core/src/nn/linear.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use crate as burn;
2+
use burn_tensor::ops::linear::linear;
23

34
use crate::config::Config;
45
use crate::module::Param;
56
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6-
use crate::nn::functional::linear;
77
use crate::tensor::{Tensor, backend::Backend};
88

99
use super::Initializer;
@@ -30,8 +30,6 @@ pub struct LinearConfig {
3030
/// Should be created with [LinearConfig]
3131
///
3232
/// `O = IW + b`
33-
///
34-
/// See: [linear][nn::functional::linear]
3533
#[derive(Module, Debug)]
3634
#[module(custom_display)]
3735
pub struct Linear<B: Backend> {
@@ -83,7 +81,8 @@ impl<B: Backend> Linear<B> {
8381
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
8482
let weight = self.weight.val();
8583
let bias = self.bias.as_ref().map(|b| b.val());
86-
linear(input, weight, bias)
84+
85+
B::linear(input, weight, bias)
8786
}
8887
}
8988

crates/burn-core/src/nn/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
pub use burn_tensor::nn::*;
2-
31
/// Attention module
42
pub mod attention;
53

crates/burn-tensor/src/tensor/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ pub mod loss;
3232
/// The burn module.
3333
pub mod module;
3434

35-
/// The nn module.
36-
pub mod nn;
37-
3835
/// Operations on tensors module.
3936
pub mod ops;
4037

crates/burn-tensor/src/tensor/nn/functional/mod.rs

Lines changed: 0 additions & 2 deletions
This file was deleted.

crates/burn-tensor/src/tensor/nn/mod.rs

Lines changed: 0 additions & 4 deletions
This file was deleted.

crates/burn-tensor/src/tensor/ops/modules/base.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use core::num::NonZeroUsize;
22

33
use super::{conv, pool, unfold::unfold4d_using_conv2d};
4+
use crate::ops::linear::linear;
45
use crate::{
5-
Shape, TensorMetadata,
6+
Shape, Tensor, TensorMetadata,
67
backend::Backend,
78
ops::{FloatTensor, IntTensor},
89
};
@@ -764,6 +765,39 @@ pub trait ModuleOps<B: Backend> {
764765
output_size: [usize; 2],
765766
options: InterpolateOptions,
766767
) -> FloatTensor<B>;
768+
769+
/// Applies a linear transformation to the input tensor using the given weight and bias.
770+
///
771+
/// ```math
772+
/// y = x @ weight + [bias]
773+
/// ```
774+
///
775+
/// # Arguments:
776+
///
777+
/// - `input` is the input tensor, ``[..., d_input]``.
778+
/// - `weight` is the weight tensor, ``[d_input, d_output]``.
779+
/// - `b` is the bias tensor (optional), ``[d_output]``.
780+
///
781+
/// # Returns:
782+
///
783+
/// The transformed tensor, ``[..., d_output]``.
784+
///
785+
/// # Compatibility
786+
///
787+
/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
788+
/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
789+
/// multiplication:
790+
///
791+
/// ```math
792+
/// y = x @ weight^T + [bias]
793+
/// ```
794+
fn linear<const D: usize>(
795+
input: Tensor<B, D>,
796+
weight: Tensor<B, 2>,
797+
bias: Option<Tensor<B, 1>>,
798+
) -> Tensor<B, D> {
799+
linear(input, weight, bias)
800+
}
767801
}
768802

769803
#[cfg(test)]

crates/burn-tensor/src/tensor/ops/modules/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ pub mod pool;
1313

1414
mod base;
1515

16+
/// Module with linear operations.
17+
pub mod linear;
18+
1619
pub use base::*;

crates/burn-tensor/src/tests/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ mod activation;
22
mod clone_invariance;
33
mod grid;
44
mod module;
5-
mod nn;
65
mod ops;
76
mod primitive;
87
mod quantization;
@@ -156,9 +155,6 @@ macro_rules! testgen_with_float_param {
156155
burn_tensor::testgen_silu!();
157156
burn_tensor::testgen_tanh_activation!();
158157

159-
// test nn.functional
160-
burn_tensor::testgen_nn_fn_vector_norm!();
161-
162158
// test grid
163159
burn_tensor::testgen_meshgrid!();
164160

@@ -181,6 +177,7 @@ macro_rules! testgen_with_float_param {
181177
burn_tensor::testgen_module_nearest_interpolate!();
182178
burn_tensor::testgen_module_bilinear_interpolate!();
183179
burn_tensor::testgen_module_bicubic_interpolate!();
180+
burn_tensor::testgen_module_linear!();
184181

185182
// test ops
186183
burn_tensor::testgen_gather_scatter!();

crates/burn-tensor/src/tests/nn/functional/mod.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

crates/burn-tensor/src/tests/nn/mod.rs

Lines changed: 0 additions & 1 deletion
This file was deleted.

crates/burn-tensor/src/tests/nn/functional/vector_norm.rs renamed to crates/burn-tensor/src/tests/ops/linear.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#[burn_tensor_testgen::testgen(nn_fn_vector_norm)]
1+
#[burn_tensor_testgen::testgen(module_linear)]
22
mod tests {
33
use super::*;
4-
use burn_tensor::nn::functional::linear;
4+
use burn_tensor::ops::ModuleOps;
55
use burn_tensor::{Tensor, TensorData};
66

77
#[test]
@@ -10,7 +10,7 @@ mod tests {
1010

1111
let x = TestTensor::<1>::from([1.0, 2.0]);
1212

13-
linear(x.clone(), weight.clone(), None)
13+
TestBackend::linear(x.clone(), weight.clone(), None)
1414
.into_data()
1515
.assert_eq(
1616
&TensorData::from([7.0, 10.0]).convert_dtype(x.dtype()),
@@ -24,7 +24,7 @@ mod tests {
2424

2525
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
2626

27-
linear(x.clone(), weight.clone(), None)
27+
TestBackend::linear(x.clone(), weight.clone(), None)
2828
.into_data()
2929
.assert_eq(
3030
&TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]])
@@ -40,7 +40,7 @@ mod tests {
4040

4141
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
4242

43-
linear(x.clone(), weight.clone(), bias.clone())
43+
TestBackend::linear(x.clone(), weight.clone(), bias.clone())
4444
.into_data()
4545
.assert_eq(
4646
&TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]])

crates/burn-tensor/src/tests/ops/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ mod full;
3030
mod gather_scatter;
3131
mod init;
3232
mod iter_dim;
33+
pub(crate) mod linear;
3334
mod log;
3435
mod log1p;
3536
mod map_comparison;

0 commit comments

Comments
 (0)