|
6 | 6 | import pyopencl as cl
|
7 | 7 | import numpy as np
|
8 | 8 | import pytato as pt
|
9 |
| -from pytato.transform import Mapper |
| 9 | +from pytato.transform import Mapper, CombineMapper |
10 | 10 | from pytato.array import (Array, Placeholder, Stack, Roll,
|
11 | 11 | AxisPermutation, DataWrapper, Reshape,
|
12 |
| - Concatenate) |
| 12 | + Concatenate, DictOfNamedArrays, IndexBase, |
| 13 | + SizeParam) |
13 | 14 | from pytools.tag import Tag
|
14 | 15 |
|
15 | 16 |
|
@@ -369,4 +370,138 @@ class QuuxTag(TestlibTag):
|
369 | 370 |
|
370 | 371 | # }}}
|
371 | 372 |
|
| 373 | + |
| 374 | +# {{{ utilities for test_push_indirections_* |
| 375 | + |
| 376 | +class _IndexeeArraysMaterializedChecker(CombineMapper[bool]): |
| 377 | + def combine(self, *args: bool) -> bool: |
| 378 | + return all(args) |
| 379 | + |
| 380 | + def map_placeholder(self, expr: Placeholder) -> bool: |
| 381 | + return True |
| 382 | + |
| 383 | + def map_data_wrapper(self, expr: DataWrapper) -> bool: |
| 384 | + return True |
| 385 | + |
| 386 | + def map_size_param(self, expr: SizeParam) -> bool: |
| 387 | + return True |
| 388 | + |
| 389 | + def _map_index_base(self, expr: IndexBase) -> bool: |
| 390 | + from pytato.transform.indirections import _is_materialized |
| 391 | + return self.combine( |
| 392 | + _is_materialized(expr.array) or isinstance(expr.array, IndexBase), |
| 393 | + self.rec(expr.array) |
| 394 | + ) |
| 395 | + |
| 396 | + |
| 397 | +def are_all_indexees_materialized_nodes( |
| 398 | + expr: Union[Array, DictOfNamedArrays]) -> bool: |
| 399 | + """ |
| 400 | + Returns *True* only if all indexee arrays are either materialized nodes, |
| 401 | + OR, other indexing nodes that have materialized indexees. |
| 402 | + """ |
| 403 | + return _IndexeeArraysMaterializedChecker()(expr) |
| 404 | + |
| 405 | + |
| 406 | +class _IndexerArrayDatawrapperChecker(CombineMapper[bool]): |
| 407 | + def combine(self, *args: bool) -> bool: |
| 408 | + return all(args) |
| 409 | + |
| 410 | + def map_placeholder(self, expr: Placeholder) -> bool: |
| 411 | + return True |
| 412 | + |
| 413 | + def map_data_wrapper(self, expr: DataWrapper) -> bool: |
| 414 | + return True |
| 415 | + |
| 416 | + def map_size_param(self, expr: SizeParam) -> bool: |
| 417 | + return True |
| 418 | + |
| 419 | + def _map_index_base(self, expr: IndexBase) -> bool: |
| 420 | + return self.combine( |
| 421 | + *[isinstance(idx, DataWrapper) |
| 422 | + for idx in expr.indices |
| 423 | + if isinstance(idx, Array)], |
| 424 | + super()._map_index_base(expr), |
| 425 | + ) |
| 426 | + |
| 427 | + |
| 428 | +def are_all_indexer_arrays_datawrappers( |
| 429 | + expr: Union[Array, DictOfNamedArrays]) -> bool: |
| 430 | + """ |
| 431 | + Returns *True* only if all indexer arrays are instances of |
| 432 | + :class:`~pytato.array.DataWrapper`. |
| 433 | + """ |
| 434 | + return _IndexerArrayDatawrapperChecker()(expr) |
| 435 | + |
| 436 | +# }}} |
| 437 | + |
| 438 | + |
| 439 | +# {{{ auto_test_vs_ref |
| 440 | + |
| 441 | +class AutoTestFailureException(RuntimeError): |
| 442 | + """ |
| 443 | + Raised by :func:`auto_test_vs_ref` when the expressions do NOT match. |
| 444 | + """ |
| 445 | + |
| 446 | + |
| 447 | +def auto_test_vs_ref(cl_ctx: "cl.Context", |
| 448 | + actual: Union[Array, DictOfNamedArrays], |
| 449 | + desired: Union[Array, DictOfNamedArrays], |
| 450 | + *, |
| 451 | + rtol: float = 1e-07, |
| 452 | + atol: float = 0) -> None: |
| 453 | + import pyopencl.array as cla |
| 454 | + import loopy as lp |
| 455 | + from pytato.transform import InputGatherer |
| 456 | + |
| 457 | + if isinstance(desired, Array): |
| 458 | + if not isinstance(actual, Array): |
| 459 | + raise AutoTestFailureException("'actual' is not an 'Array'") |
| 460 | + |
| 461 | + desired = pt.make_dict_of_named_arrays({"_pt_out": desired}) |
| 462 | + actual = pt.make_dict_of_named_arrays({"_pt_out": actual}) |
| 463 | + else: |
| 464 | + assert isinstance(desired, DictOfNamedArrays) |
| 465 | + if not isinstance(actual, DictOfNamedArrays): |
| 466 | + raise AutoTestFailureException("'actual' is not" |
| 467 | + " a 'DictOfNamedArrays'") |
| 468 | + |
| 469 | + cq = cl.CommandQueue(cl_ctx) |
| 470 | + |
| 471 | + if (any(isinstance(inp, Placeholder) |
| 472 | + for inp in InputGatherer()(actual)) |
| 473 | + or any(isinstance(inp, Placeholder) |
| 474 | + for inp in InputGatherer()(desired))): |
| 475 | + raise NotImplementedError("Expression graphs with placeholders not" |
| 476 | + " yet supported in auto_test_vs_ref.") |
| 477 | + |
| 478 | + actual_prg = pt.generate_loopy(actual, options=lp.Options(return_dict=True)) |
| 479 | + desired_prg = pt.generate_loopy(desired, options=lp.Options(return_dict=True)) |
| 480 | + |
| 481 | + _, actual_out_dict = actual_prg(cq) |
| 482 | + _, desired_out_dict = desired_prg(cq) |
| 483 | + |
| 484 | + if set(actual_out_dict) != set(desired_out_dict): |
| 485 | + raise AutoTestFailureException( |
| 486 | + "Different outputs obtained from the 2 expressions. " |
| 487 | + f" '{set(actual_out_dict.keys())}' vs '{set(desired_out_dict.keys())}'" |
| 488 | + ) |
| 489 | + |
| 490 | + for output_name, desired_out in desired_out_dict.items(): |
| 491 | + actual_out = actual_out_dict[output_name] |
| 492 | + |
| 493 | + if isinstance(desired_out, cla.Array): |
| 494 | + desired_out = desired_out.get() |
| 495 | + if isinstance(actual_out, cla.Array): |
| 496 | + actual_out = actual_out.get() |
| 497 | + |
| 498 | + try: |
| 499 | + np.testing.assert_allclose(actual_out, desired_out, |
| 500 | + rtol=rtol, atol=atol) |
| 501 | + except AssertionError as e: |
| 502 | + raise AutoTestFailureException( |
| 503 | + f"While comparing '{output_name}': \n{e.args[0]}") |
| 504 | + |
| 505 | +# }}} |
| 506 | + |
372 | 507 | # vim: foldmethod=marker
|
0 commit comments