Skip to content

Commit c5a120f

Browse files
committed
test(computation-graph): add test cases for forward propagation
- Introduced test cases to validate the forward propagation of the computation graph. - Tested forward propagation with different input data types and shapes. - Ensured the correct calculation of output values at each node in the computation graph during forward propagation.
1 parent b17fdb7 commit c5a120f

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

test/Test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <TensorOperations.cuh>
55
#include <Nodes.cuh>
66
#include <Optimizer.cuh>
7+
#include <ComputeGraph.cuh>
78
using namespace nz::data;
89
using namespace nz::nodes;
910
using namespace nz::nodes::calc;
@@ -2554,4 +2555,61 @@ TEST(OptimizerBasic, AdaDeltaTest) {
25542555
expected.dataInject(expectedData.begin(), expectedData.end());
25552556
expected.dataInject(grad.begin(), grad.end(), true);
25562557
EXPECT_EQ(*input.output, expected);
2558+
}
2559+
2560+
TEST(ComputeGraph, GraphForwardTest) {
2561+
graph::ComputeGraph graph;
2562+
InputNode input1({2, 3, 4, 5});
2563+
InputNode param1({1, 3, 5, 1});
2564+
InputNode param2({2, 3, 4, 1});
2565+
InputNode target({2, 3, 4, 1});
2566+
MatMulNode matmul(&input1, &param1);
2567+
ReLUNode relu(&matmul);
2568+
AddNode add(&relu, &param2);
2569+
MeanSquaredErrorNode mse(&add, &target);
2570+
MappedTensor input1Data({2, 3, 4, 5});
2571+
MappedTensor param1Data({1, 3, 5, 1});
2572+
MappedTensor param2Data({2, 3, 4, 1});
2573+
MappedTensor targetData({2, 3, 4, 1});
2574+
std::random_device rd;
2575+
std::mt19937 gen(rd());
2576+
std::uniform_real_distribution<float> dist(-10.0f, 10.0f);
2577+
for (auto& i : input1Data) {
2578+
i = dist(gen);
2579+
}
2580+
for (auto& i : param1Data) {
2581+
i = dist(gen);
2582+
}
2583+
for (auto& i : param2Data) {
2584+
i = dist(gen);
2585+
}
2586+
for (auto& i : targetData) {
2587+
i = dist(gen);
2588+
}
2589+
input1.dataInject(input1Data.begin(), input1Data.end());
2590+
param1.dataInject(param1Data.begin(), param1Data.end());
2591+
param2.dataInject(param2Data.begin(), param2Data.end());
2592+
target.dataInject(targetData.begin(), targetData.end());
2593+
MappedTensor mulResult({2, 3, 4, 1});
2594+
GEMMTensorCore(mulResult, input1Data, param1Data);
2595+
auto reluResult = ReLU(mulResult);
2596+
auto addResult = reluResult + param2Data;
2597+
float loss = 0.0f;
2598+
for (auto i = 0; i < addResult.size(); i++) {
2599+
loss += (addResult[i] - targetData[i]) * (addResult[i] - targetData[i]);
2600+
}
2601+
loss /= static_cast<float>(addResult.size());
2602+
graph.addNode(&input1);
2603+
graph.addNode(&param1);
2604+
graph.addNode(&param2);
2605+
graph.addNode(&target);
2606+
graph.addNode(&matmul);
2607+
graph.addNode(&relu);
2608+
graph.addNode(&add);
2609+
graph.addNode(&mse);
2610+
graph.forward();
2611+
Tensor expected({2, 3, 4, 1});
2612+
expected.dataInject(addResult.begin(), addResult.end());
2613+
EXPECT_EQ(expected, *add.output);
2614+
EXPECT_NEAR(loss, mse.getLoss(), 1e-2);
25572615
}

0 commit comments

Comments
 (0)