1
- from dataclasses import dataclass
1
+ from dataclasses import dataclass , field
2
2
from copy import deepcopy
3
3
import numpy as np
4
4
5
- from typing import Any
5
+ from typing import Any , List , Tuple
6
6
7
7
from .hierarchy import PatchHierarchy , format_timestamp
8
8
from .patchdata import FieldData , ParticleData
12
12
from ...core .gridlayout import GridLayout
13
13
from ...core .phare_utilities import listify
14
14
from ...core .phare_utilities import refinement_ratio
15
+ from pyphare .core import phare_utilities as phut
15
16
16
17
17
18
field_qties = {
@@ -561,41 +562,53 @@ def _compute_scalardiv(patch_datas, **kwargs):
561
562
562
563
@dataclass
563
564
class EqualityReport :
564
- ok : bool
565
- reason : str
566
- ref : Any = None
567
- cmp : Any = None
565
+ failed : List [Tuple [str , Any , Any ]] = field (default_factory = lambda : [])
568
566
569
567
def __bool__ (self ):
570
- return self .ok
568
+ return not self .failed
571
569
572
570
def __repr__ (self ):
573
- return self .reason
571
+ for msg , ref , cmp in self :
572
+ print (msg )
573
+ try :
574
+ if type (ref ) is FieldData :
575
+ phut .assert_fp_any_all_close (ref [:], cmp [:], atol = 1e-16 )
576
+ except AssertionError as e :
577
+ print (e )
578
+ return self .failed [0 ][0 ]
574
579
575
- def __post_init__ (self ):
576
- not_nones = [a is not None for a in [self .ref , self .cmp ]]
577
- if all (not_nones ):
578
- assert id (self .ref ) != id (self .cmp )
579
- else :
580
- assert not any (not_nones )
580
+ def __call__ (self , reason , ref = None , cmp = None ):
581
+ self .failed .append ((reason , ref , cmp ))
582
+ return self
583
+
584
+ def __getitem__ (self , idx ):
585
+ return (self .failed [idx ][1 ], self .failed [idx ][2 ])
586
+
587
+ def __iter__ (self ):
588
+ return self .failed .__iter__ ()
589
+
590
+ def __reversed__ (self ):
591
+ return reversed (self .failed )
581
592
582
593
583
594
def hierarchy_compare (this , that , atol = 1e-16 ):
595
+ eqr = EqualityReport ()
596
+
584
597
if not isinstance (this , PatchHierarchy ) or not isinstance (that , PatchHierarchy ):
585
- return EqualityReport ( False , "class type mismatch" )
598
+ return eqr ( "class type mismatch" )
586
599
587
600
if this .ndim != that .ndim or this .domain_box != that .domain_box :
588
- return EqualityReport ( False , "dimensional mismatch" )
601
+ return eqr ( "dimensional mismatch" )
589
602
590
603
if this .time_hier .keys () != that .time_hier .keys ():
591
- return EqualityReport ( False , "timesteps mismatch" )
604
+ return eqr ( "timesteps mismatch" )
592
605
593
606
for tidx in this .times ():
594
607
patch_levels_ref = this .time_hier [tidx ]
595
608
patch_levels_cmp = that .time_hier [tidx ]
596
609
597
610
if patch_levels_ref .keys () != patch_levels_cmp .keys ():
598
- return EqualityReport ( False , "levels mismatch" )
611
+ return eqr ( "levels mismatch" )
599
612
600
613
for level_idx in patch_levels_cmp .keys ():
601
614
patch_level_ref = patch_levels_ref [level_idx ]
@@ -606,19 +619,20 @@ def hierarchy_compare(this, that, atol=1e-16):
606
619
patch_cmp = patch_level_cmp .patches [patch_idx ]
607
620
608
621
if patch_ref .patch_datas .keys () != patch_cmp .patch_datas .keys ():
609
- return EqualityReport ( False , "data keys mismatch" )
622
+ return eqr ( "data keys mismatch" )
610
623
611
624
for patch_data_key in patch_ref .patch_datas .keys ():
612
625
patch_data_ref = patch_ref .patch_datas [patch_data_key ]
613
626
patch_data_cmp = patch_cmp .patch_datas [patch_data_key ]
614
627
615
628
if not patch_data_cmp .compare (patch_data_ref , atol = atol ):
616
629
msg = f"data mismatch: { type (patch_data_ref ).__name__ } { patch_data_key } "
617
- return EqualityReport (
618
- False , msg , patch_data_cmp , patch_data_ref
619
- )
630
+ eqr (msg , patch_data_cmp , patch_data_ref )
631
+
632
+ if not eqr :
633
+ return eqr
620
634
621
- return EqualityReport ( True , "OK" )
635
+ return eqr
622
636
623
637
624
638
def single_patch_for_LO (hier , qties = None , skip = None ):
0 commit comments