Skip to content

Commit f7cfd14

Browse files
committed
test(computation-graph): add test cases for forward propagation
- Introduced test cases to validate the forward propagation of the computation graph. - Tested forward propagation with different input data types and shapes. - Ensured the correct calculation of output values at each node in the computation graph during forward propagation.
1 parent c5a120f commit f7cfd14

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

include/NeuZephyr/ComputeGraph.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,8 @@ namespace nz::graph {
14531453
*/
14541454
void update(Optimizer* optimizer) const;
14551455

1456+
bool inGraph(Node* node) const;
1457+
14561458
/// @}
14571459

14581460
/// @name File Managers

src/ComputeGraph.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ namespace nz::graph {
416416
throw std::runtime_error("Graph is not sorted");
417417
}
418418
if (outputNodes.size() == 1) {
419+
outputNodes[0]->output->sync();
419420
for (auto it = sortedNodes.rbegin(); it != sortedNodes.rend(); ++it) {
420421
(*it)->backward();
421422
}
@@ -545,6 +546,10 @@ namespace nz::graph {
545546
}
546547
}
547548

549+
bool ComputeGraph::inGraph(Node* node) const {
550+
return std::find(nodes.begin(), nodes.end(), node) != nodes.end();
551+
}
552+
548553
void ComputeGraph::save(const std::string& path) {
549554
if (path.empty()) {
550555
throw std::runtime_error("Path cannot be empty");

0 commit comments

Comments
 (0)