Skip to content

Commit 271ec0e

Browse files
majosminducer
authored andcommitted
fix more bugs
1 parent 37e66f7 commit 271ec0e

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

pytato/analysis/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
571571

572572
def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None:
573573
"""Call the mapper method of *expr* and return the result."""
574-
self.depth += 1
575-
self.max_depth = max(self.max_depth, self.depth)
576-
577-
try:
574+
if isinstance(expr, DictOfNamedArrays):
578575
super().rec(expr, *args, **kwargs)
579-
finally:
580-
self.depth -= 1
576+
else:
577+
self.depth += 1
578+
self.max_depth = max(self.max_depth, self.depth)
579+
580+
try:
581+
super().rec(expr, *args, **kwargs)
582+
finally:
583+
self.depth -= 1
581584

582585

583586
def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int:

test/test_pytato.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,10 @@ def test_nodemaxdepthmapper():
769769
from pytato.analysis import get_max_node_depth
770770

771771
x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64)
772-
for _ in range(9):
772+
773+
assert get_max_node_depth(x) == 0
774+
775+
for _ in range(10):
773776
x = x + 1
774777

775778
assert get_max_node_depth(x) == 10

0 commit comments

Comments
 (0)