2
2
from copy import deepcopy
3
3
import numpy as np
4
4
5
+ from typing import Any
6
+
5
7
from .hierarchy import PatchHierarchy , format_timestamp
6
8
from .patchdata import FieldData , ParticleData
7
9
from .patchlevel import PatchLevel
10
12
from ...core .gridlayout import GridLayout
11
13
from ...core .phare_utilities import listify
12
14
from ...core .phare_utilities import refinement_ratio
13
- from pyphare .pharesee import particles as mparticles
14
15
15
16
16
17
field_qties = {
@@ -562,15 +563,24 @@ def _compute_scalardiv(patch_datas, **kwargs):
562
563
class EqualityReport :
563
564
ok : bool
564
565
reason : str
566
+ ref : Any = None
567
+ cmp : Any = None
565
568
566
569
def __bool__ (self ):
567
570
return self .ok
568
571
569
572
def __repr__ (self ):
570
573
return self .reason
571
574
575
+ def __post_init__ (self ):
576
+ not_nones = [a is not None for a in [self .ref , self .cmp ]]
577
+ if all (not_nones ):
578
+ assert id (self .ref ) != id (self .cmp )
579
+ else :
580
+ assert not any (not_nones )
581
+
572
582
573
- def hierarchy_compare (this , that ):
583
+ def hierarchy_compare (this , that , atol = 1e-16 ):
574
584
if not isinstance (this , PatchHierarchy ) or not isinstance (that , PatchHierarchy ):
575
585
return EqualityReport (False , "class type mismatch" )
576
586
@@ -596,24 +606,26 @@ def hierarchy_compare(this, that):
596
606
patch_cmp = patch_level_cmp .patches [patch_idx ]
597
607
598
608
if patch_ref .patch_datas .keys () != patch_cmp .patch_datas .keys ():
599
- print (list (patch_ref .patch_datas .keys ()))
600
- print (list (patch_cmp .patch_datas .keys ()))
601
609
return EqualityReport (False , "data keys mismatch" )
602
610
603
611
for patch_data_key in patch_ref .patch_datas .keys ():
604
612
patch_data_ref = patch_ref .patch_datas [patch_data_key ]
605
613
patch_data_cmp = patch_cmp .patch_datas [patch_data_key ]
606
614
607
- if patch_data_cmp != patch_data_ref :
608
- msg = f"data mismatch: { patch_data_key } { type (patch_data_cmp ).__name__ } { type (patch_data_ref ).__name__ } "
609
- return EqualityReport (False , msg )
615
+ if not patch_data_cmp .compare (patch_data_ref , atol = atol ):
616
+ msg = f"data mismatch: { type (patch_data_ref ).__name__ } { patch_data_key } "
617
+ return EqualityReport (
618
+ False , msg , patch_data_cmp , patch_data_ref
619
+ )
610
620
611
621
return EqualityReport (True , "OK" )
612
622
613
623
614
- def single_patch_for_LO (hier , qties = None ):
624
+ def single_patch_for_LO (hier , qties = None , skip = None ):
615
625
def _skip (qty ):
616
- return qties is not None and qty not in qties
626
+ return (qties is not None and qty not in qties ) or (
627
+ skip is not None and qty in skip
628
+ )
617
629
618
630
cier = deepcopy (hier )
619
631
sim = hier .sim
@@ -633,22 +645,22 @@ def _skip(qty):
633
645
layout , v .field_name , None , centering = v .centerings
634
646
)
635
647
l0_pds [k ].dataset = np .zeros (l0_pds [k ].size )
648
+ patch_box = hier .level (0 , t ).patches [0 ].box
649
+ l0_pds [k ][patch_box ] = v [patch_box ]
636
650
637
651
elif isinstance (v , ParticleData ):
638
652
l0_pds [k ] = deepcopy (v )
639
653
else :
640
654
raise RuntimeError ("unexpected state" )
641
655
642
- for patch in hier .level (0 , t ).patches :
656
+ for patch in hier .level (0 , t ).patches [ 1 :] :
643
657
for k , v in patch .patch_datas .items ():
644
658
if _skip (k ):
645
659
continue
646
660
if isinstance (v , FieldData ):
647
661
l0_pds [k ][patch .box ] = v [patch .box ]
648
662
elif isinstance (v , ParticleData ):
649
- l0_pds [k ].dataset = mparticles .aggregate (
650
- [l0_pds [k ].dataset , v .dataset ]
651
- )
663
+ l0_pds [k ].dataset .add (v .dataset )
652
664
else :
653
665
raise RuntimeError ("unexpected state" )
654
666
return cier
0 commit comments