|
1 |
| -use crate::tensor::Tensor; |
2 |
| -use crate::tensor::backend::Backend; |
| 1 | +use crate::{Tensor, backend::Backend}; |
3 | 2 |
|
4 | 3 | /// Applies a linear transformation to the input tensor using the given weight and bias.
|
5 | 4 | ///
|
@@ -43,67 +42,3 @@ pub fn linear<B: Backend, const D: usize>(
|
43 | 42 | None => output,
|
44 | 43 | }
|
45 | 44 | }
|
46 |
| - |
47 |
| -#[cfg(test)] |
48 |
| -mod tests { |
49 |
| - use super::*; |
50 |
| - use crate::TestBackend; |
51 |
| - use burn_tensor::TensorData; |
52 |
| - |
53 |
| - #[test] |
54 |
| - fn test_linear_1d() { |
55 |
| - let weight = |
56 |
| - Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default()); |
57 |
| - |
58 |
| - let x = Tensor::<TestBackend, 1>::from_data([1.0, 2.0], &Default::default()); |
59 |
| - |
60 |
| - linear(x.clone(), weight.clone(), None) |
61 |
| - .into_data() |
62 |
| - .assert_eq( |
63 |
| - &TensorData::from([7.0, 10.0]).convert_dtype(x.dtype()), |
64 |
| - true, |
65 |
| - ); |
66 |
| - } |
67 |
| - |
68 |
| - #[test] |
69 |
| - fn test_linear_forward_no_bias() { |
70 |
| - let weight = |
71 |
| - Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default()); |
72 |
| - |
73 |
| - let x = Tensor::<TestBackend, 3>::from_data( |
74 |
| - [[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]], |
75 |
| - &Default::default(), |
76 |
| - ); |
77 |
| - |
78 |
| - linear(x.clone(), weight.clone(), None) |
79 |
| - .into_data() |
80 |
| - .assert_eq( |
81 |
| - &TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]]) |
82 |
| - .convert_dtype(x.dtype()), |
83 |
| - true, |
84 |
| - ); |
85 |
| - } |
86 |
| - |
87 |
| - #[test] |
88 |
| - fn test_linear_forward_with_bias() { |
89 |
| - let weight = |
90 |
| - Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default()); |
91 |
| - let bias = Some(Tensor::<TestBackend, 1>::from_data( |
92 |
| - [1.0, -1.0], |
93 |
| - &Default::default(), |
94 |
| - )); |
95 |
| - |
96 |
| - let x = Tensor::<TestBackend, 3>::from_data( |
97 |
| - [[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]], |
98 |
| - &Default::default(), |
99 |
| - ); |
100 |
| - |
101 |
| - linear(x.clone(), weight.clone(), bias.clone()) |
102 |
| - .into_data() |
103 |
| - .assert_eq( |
104 |
| - &TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]]) |
105 |
| - .convert_dtype(x.dtype()), |
106 |
| - true, |
107 |
| - ); |
108 |
| - } |
109 |
| -} |
0 commit comments