Skip to content

Commit 9824b97

Browse files
committed
fix tests
1 parent bee9ed4 commit 9824b97

File tree

5 files changed

+58
-66
lines changed

5 files changed

+58
-66
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use crate::tensor::Tensor;
2-
use crate::tensor::backend::Backend;
1+
use crate::{Tensor, backend::Backend};
32

43
/// Applies a linear transformation to the input tensor using the given weight and bias.
54
///
@@ -43,67 +42,3 @@ pub fn linear<B: Backend, const D: usize>(
4342
None => output,
4443
}
4544
}
46-
47-
#[cfg(test)]
48-
mod tests {
49-
use super::*;
50-
use crate::TestBackend;
51-
use burn_tensor::TensorData;
52-
53-
#[test]
54-
fn test_linear_1d() {
55-
let weight =
56-
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default());
57-
58-
let x = Tensor::<TestBackend, 1>::from_data([1.0, 2.0], &Default::default());
59-
60-
linear(x.clone(), weight.clone(), None)
61-
.into_data()
62-
.assert_eq(
63-
&TensorData::from([7.0, 10.0]).convert_dtype(x.dtype()),
64-
true,
65-
);
66-
}
67-
68-
#[test]
69-
fn test_linear_forward_no_bias() {
70-
let weight =
71-
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default());
72-
73-
let x = Tensor::<TestBackend, 3>::from_data(
74-
[[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]],
75-
&Default::default(),
76-
);
77-
78-
linear(x.clone(), weight.clone(), None)
79-
.into_data()
80-
.assert_eq(
81-
&TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]])
82-
.convert_dtype(x.dtype()),
83-
true,
84-
);
85-
}
86-
87-
#[test]
88-
fn test_linear_forward_with_bias() {
89-
let weight =
90-
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &Default::default());
91-
let bias = Some(Tensor::<TestBackend, 1>::from_data(
92-
[1.0, -1.0],
93-
&Default::default(),
94-
));
95-
96-
let x = Tensor::<TestBackend, 3>::from_data(
97-
[[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]],
98-
&Default::default(),
99-
);
100-
101-
linear(x.clone(), weight.clone(), bias.clone())
102-
.into_data()
103-
.assert_eq(
104-
&TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]])
105-
.convert_dtype(x.dtype()),
106-
true,
107-
);
108-
}
109-
}

crates/burn-tensor/src/tests/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod activation;
22
mod clone_invariance;
33
mod grid;
44
mod module;
5+
mod nn;
56
mod ops;
67
mod primitive;
78
mod quantization;
@@ -155,6 +156,9 @@ macro_rules! testgen_with_float_param {
155156
burn_tensor::testgen_silu!();
156157
burn_tensor::testgen_tanh_activation!();
157158

159+
// test nn.functional
160+
burn_tensor::testgen_nn_fn_vector_norm!();
161+
158162
// test grid
159163
burn_tensor::testgen_meshgrid!();
160164

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub(crate) mod vector_norm;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#[burn_tensor_testgen::testgen(nn_fn_vector_norm)]
2+
mod tests {
3+
use super::*;
4+
use burn_tensor::nn::functional::linear;
5+
use burn_tensor::{Tensor, TensorData};
6+
7+
#[test]
8+
fn test_linear_1d() {
9+
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
10+
11+
let x = TestTensor::<1>::from([1.0, 2.0]);
12+
13+
linear(x.clone(), weight.clone(), None)
14+
.into_data()
15+
.assert_eq(
16+
&TensorData::from([7.0, 10.0]).convert_dtype(x.dtype()),
17+
true,
18+
);
19+
}
20+
21+
#[test]
22+
fn test_linear_forward_no_bias() {
23+
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
24+
25+
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
26+
27+
linear(x.clone(), weight.clone(), None)
28+
.into_data()
29+
.assert_eq(
30+
&TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]])
31+
.convert_dtype(x.dtype()),
32+
true,
33+
);
34+
}
35+
36+
#[test]
37+
fn test_linear_forward_with_bias() {
38+
let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);
39+
let bias = Some(TestTensor::<1>::from([1.0, -1.0]));
40+
41+
let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);
42+
43+
linear(x.clone(), weight.clone(), bias.clone())
44+
.into_data()
45+
.assert_eq(
46+
&TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]])
47+
.convert_dtype(x.dtype()),
48+
true,
49+
);
50+
}
51+
}
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub(crate) mod functional;

0 commit comments

Comments
 (0)