1
- from .hierarchy import PatchHierarchy
2
- from .patchdata import FieldData
1
+ from dataclasses import dataclass , field
2
+ from copy import deepcopy
3
+ import numpy as np
4
+
5
+ from typing import Any , List , Tuple
6
+
7
+ from .hierarchy import PatchHierarchy , format_timestamp
8
+ from .patchdata import FieldData , ParticleData
3
9
from .patchlevel import PatchLevel
4
10
from .patch import Patch
11
+ from ...core .box import Box
12
+ from ...core .gridlayout import GridLayout
5
13
from ...core .phare_utilities import listify
6
14
from ...core .phare_utilities import refinement_ratio
15
+ from pyphare .core import phare_utilities as phut
7
16
8
- import numpy as np
9
17
10
18
field_qties = {
11
19
"EM_B_x" : "Bx" ,
@@ -298,7 +306,7 @@ def overlap_mask_2d(x, y, dl, level, qty):
298
306
return is_overlaped
299
307
300
308
301
- def flat_finest_field (hierarchy , qty , time = None ):
309
+ def flat_finest_field (hierarchy , qty , time = None , neghosts = 1 ):
302
310
"""
303
311
returns 2 flattened arrays containing the data (with shape [Npoints])
304
312
and the coordinates (with shape [Npoints, Ndim]) for the given
@@ -311,7 +319,7 @@ def flat_finest_field(hierarchy, qty, time=None):
311
319
dim = hierarchy .ndim
312
320
313
321
if dim == 1 :
314
- return flat_finest_field_1d (hierarchy , qty , time )
322
+ return flat_finest_field_1d (hierarchy , qty , time , neghosts )
315
323
elif dim == 2 :
316
324
return flat_finest_field_2d (hierarchy , qty , time )
317
325
elif dim == 3 :
@@ -321,7 +329,7 @@ def flat_finest_field(hierarchy, qty, time=None):
321
329
raise ValueError ("the dim of a hierarchy should be 1, 2 or 3" )
322
330
323
331
324
- def flat_finest_field_1d (hierarchy , qty , time = None ):
332
+ def flat_finest_field_1d (hierarchy , qty , time = None , neghosts = 1 ):
325
333
lvl = hierarchy .levels (time )
326
334
327
335
for ilvl in range (hierarchy .finest_level (time ) + 1 )[::- 1 ]:
@@ -333,7 +341,7 @@ def flat_finest_field_1d(hierarchy, qty, time=None):
333
341
# all but 1 ghost nodes are removed in order to limit
334
342
# the overlapping, but to keep enough point to avoid
335
343
# any extrapolation for the interpolator
336
- needed_points = pdata .ghosts_nbr - 1
344
+ needed_points = pdata .ghosts_nbr - neghosts
337
345
338
346
# data = pdata.dataset[patch.box] # TODO : once PR 551 will be merged...
339
347
data = pdata .dataset [needed_points [0 ] : - needed_points [0 ]]
@@ -552,34 +560,55 @@ def _compute_scalardiv(patch_datas, **kwargs):
552
560
return tuple (pd_attrs )
553
561
554
562
555
- from dataclasses import dataclass
556
-
557
-
558
563
@dataclass
559
564
class EqualityReport :
560
- ok : bool
561
- reason : str
565
+ failed : List [Tuple [str , Any , Any ]] = field (default_factory = lambda : [])
562
566
563
567
def __bool__ (self ):
564
- return self .ok
568
+ return not self .failed
569
+
570
+ def __repr__ (self ):
571
+ for msg , ref , cmp in self :
572
+ print (msg )
573
+ try :
574
+ if type (ref ) is FieldData :
575
+ phut .assert_fp_any_all_close (ref [:], cmp [:], atol = 1e-16 )
576
+ except AssertionError as e :
577
+ print (e )
578
+ return self .failed [0 ][0 ]
579
+
580
+ def __call__ (self , reason , ref = None , cmp = None ):
581
+ self .failed .append ((reason , ref , cmp ))
582
+ return self
583
+
584
+ def __getitem__ (self , idx ):
585
+ return (self .failed [idx ][1 ], self .failed [idx ][2 ])
586
+
587
+ def __iter__ (self ):
588
+ return self .failed .__iter__ ()
589
+
590
+ def __reversed__ (self ):
591
+ return reversed (self .failed )
565
592
566
593
567
- def hierarchy_compare (this , that ):
594
+ def hierarchy_compare (this , that , atol = 1e-16 ):
595
+ eqr = EqualityReport ()
596
+
568
597
if not isinstance (this , PatchHierarchy ) or not isinstance (that , PatchHierarchy ):
569
- return EqualityReport ( False , "class type mismatch" )
598
+ return eqr ( "class type mismatch" )
570
599
571
600
if this .ndim != that .ndim or this .domain_box != that .domain_box :
572
- return EqualityReport ( False , "dimensional mismatch" )
601
+ return eqr ( "dimensional mismatch" )
573
602
574
603
if this .time_hier .keys () != that .time_hier .keys ():
575
- return EqualityReport ( False , "timesteps mismatch" )
604
+ return eqr ( "timesteps mismatch" )
576
605
577
606
for tidx in this .times ():
578
607
patch_levels_ref = this .time_hier [tidx ]
579
608
patch_levels_cmp = that .time_hier [tidx ]
580
609
581
610
if patch_levels_ref .keys () != patch_levels_cmp .keys ():
582
- return EqualityReport ( False , "levels mismatch" )
611
+ return eqr ( "levels mismatch" )
583
612
584
613
for level_idx in patch_levels_cmp .keys ():
585
614
patch_level_ref = patch_levels_ref [level_idx ]
@@ -590,21 +619,62 @@ def hierarchy_compare(this, that):
590
619
patch_cmp = patch_level_cmp .patches [patch_idx ]
591
620
592
621
if patch_ref .patch_datas .keys () != patch_cmp .patch_datas .keys ():
593
- print (list (patch_ref .patch_datas .keys ()))
594
- print (list (patch_cmp .patch_datas .keys ()))
595
- return EqualityReport (False , "data keys mismatch" )
622
+ return eqr ("data keys mismatch" )
596
623
597
624
for patch_data_key in patch_ref .patch_datas .keys ():
598
625
patch_data_ref = patch_ref .patch_datas [patch_data_key ]
599
626
patch_data_cmp = patch_cmp .patch_datas [patch_data_key ]
600
627
601
- if patch_data_cmp != patch_data_ref :
602
- return EqualityReport (
603
- False ,
604
- "data mismatch: "
605
- + type (patch_data_cmp ).__name__
606
- + " "
607
- + type (patch_data_ref ).__name__ ,
608
- )
628
+ if not patch_data_cmp .compare (patch_data_ref , atol = atol ):
629
+ msg = f"data mismatch: { type (patch_data_ref ).__name__ } { patch_data_key } "
630
+ eqr (msg , patch_data_cmp , patch_data_ref )
631
+
632
+ if not eqr :
633
+ return eqr
634
+
635
+ return eqr
636
+
637
+
638
+ def single_patch_for_LO (hier , qties = None , skip = None ):
639
+ def _skip (qty ):
640
+ return (qties is not None and qty not in qties ) or (
641
+ skip is not None and qty in skip
642
+ )
609
643
610
- return EqualityReport (True , "OK" )
644
+ cier = deepcopy (hier )
645
+ sim = hier .sim
646
+ layout = GridLayout (
647
+ Box (sim .origin , sim .cells ), sim .origin , sim .dl , interp_order = sim .interp_order
648
+ )
649
+ p0 = Patch (patch_datas = {}, patch_id = "" , layout = layout )
650
+ for t in cier .times ():
651
+ cier .time_hier [format_timestamp (t )] = {0 : cier .level (0 , t )}
652
+ cier .level (0 , t ).patches = [deepcopy (p0 )]
653
+ l0_pds = cier .level (0 , t ).patches [0 ].patch_datas
654
+ for k , v in hier .level (0 , t ).patches [0 ].patch_datas .items ():
655
+ if _skip (k ):
656
+ continue
657
+ if isinstance (v , FieldData ):
658
+ l0_pds [k ] = FieldData (
659
+ layout , v .field_name , None , centering = v .centerings
660
+ )
661
+ l0_pds [k ].dataset = np .zeros (l0_pds [k ].size )
662
+ patch_box = hier .level (0 , t ).patches [0 ].box
663
+ l0_pds [k ][patch_box ] = v [patch_box ]
664
+
665
+ elif isinstance (v , ParticleData ):
666
+ l0_pds [k ] = deepcopy (v )
667
+ else :
668
+ raise RuntimeError ("unexpected state" )
669
+
670
+ for patch in hier .level (0 , t ).patches [1 :]:
671
+ for k , v in patch .patch_datas .items ():
672
+ if _skip (k ):
673
+ continue
674
+ if isinstance (v , FieldData ):
675
+ l0_pds [k ][patch .box ] = v [patch .box ]
676
+ elif isinstance (v , ParticleData ):
677
+ l0_pds [k ].dataset .add (v .dataset )
678
+ else :
679
+ raise RuntimeError ("unexpected state" )
680
+ return cier
0 commit comments