Skip to content

Commit c262a75

Browse files
authored
hier comparae (#910)
1 parent af3f2be commit c262a75

File tree

9 files changed

+156
-158
lines changed

9 files changed

+156
-158
lines changed

pyphare/pyphare/core/phare_utilities.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,18 @@ def is_fp32(item):
128128
return isinstance(item, float)
129129

130130

131-
def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
131+
def any_fp_tol(a, b, atol=1e-16, rtol=0, atol_fp32=None):
132132
if any([is_fp32(el) for el in [a, b]]):
133133
atol = atol_fp32 if atol_fp32 else atol * 1e8
134-
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
134+
return dict(atol=atol, rtol=rtol)
135+
136+
137+
def fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
138+
return np.allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))
139+
140+
141+
def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
142+
np.testing.assert_allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))
135143

136144

137145
def decode_bytes(input, errors="ignore"):

pyphare/pyphare/pharein/examples/job.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

pyphare/pyphare/pharein/init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ def get_user_inputs(jobname):
66
_init_.PHARE_EXE = True
77
print(jobname)
88
jobmodule = importlib.import_module(jobname) # lgtm [py/unused-local-variable]
9+
if jobmodule is None:
10+
raise RuntimeError("failed to import job")
911
populateDict()

pyphare/pyphare/pharein/load_balancer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class LoadBalancer:
3333

3434
def __post_init__(self):
3535
if self.auto and self.every:
36-
raise RuntimeError(f"LoadBalancer cannot work with both 'every' and 'auto'")
36+
raise RuntimeError("LoadBalancer cannot work with both 'every' and 'auto'")
3737

3838
if self.every is None:
3939
self.auto = True
@@ -50,8 +50,8 @@ def __post_init__(self):
5050
if self._register:
5151
if not gv.sim:
5252
raise RuntimeError(
53-
f"LoadBalancer cannot be registered as no simulation exists"
53+
"LoadBalancer cannot be registered as no simulation exists"
5454
)
5555
if gv.sim.load_balancer:
56-
raise RuntimeError(f"LoadBalancer is already registered to simulation")
56+
raise RuntimeError("LoadBalancer is already registered to simulation")
5757
gv.sim.load_balancer = self

pyphare/pyphare/pharesee/hierarchy/fromh5.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
particle_files_patterns = ("domain", "patchGhost", "levelGhost")
2525

2626

27-
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"]):
27+
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"], hier=None):
2828
time = format_timestamp(time)
29-
hier = None
3029
path = Path(filepath)
3130
for h5 in path.glob("*.h5"):
3231
if h5.parent == path and h5.stem not in exclude:

pyphare/pyphare/pharesee/hierarchy/hierarchy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
14
from .patch import Patch
25
from .patchlevel import PatchLevel
36
from ...core.box import Box
47
from ...core import box as boxm
5-
from ...core.phare_utilities import refinement_ratio
68
from ...core.phare_utilities import listify
7-
8-
import numpy as np
9-
import matplotlib.pyplot as plt
9+
from ...core.phare_utilities import deep_copy
10+
from ...core.phare_utilities import refinement_ratio
1011

1112

