Skip to content

Commit 11ca8f1

Browse files
Cristian GarciaFlax Authors
authored andcommitted
avoid __eq__ in test_iter_graph
PiperOrigin-RevId: 878762402
1 parent b6a683e commit 11ca8f1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/nnx/graph_utils_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,11 +1464,11 @@ def test_iter_graph(self, graph):
14641464
root.x = child
14651465
root.y = jnp.ones(2)
14661466

1467-
nodes = [node for _, node in nnx.graph.iter_graph(root, graph=graph)]
1468-
self.assertIn(var0, nodes)
1469-
self.assertIn(var1, nodes)
1470-
self.assertIn(child, nodes)
1471-
self.assertIn(root, nodes)
1467+
node_ids = [id(node) for _, node in nnx.graph.iter_graph(root, graph=graph)]
1468+
self.assertIn(id(var0), node_ids)
1469+
self.assertIn(id(var1), node_ids)
1470+
self.assertIn(id(child), node_ids)
1471+
self.assertIn(id(root), node_ids)
14721472

14731473
def test_iter_graph_tree_mode_shared_variable_raises(self):
14741474
var = nnx.Variable(jnp.zeros(1))

0 commit comments

Comments
 (0)