Skip to content

Commit 4eaa21d

Browse files
majosminducer
authored andcommitted
fix off-by-one error and add test
1 parent 08c2223 commit 4eaa21d

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pytato/analysis/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,9 @@ class NodeMaxDepthMapper(CachedWalkMapper):
560560

561561
def __init__(self) -> None:
562562
super().__init__()
563-
self.depth = 0
564-
self.max_depth = 0
563+
# Want the first rec() call to increment to 0, so start at -1
564+
self.depth = -1
565+
self.max_depth = -1
565566

566567
# FIXME: Do I need this?
567568
# type-ignore-reason: dropped the extra `*args, **kwargs`.

test/test_pytato.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,16 @@ def test_large_dag_with_duplicates_count():
765765
dag, count_duplicates=False)
766766

767767

768+
def test_nodemaxdepthmapper():
769+
from pytato.analysis import get_max_node_depth
770+
771+
x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64)
772+
for i in range(9):
773+
x = x + 1
774+
775+
assert get_max_node_depth(x) == 10
776+
777+
768778
def test_rec_get_user_nodes():
769779
x1 = pt.make_placeholder("x1", shape=(10, 4))
770780
x2 = pt.make_placeholder("x2", shape=(10, 4))

0 commit comments

Comments
 (0)