1213
def format_timestamp(timestamp):
@@ -68,6 +69,10 @@ def __init__(
6869

6970
self.update()
7071

72+
def __deepcopy__(self, memo):
73+
no_copy_keys = ["data_files"] # do not copy these things
74+
return deep_copy(self, memo, no_copy_keys)
75+
7176
def __getitem__(self, qty):
7277
return self.__dict__[qty]
7378

pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from .hierarchy import PatchHierarchy
2-
from .patchdata import FieldData
1+
from dataclasses import dataclass, field
2+
from copy import deepcopy
3+
import numpy as np
4+
5+
from typing import Any, List, Tuple
6+
7+
from .hierarchy import PatchHierarchy, format_timestamp
8+
from .patchdata import FieldData, ParticleData
39
from .patchlevel import PatchLevel
410
from .patch import Patch
11+
from ...core.box import Box
12+
from ...core.gridlayout import GridLayout
513
from ...core.phare_utilities import listify
614
from ...core.phare_utilities import refinement_ratio
15+
from pyphare.core import phare_utilities as phut
716

8-
import numpy as np
917

1018
field_qties = {
1119
"EM_B_x": "Bx",
@@ -298,7 +306,7 @@ def overlap_mask_2d(x, y, dl, level, qty):
298306
return is_overlaped
299307

300308

301-
def flat_finest_field(hierarchy, qty, time=None):
309+
def flat_finest_field(hierarchy, qty, time=None, neghosts=1):
302310
"""
303311
returns 2 flattened arrays containing the data (with shape [Npoints])
304312
and the coordinates (with shape [Npoints, Ndim]) for the given
@@ -311,7 +319,7 @@ def flat_finest_field(hierarchy, qty, time=None):
311319
dim = hierarchy.ndim
312320

313321
if dim == 1:
314-
return flat_finest_field_1d(hierarchy, qty, time)
322+
return flat_finest_field_1d(hierarchy, qty, time, neghosts)
315323
elif dim == 2:
316324
return flat_finest_field_2d(hierarchy, qty, time)
317325
elif dim == 3:
@@ -321,7 +329,7 @@ def flat_finest_field(hierarchy, qty, time=None):
321329
raise ValueError("the dim of a hierarchy should be 1, 2 or 3")
322330

323331

324-
def flat_finest_field_1d(hierarchy, qty, time=None):
332+
def flat_finest_field_1d(hierarchy, qty, time=None, neghosts=1):
325333
lvl = hierarchy.levels(time)
326334

327335
for ilvl in range(hierarchy.finest_level(time) + 1)[::-1]:
@@ -333,7 +341,7 @@ def flat_finest_field_1d(hierarchy, qty, time=None):
333341
# all but 1 ghost nodes are removed in order to limit
334342
# the overlapping, but to keep enough point to avoid
335343
# any extrapolation for the interpolator
336-
needed_points = pdata.ghosts_nbr - 1
344+
needed_points = pdata.ghosts_nbr - neghosts
337345

338346
# data = pdata.dataset[patch.box] # TODO : once PR 551 will be merged...
339347
data = pdata.dataset[needed_points[0] : -needed_points[0]]
@@ -552,34 +560,55 @@ def _compute_scalardiv(patch_datas, **kwargs):
552560
return tuple(pd_attrs)
553561

554562

555-
from dataclasses import dataclass
556-
557-
558563
@dataclass
559564
class EqualityReport:
560-
ok: bool
561-
reason: str
565+
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])
562566

563567
def __bool__(self):
564-
return self.ok
568+
return not self.failed
569+
570+
def __repr__(self):
571+
for msg, ref, cmp in self:
572+
print(msg)
573+
try:
574+
if type(ref) is FieldData:
575+
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
576+
except AssertionError as e:
577+
print(e)
578+
return self.failed[0][0]
579+
580+
def __call__(self, reason, ref=None, cmp=None):
581+
self.failed.append((reason, ref, cmp))
582+
return self
583+
584+
def __getitem__(self, idx):
585+
return (self.failed[idx][1], self.failed[idx][2])
586+
587+
def __iter__(self):
588+
return self.failed.__iter__()
589+
590+
def __reversed__(self):
591+
return reversed(self.failed)
565592

566593

567-
def hierarchy_compare(this, that):
594+
def hierarchy_compare(this, that, atol=1e-16):
595+
eqr = EqualityReport()
596+
568597
if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
569-
return EqualityReport(False, "class type mismatch")
598+
return eqr("class type mismatch")
570599

571600
if this.ndim != that.ndim or this.domain_box != that.domain_box:
572-
return EqualityReport(False, "dimensional mismatch")
601+
return eqr("dimensional mismatch")
573602

574603
if this.time_hier.keys() != that.time_hier.keys():
575-
return EqualityReport(False, "timesteps mismatch")
604+
return eqr("timesteps mismatch")
576605

577606
for tidx in this.times():
578607
patch_levels_ref = this.time_hier[tidx]
579608
patch_levels_cmp = that.time_hier[tidx]
580609

581610
if patch_levels_ref.keys() != patch_levels_cmp.keys():
582-
return EqualityReport(False, "levels mismatch")
611+
return eqr("levels mismatch")
583612

584613
for level_idx in patch_levels_cmp.keys():
585614
patch_level_ref = patch_levels_ref[level_idx]
@@ -590,21 +619,62 @@ def hierarchy_compare(this, that):
590619
patch_cmp = patch_level_cmp.patches[patch_idx]
591620

592621
if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
593-
print(list(patch_ref.patch_datas.keys()))
594-
print(list(patch_cmp.patch_datas.keys()))
595-
return EqualityReport(False, "data keys mismatch")
622+
return eqr("data keys mismatch")
596623

597624
for patch_data_key in patch_ref.patch_datas.keys():
598625
patch_data_ref = patch_ref.patch_datas[patch_data_key]
599626
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]
600627

601-
if patch_data_cmp != patch_data_ref:
602-
return EqualityReport(
603-
False,
604-
"data mismatch: "
605-
+ type(patch_data_cmp).__name__
606-
+ " "
607-
+ type(patch_data_ref).__name__,
608-
)
628+
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
629+
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
630+
eqr(msg, patch_data_cmp, patch_data_ref)
631+
632+
if not eqr:
633+
return eqr
634+
635+
return eqr
636+
637+
638+
def single_patch_for_LO(hier, qties=None, skip=None):
639+
def _skip(qty):
640+
return (qties is not None and qty not in qties) or (
641+
skip is not None and qty in skip
642+
)
609643

610-
return EqualityReport(True, "OK")
644+
cier = deepcopy(hier)
645+
sim = hier.sim
646+
layout = GridLayout(
647+
Box(sim.origin, sim.cells), sim.origin, sim.dl, interp_order=sim.interp_order
648+
)
649+
p0 = Patch(patch_datas={}, patch_id="", layout=layout)
650+
for t in cier.times():
651+
cier.time_hier[format_timestamp(t)] = {0: cier.level(0, t)}
652+
cier.level(0, t).patches = [deepcopy(p0)]
653+
l0_pds = cier.level(0, t).patches[0].patch_datas
654+
for k, v in hier.level(0, t).patches[0].patch_datas.items():
655+
if _skip(k):
656+
continue
657+
if isinstance(v, FieldData):
658+
l0_pds[k] = FieldData(
659+
layout, v.field_name, None, centering=v.centerings
660+
)
661+
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
662+
patch_box = hier.level(0, t).patches[0].box
663+
l0_pds[k][patch_box] = v[patch_box]
664+
665+
elif isinstance(v, ParticleData):
666+
l0_pds[k] = deepcopy(v)
667+
else:
668+
raise RuntimeError("unexpected state")
669+
670+
for patch in hier.level(0, t).patches[1:]:
671+
for k, v in patch.patch_datas.items():
672+
if _skip(k):
673+
continue
674+
if isinstance(v, FieldData):
675+
l0_pds[k][patch.box] = v[patch.box]
676+
elif isinstance(v, ParticleData):
677+
l0_pds[k].dataset.add(v.dataset)
678+
else:
679+
raise RuntimeError("unexpected state")
680+
return cier

0 commit comments

Comments
 (0)