Skip to content

Commit a4be4f7

Browse files
committed
add testlib utils for checking if indexee is always materialized and implements pt.auto_test_vs_ref
1 parent 8aedcba commit a4be4f7

File tree

1 file changed

+137
-2
lines changed

1 file changed

+137
-2
lines changed

test/testlib.py

Lines changed: 137 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import pyopencl as cl
77
import numpy as np
88
import pytato as pt
9-
from pytato.transform import Mapper
9+
from pytato.transform import Mapper, CombineMapper
1010
from pytato.array import (Array, Placeholder, Stack, Roll,
1111
AxisPermutation, DataWrapper, Reshape,
12-
Concatenate)
12+
Concatenate, DictOfNamedArrays, IndexBase,
13+
SizeParam)
1314
from pytools.tag import Tag
1415

1516

@@ -369,4 +370,138 @@ class QuuxTag(TestlibTag):
369370

370371
# }}}
371372

373+
374+
# {{{ utilities for test_push_indirections_*
375+
376+
class _IndexeeArraysMaterializedChecker(CombineMapper[bool]):
377+
def combine(self, *args: bool) -> bool:
378+
return all(args)
379+
380+
def map_placeholder(self, expr: Placeholder) -> bool:
381+
return True
382+
383+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
384+
return True
385+
386+
def map_size_param(self, expr: SizeParam) -> bool:
387+
return True
388+
389+
def _map_index_base(self, expr: IndexBase) -> bool:
390+
from pytato.transform.indirections import _is_materialized
391+
return self.combine(
392+
_is_materialized(expr.array) or isinstance(expr.array, IndexBase),
393+
self.rec(expr.array)
394+
)
395+
396+
397+
def are_all_indexees_materialized_nodes(
398+
expr: Union[Array, DictOfNamedArrays]) -> bool:
399+
"""
400+
Returns *True* only if all indexee arrays are either materialized nodes,
401+
OR, other indexing nodes that have materialized indexees.
402+
"""
403+
return _IndexeeArraysMaterializedChecker()(expr)
404+
405+
406+
class _IndexerArrayDatawrapperChecker(CombineMapper[bool]):
407+
def combine(self, *args: bool) -> bool:
408+
return all(args)
409+
410+
def map_placeholder(self, expr: Placeholder) -> bool:
411+
return True
412+
413+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
414+
return True
415+
416+
def map_size_param(self, expr: SizeParam) -> bool:
417+
return True
418+
419+
def _map_index_base(self, expr: IndexBase) -> bool:
420+
return self.combine(
421+
*[isinstance(idx, DataWrapper)
422+
for idx in expr.indices
423+
if isinstance(idx, Array)],
424+
super()._map_index_base(expr),
425+
)
426+
427+
428+
def are_all_indexer_arrays_datawrappers(
429+
expr: Union[Array, DictOfNamedArrays]) -> bool:
430+
"""
431+
Returns *True* only if all indexer arrays are instances of
432+
:class:`~pytato.array.DataWrapper`.
433+
"""
434+
return _IndexerArrayDatawrapperChecker()(expr)
435+
436+
# }}}
437+
438+
439+
# {{{ auto_test_vs_ref
440+
441+
class AutoTestFailureException(RuntimeError):
442+
"""
443+
Raised by :func:`auto_test_vs_ref` when the expressions do NOT match.
444+
"""
445+
446+
447+
def auto_test_vs_ref(cl_ctx: "cl.Context",
448+
actual: Union[Array, DictOfNamedArrays],
449+
desired: Union[Array, DictOfNamedArrays],
450+
*,
451+
rtol: float = 1e-07,
452+
atol: float = 0) -> None:
453+
import pyopencl.array as cla
454+
import loopy as lp
455+
from pytato.transform import InputGatherer
456+
457+
if isinstance(desired, Array):
458+
if not isinstance(actual, Array):
459+
raise AutoTestFailureException("'actual' is not an 'Array'")
460+
461+
desired = pt.make_dict_of_named_arrays({"_pt_out": desired})
462+
actual = pt.make_dict_of_named_arrays({"_pt_out": actual})
463+
else:
464+
assert isinstance(desired, DictOfNamedArrays)
465+
if not isinstance(actual, DictOfNamedArrays):
466+
raise AutoTestFailureException("'actual' is not"
467+
" a 'DictOfNamedArrays'")
468+
469+
cq = cl.CommandQueue(cl_ctx)
470+
471+
if (any(isinstance(inp, Placeholder)
472+
for inp in InputGatherer()(actual))
473+
or any(isinstance(inp, Placeholder)
474+
for inp in InputGatherer()(desired))):
475+
raise NotImplementedError("Expression graphs with placeholders not"
476+
" yet supported in auto_test_vs_ref.")
477+
478+
actual_prg = pt.generate_loopy(actual, options=lp.Options(return_dict=True))
479+
desired_prg = pt.generate_loopy(desired, options=lp.Options(return_dict=True))
480+
481+
_, actual_out_dict = actual_prg(cq)
482+
_, desired_out_dict = desired_prg(cq)
483+
484+
if set(actual_out_dict) != set(desired_out_dict):
485+
raise AutoTestFailureException(
486+
"Different outputs obtained from the 2 expressions. "
487+
f" '{set(actual_out_dict.keys())}' vs '{set(desired_out_dict.keys())}'"
488+
)
489+
490+
for output_name, desired_out in desired_out_dict.items():
491+
actual_out = actual_out_dict[output_name]
492+
493+
if isinstance(desired_out, cla.Array):
494+
desired_out = desired_out.get()
495+
if isinstance(actual_out, cla.Array):
496+
actual_out = actual_out.get()
497+
498+
try:
499+
np.testing.assert_allclose(actual_out, desired_out,
500+
rtol=rtol, atol=atol)
501+
except AssertionError as e:
502+
raise AutoTestFailureException(
503+
f"While comparing '{output_name}': \n{e.args[0]}")
504+
505+
# }}}
506+
372507
# vim: foldmethod=marker

0 commit comments

Comments
 (0)