Skip to content

Commit 9f2d2fa

Browse files
committed
move to moduleops
1 parent b4b6ede commit 9f2d2fa

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

crates/burn-core/src/nn/linear.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use crate as burn;
2-
use burn_tensor::ops::linear::linear;
32

43
use crate::config::Config;
54
use crate::module::Param;

crates/burn-tensor/src/tensor/ops/modules/base.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use core::num::NonZeroUsize;
22

33
use super::{conv, pool, unfold::unfold4d_using_conv2d};
4-
use crate::ops::linear::linear;
54
use crate::{
65
Shape, Tensor, TensorMetadata,
76
backend::Backend,
@@ -796,7 +795,17 @@ pub trait ModuleOps<B: Backend> {
796795
weight: Tensor<B, 2>,
797796
bias: Option<Tensor<B, 1>>,
798797
) -> Tensor<B, D> {
799-
linear(input, weight, bias)
798+
if D == 1 {
799+
// Insert and remove an extra batch dimension for the batch matmul to work.
800+
return Self::linear::<2>(input.unsqueeze(), weight, bias).flatten(0, 1);
801+
}
802+
803+
let weight = weight.unsqueeze();
804+
let output = input.matmul(weight);
805+
match bias {
806+
Some(bias) => output + bias.unsqueeze(),
807+
None => output,
808+
}
800809
}
801810
}
802811

crates/burn-tensor/src/tensor/ops/modules/mod.rs

-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,4 @@ pub mod pool;
1313

1414
mod base;
1515

16-
/// Module with linear operations.
17-
pub mod linear;
18-
1916
pub use base::*;

0 commit comments

Comments
 (0)