-
Notifications
You must be signed in to change notification settings - Fork 16
add materialization counter #333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
68eb741
7949b19
4a578d8
c6a1781
9a79943
80dbca5
de25c6a
9ca88bc
fbdfd6e
29281d9
d01e839
33d7d1e
68f159e
e50ca77
8296c04
2eb85d4
262787c
e65c25d
aaf2bf8
bfe8d0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -26,13 +26,15 @@ | |||||
""" | ||||||
|
||||||
from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, | ||||||
TYPE_CHECKING) | ||||||
TYPE_CHECKING, Iterable) | ||||||
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, | ||||||
DictOfNamedArrays, NamedArray, | ||||||
IndexBase, IndexRemappingBase, InputArgumentBase, | ||||||
ShapeType) | ||||||
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper | ||||||
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper, CombineMapper | ||||||
from pytato.loopy import LoopyCall | ||||||
from pytools.tag import Tag | ||||||
from pytato.tags import ImplStored | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from pytato.distributed import DistributedRecv, DistributedSendRefHolder | ||||||
|
@@ -47,6 +49,11 @@ | |||||
.. autofunction:: get_num_nodes | ||||||
|
||||||
.. autoclass:: DirectPredecessorsGetter | ||||||
|
||||||
.. autoclass:: TagCountMapper | ||||||
.. autofunction:: get_num_tags_of_type | ||||||
|
||||||
.. autofunction:: get_num_materialized | ||||||
""" | ||||||
|
||||||
|
||||||
|
@@ -371,12 +378,78 @@ def post_visit(self, expr: Any) -> None: | |||||
def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: | ||||||
"""Returns the number of nodes in DAG *outputs*.""" | ||||||
|
||||||
from pytato.codegen import normalize_outputs | ||||||
outputs = normalize_outputs(outputs) | ||||||
|
||||||
ncm = NodeCountMapper() | ||||||
ncm(outputs) | ||||||
|
||||||
return ncm.count | ||||||
|
||||||
# }}} | ||||||
|
||||||
|
||||||
# {{{ TagCountMapper | ||||||
|
||||||
class TagCountMapper(CombineMapper[int]): | ||||||
""" | ||||||
Returns the number of nodes in a DAG that are tagged with all the tags in *tags*. | ||||||
""" | ||||||
|
||||||
def __init__(self, tags: Union[Tag, Iterable[Tag]]) -> None: | ||||||
super().__init__() | ||||||
if isinstance(tags, Tag): | ||||||
tags = frozenset((tags,)) | ||||||
elif not isinstance(tags, frozenset): | ||||||
tags = frozenset(tags) | ||||||
self._tags = tags | ||||||
|
||||||
def combine(self, *args: int) -> int: | ||||||
return sum(args) | ||||||
|
||||||
# type-ignore reason: incompatible return type with super class | ||||||
def rec(self, expr: ArrayOrNames) -> int: # type: ignore | ||||||
if expr in self.cache: | ||||||
return self.cache[expr] | ||||||
|
||||||
if isinstance(expr, Array) and self._tags <= expr.tags: | ||||||
result = 1 + super().rec(expr) | ||||||
else: | ||||||
result = 0 + super().rec(expr) | ||||||
|
||||||
self.cache[expr] = 0 | ||||||
return result | ||||||
|
||||||
|
||||||
def get_num_tags_of_type( | ||||||
outputs: Union[Array, DictOfNamedArrays], | ||||||
tags: Union[Tag, Iterable[Tag]]) -> int: | ||||||
"""Returns the number of nodes in DAG *outputs* that are tagged with | ||||||
all the tags in *tags*.""" | ||||||
|
||||||
tcm = TagCountMapper(tags) | ||||||
|
||||||
return tcm(outputs) | ||||||
|
||||||
# }}} | ||||||
|
||||||
|
||||||
def get_num_materialized(outputs: Union[Array, DictOfNamedArrays]) \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My main issue with this approach is that it has at least quadratic complexity. A possible improvement would be a |
||||||
-> Dict[ArrayOrNames, int]: | ||||||
"""Returns the number of materialized nodes each node in *outputs* depends on.""" | ||||||
from pytato.transform import rec_get_all_user_nodes | ||||||
users = rec_get_all_user_nodes(outputs) | ||||||
|
||||||
def is_materialized(expr: ArrayOrNames) -> bool: | ||||||
if (isinstance(expr, Array) and | ||||||
any(isinstance(tag, ImplStored) for tag in expr.tags)): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return True | ||||||
else: | ||||||
return False | ||||||
|
||||||
res: Dict[ArrayOrNames, int] = {} | ||||||
|
||||||
for node in users.keys(): | ||||||
if is_materialized(node): | ||||||
for user in users[node]: | ||||||
res.setdefault(user, 0) | ||||||
res[user] += 1 | ||||||
|
||||||
return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.