Skip to content

Commit 3cb487a

Browse files
committed
MLP library tests
1 parent 3640df4 commit 3cb487a

14 files changed

+48374
-138
lines changed

tensilelite/HostLibraryTests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ set(test_sources ${test_sources}
110110
ContractionSelectionLibrary_test.cpp
111111
DataTypes_test.cpp
112112
Predicates_test.cpp
113+
MLPNet_test.cpp
113114
)
114115

115116
if(TENSILE_USE_LLVM)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (C) 2019-2025 Advanced Micro Devices, Inc. All rights reserved.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in
15+
* all copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
27+
#include <stdexcept>
28+
#include <random>
29+
30+
#include <gtest/gtest.h>
31+
#include <Tensile/MLPClassification.hpp>
32+
33+
constexpr double abs_error = 10. * std::numeric_limits<TensileLite::MLPClassification::dtype>::epsilon();
34+
35+
TEST(MLPNet, DenseLayer)
36+
{
37+
using namespace TensileLite;
38+
using namespace MLPClassification;
39+
40+
DenseLayer test_dense{
41+
/* weights */
42+
std::vector({0.6634151484170232f, 0.788165180871102f, 0.31248166526753884f,
43+
0.23942935302736823f, 0.6809768405365064f, 0.1367808736375885f,
44+
0.5374190113071796f, 0.9243177539724999f, 0.2626090032418886f,
45+
0.25681768989410403f, 0.9874451518147117f, 0.42539241956479557f}),
46+
/* bias */
47+
std::vector({0.5428595043381658f, 0.17016816526861123f,
48+
0.848431351596801f, 0.8236885843811014f})
49+
};
50+
51+
auto Fout = test_dense({0.43858885174024276f, 0.7889586579958023f, 0.6683846141676605f});
52+
53+
std::vector<dtype> Ftrue({1.6645136731628662f, 0.9038640159734725f,
54+
1.9889096507140147f, 1.9997051101391978f});
55+
56+
for (std::size_t i=0; i<Fout.size(); i++)
57+
EXPECT_NEAR(Fout[i], Ftrue[i], abs_error);
58+
}
59+
60+
template <typename T>
61+
std::vector<T> normal_random_vector(std::size_t n) {
62+
std::default_random_engine gen;
63+
std::normal_distribution<T> dist(0., 1.0);
64+
auto generator = std::bind(dist, gen);
65+
std::vector<T> v(n);
66+
std::generate(v.begin(), v.end(), generator);
67+
return v;
68+
}
69+
70+
TensileLite::MLPClassification::DenseLayer
71+
random_dense_layer(std::size_t n_in, std::size_t n_out) {
72+
using namespace TensileLite;
73+
using namespace MLPClassification;
74+
return DenseLayer(normal_random_vector<dtype>(n_in * n_out),
75+
normal_random_vector<dtype>(n_out));
76+
}
77+
78+
TensileLite::MLPClassification::ResBlock
79+
random_resblock(std::size_t n_in, std::size_t hidden, std::size_t n_out) {
80+
using namespace TensileLite;
81+
using namespace MLPClassification;
82+
ResBlock r;
83+
r.linear1 = random_dense_layer(n_in, hidden);
84+
r.linear2 = random_dense_layer(hidden, n_out);
85+
r.res = random_dense_layer(n_in, n_out);
86+
return r;
87+
}
88+
89+
TEST(MLPNet, DenseLayerFixed)
90+
{
91+
using namespace TensileLite;
92+
using namespace MLPClassification;
93+
94+
/* DenseLayer has some sizes hardcoded for optimization, test these sizes */
95+
int n_in = 16, n_out = 3;
96+
97+
auto weights = normal_random_vector<dtype>(n_out*n_in);
98+
auto bias = normal_random_vector<dtype>(n_out);
99+
DenseLayer test_dense{weights, bias};
100+
// EXPECT_FALSE(std::string(typeid(test_dense.W.get()).name()).find("WeightMatrixFixed") == std::string::npos);
101+
102+
auto Fin = normal_random_vector<dtype>(n_in);
103+
auto Fout = test_dense(Fin);
104+
std::vector<dtype> Ftrue(n_out);
105+
for (int i=0; i<n_out; i++) {
106+
dtype ftrue = bias[i];
107+
for (int j=0; j<n_in; j++)
108+
ftrue += weights[i*n_in+j] * Fin[j];
109+
EXPECT_NEAR(Fout[i], ftrue, abs_error);
110+
}
111+
}
112+
113+
TEST(MLPNet, DenseLayerDimFail)
114+
{
115+
using namespace TensileLite;
116+
using namespace MLPClassification;
117+
118+
/* weights dimension is not a multiple of bias dimension */
119+
EXPECT_THROW(
120+
(DenseLayer{std::vector({1.f, 2.f, 3.f}), std::vector({1.f, 2.f})}),
121+
std::runtime_error);
122+
}
123+
124+
TEST(MLPNet, StandardScaler)
125+
{
126+
using namespace TensileLite;
127+
using namespace MLPClassification;
128+
129+
StandardScaler test_scaler{
130+
/* mean */ std::vector{0.4525329262019901f, 0.8647806535129754f},
131+
/* scale */ std::vector{0.05201354125426511f, 0.06123320047178044f}
132+
};
133+
134+
std::vector Fin{3.991355396203433e-04f, 6.381927186481492e-01f};
135+
auto F = Fin;
136+
test_scaler(F);
137+
138+
for (std::size_t i=0; i<F.size(); i++)
139+
EXPECT_NEAR(F[i], (Fin[i] - test_scaler.mean[i]) / test_scaler.scale[i], abs_error);
140+
141+
std::vector<dtype> Ftrue({-8.692616956267996f, -3.7004097959774316f});
142+
for (std::size_t i=0; i<F.size(); i++)
143+
EXPECT_NEAR(F[i], Ftrue[i], abs_error);
144+
}
145+
146+
TEST(MLPNet, ResBlock)
147+
{
148+
using namespace TensileLite;
149+
using namespace MLPClassification;
150+
151+
int n_in = 3, h = 6, n_out = 5;
152+
153+
ResBlock b = random_resblock(n_in, h, n_out);
154+
auto Fin = normal_random_vector<dtype>(n_in);
155+
auto Fout = b(Fin);
156+
157+
auto Ftmp = b.linear1(Fin);
158+
for (auto& f : Ftmp)
159+
f = f > 0 ? f : 0.;
160+
auto Ftrue = b.linear2(Ftmp);
161+
auto Fres = b.res(Fin);
162+
for (std::size_t i=0; i<Fres.size(); i++)
163+
Ftrue[i] += Fres[i];
164+
for (auto& f : Ftrue)
165+
f = f > 0 ? f : 0.;
166+
167+
for (std::size_t i=0; i<Ftrue.size(); i++)
168+
EXPECT_NEAR(Ftrue[i], Fout[i], abs_error);
169+
}
170+
171+
TEST(MLPNet, MLPNet)
172+
{
173+
using namespace TensileLite;
174+
using namespace MLPClassification;
175+
176+
std::size_t n_solutions = 9, n_features = MLPNet::n_features;
177+
std::size_t h1 = 3, h2 = 5, h3 = 4, h4 = 7;
178+
179+
MLPNet net;
180+
net.res_blocks.push_back(random_resblock(n_features, h1, h2));
181+
net.res_blocks.push_back(random_resblock(h2, h3, h4));
182+
net.dense = random_dense_layer(h4, n_solutions);
183+
net.scaler.mean = std::vector<dtype>(n_features, .7);
184+
net.scaler.scale = std::vector<dtype>(n_features, 3.6);
185+
186+
EXPECT_TRUE(net.valid());
187+
188+
std::vector<float> probkey = normal_random_vector<float>(4);
189+
auto Fout = net.predict(probkey);
190+
EXPECT_TRUE(Fout.size() == n_solutions);
191+
192+
for (auto fi : Fout) {
193+
EXPECT_TRUE(std::isfinite(fi));
194+
EXPECT_FALSE(std::isnan(fi));
195+
}
196+
}

tensilelite/HostLibraryTests/configs/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ if(TENSILE_USE_LLVM)
2626
set(SOLUTION_LIBRARY_FILES_
2727
${SOLUTION_LIBRARY_FILES_}
2828
"${CMAKE_CURRENT_SOURCE_DIR}/SolutionLibraries/Kernels.yaml"
29+
"${CMAKE_CURRENT_SOURCE_DIR}/SolutionLibraries/Mlp_Kernels.yaml"
2930
)
3031
endif()
3132

3233
if(TENSILE_USE_MSGPACK)
3334
set(SOLUTION_LIBRARY_FILES_
3435
${SOLUTION_LIBRARY_FILES_}
3536
"${CMAKE_CURRENT_SOURCE_DIR}/SolutionLibraries/Kernels.dat"
37+
"${CMAKE_CURRENT_SOURCE_DIR}/SolutionLibraries/Mlp_Kernels.dat"
3638
)
3739
endif()
3840

Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)