44#include < TensorOperations.cuh>
55#include < Nodes.cuh>
66#include < Optimizer.cuh>
7+ #include < ComputeGraph.cuh>
78using namespace nz ::data;
89using namespace nz ::nodes;
910using namespace nz ::nodes::calc;
@@ -2554,4 +2555,61 @@ TEST(OptimizerBasic, AdaDeltaTest) {
25542555 expected.dataInject (expectedData.begin (), expectedData.end ());
25552556 expected.dataInject (grad.begin (), grad.end (), true );
25562557 EXPECT_EQ (*input.output , expected);
2558+ }
2559+
2560+ TEST (ComputeGraph, GraphForwardTest) {
2561+ graph::ComputeGraph graph;
2562+ InputNode input1 ({2 , 3 , 4 , 5 });
2563+ InputNode param1 ({1 , 3 , 5 , 1 });
2564+ InputNode param2 ({2 , 3 , 4 , 1 });
2565+ InputNode target ({2 , 3 , 4 , 1 });
2566+ MatMulNode matmul (&input1, ¶m1);
2567+ ReLUNode relu (&matmul);
2568+ AddNode add (&relu, ¶m2);
2569+ MeanSquaredErrorNode mse (&add, &target);
2570+ MappedTensor input1Data ({2 , 3 , 4 , 5 });
2571+ MappedTensor param1Data ({1 , 3 , 5 , 1 });
2572+ MappedTensor param2Data ({2 , 3 , 4 , 1 });
2573+ MappedTensor targetData ({2 , 3 , 4 , 1 });
2574+ std::random_device rd;
2575+ std::mt19937 gen (rd ());
2576+ std::uniform_real_distribution<float > dist (-10 .0f , 10 .0f );
2577+ for (auto & i : input1Data) {
2578+ i = dist (gen);
2579+ }
2580+ for (auto & i : param1Data) {
2581+ i = dist (gen);
2582+ }
2583+ for (auto & i : param2Data) {
2584+ i = dist (gen);
2585+ }
2586+ for (auto & i : targetData) {
2587+ i = dist (gen);
2588+ }
2589+ input1.dataInject (input1Data.begin (), input1Data.end ());
2590+ param1.dataInject (param1Data.begin (), param1Data.end ());
2591+ param2.dataInject (param2Data.begin (), param2Data.end ());
2592+ target.dataInject (targetData.begin (), targetData.end ());
2593+ MappedTensor mulResult ({2 , 3 , 4 , 1 });
2594+ GEMMTensorCore (mulResult, input1Data, param1Data);
2595+ auto reluResult = ReLU (mulResult);
2596+ auto addResult = reluResult + param2Data;
2597+ float loss = 0 .0f ;
2598+ for (auto i = 0 ; i < addResult.size (); i++) {
2599+ loss += (addResult[i] - targetData[i]) * (addResult[i] - targetData[i]);
2600+ }
2601+ loss /= static_cast <float >(addResult.size ());
2602+ graph.addNode (&input1);
2603+ graph.addNode (¶m1);
2604+ graph.addNode (¶m2);
2605+ graph.addNode (&target);
2606+ graph.addNode (&matmul);
2607+ graph.addNode (&relu);
2608+ graph.addNode (&add);
2609+ graph.addNode (&mse);
2610+ graph.forward ();
2611+ Tensor expected ({2 , 3 , 4 , 1 });
2612+ expected.dataInject (addResult.begin (), addResult.end ());
2613+ EXPECT_EQ (expected, *add.output );
2614+ EXPECT_NEAR (loss, mse.getLoss (), 1e-2 );
25572615}
0 commit comments