Skip to content

hier comparae #910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pyphare/pyphare/core/phare_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ def is_fp32(item):
return isinstance(item, float)


def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
def any_fp_tol(a, b, atol=1e-16, rtol=0, atol_fp32=None):
if any([is_fp32(el) for el in [a, b]]):
atol = atol_fp32 if atol_fp32 else atol * 1e8
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
return dict(atol=atol, rtol=rtol)


def fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
return np.allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))


def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
np.testing.assert_allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))


def decode_bytes(input, errors="ignore"):
Expand Down
110 changes: 0 additions & 110 deletions pyphare/pyphare/pharein/examples/job.py

This file was deleted.

2 changes: 2 additions & 0 deletions pyphare/pyphare/pharein/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ def get_user_inputs(jobname):
_init_.PHARE_EXE = True
print(jobname)
jobmodule = importlib.import_module(jobname) # lgtm [py/unused-local-variable]
if jobmodule is None:
raise RuntimeError("failed to import job")
populateDict()
6 changes: 3 additions & 3 deletions pyphare/pyphare/pharein/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LoadBalancer:

def __post_init__(self):
if self.auto and self.every:
raise RuntimeError(f"LoadBalancer cannot work with both 'every' and 'auto'")
raise RuntimeError("LoadBalancer cannot work with both 'every' and 'auto'")

if self.every is None:
self.auto = True
Expand All @@ -50,8 +50,8 @@ def __post_init__(self):
if self._register:
if not gv.sim:
raise RuntimeError(
f"LoadBalancer cannot be registered as no simulation exists"
"LoadBalancer cannot be registered as no simulation exists"
)
if gv.sim.load_balancer:
raise RuntimeError(f"LoadBalancer is already registered to simulation")
raise RuntimeError("LoadBalancer is already registered to simulation")
gv.sim.load_balancer = self
3 changes: 1 addition & 2 deletions pyphare/pyphare/pharesee/hierarchy/fromh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
particle_files_patterns = ("domain", "patchGhost", "levelGhost")


def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"]):
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"], hier=None):
time = format_timestamp(time)
hier = None
path = Path(filepath)
for h5 in path.glob("*.h5"):
if h5.parent == path and h5.stem not in exclude:
Expand Down
13 changes: 9 additions & 4 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import matplotlib.pyplot as plt

from .patch import Patch
from .patchlevel import PatchLevel
from ...core.box import Box
from ...core import box as boxm
from ...core.phare_utilities import refinement_ratio
from ...core.phare_utilities import listify

import numpy as np
import matplotlib.pyplot as plt
from ...core.phare_utilities import deep_copy
from ...core.phare_utilities import refinement_ratio


def format_timestamp(timestamp):
Expand Down Expand Up @@ -68,6 +69,10 @@ def __init__(

self.update()

def __deepcopy__(self, memo):
no_copy_keys = ["data_files"] # do not copy these things
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepcopy crashes without this

return deep_copy(self, memo, no_copy_keys)

def __getitem__(self, qty):
return self.__dict__[qty]

Expand Down
130 changes: 100 additions & 30 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from .hierarchy import PatchHierarchy
from .patchdata import FieldData
from dataclasses import dataclass, field
from copy import deepcopy
import numpy as np

from typing import Any, List, Tuple

from .hierarchy import PatchHierarchy, format_timestamp
from .patchdata import FieldData, ParticleData
from .patchlevel import PatchLevel
from .patch import Patch
from ...core.box import Box
from ...core.gridlayout import GridLayout
from ...core.phare_utilities import listify
from ...core.phare_utilities import refinement_ratio
from pyphare.core import phare_utilities as phut

import numpy as np

field_qties = {
"EM_B_x": "Bx",
Expand Down Expand Up @@ -298,7 +306,7 @@
return is_overlaped


def flat_finest_field(hierarchy, qty, time=None):
def flat_finest_field(hierarchy, qty, time=None, neghosts=1):
"""
returns 2 flattened arrays containing the data (with shape [Npoints])
and the coordinates (with shape [Npoints, Ndim]) for the given
Expand All @@ -311,7 +319,7 @@
dim = hierarchy.ndim

if dim == 1:
return flat_finest_field_1d(hierarchy, qty, time)
return flat_finest_field_1d(hierarchy, qty, time, neghosts)
elif dim == 2:
return flat_finest_field_2d(hierarchy, qty, time)
elif dim == 3:
Expand All @@ -321,7 +329,7 @@
raise ValueError("the dim of a hierarchy should be 1, 2 or 3")


def flat_finest_field_1d(hierarchy, qty, time=None):
def flat_finest_field_1d(hierarchy, qty, time=None, neghosts=1):
lvl = hierarchy.levels(time)

for ilvl in range(hierarchy.finest_level(time) + 1)[::-1]:
Expand All @@ -333,7 +341,7 @@
# all but 1 ghost nodes are removed in order to limit
# the overlapping, but to keep enough point to avoid
# any extrapolation for the interpolator
needed_points = pdata.ghosts_nbr - 1
needed_points = pdata.ghosts_nbr - neghosts

# data = pdata.dataset[patch.box] # TODO : once PR 551 will be merged...
data = pdata.dataset[needed_points[0] : -needed_points[0]]
Expand Down Expand Up @@ -552,34 +560,55 @@
return tuple(pd_attrs)


from dataclasses import dataclass


@dataclass
class EqualityReport:
ok: bool
reason: str
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])

def __bool__(self):
return self.ok
return not self.failed

def __repr__(self):
for msg, ref, cmp in self:
print(msg)
try:
if type(ref) is FieldData:
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
except AssertionError as e:
print(e)
return self.failed[0][0]

def __call__(self, reason, ref=None, cmp=None):
self.failed.append((reason, ref, cmp))
return self

def __getitem__(self, idx):
return (self.failed[idx][1], self.failed[idx][2])

def __iter__(self):
return self.failed.__iter__()

def __reversed__(self):
return reversed(self.failed)
Comment on lines +565 to +591
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Correct the __repr__ method in the EqualityReport class.

The __repr__ method currently prints messages and returns one of them, which is not standard behavior for __repr__. The purpose of __repr__ is to return a string that would recreate the object when evaluated (or at least give a detailed representation). Additionally, printing within __repr__ is not recommended.

Apply this diff to correct the __repr__ method:

-    def __repr__(self):
-        for msg, ref, cmp in self:
-            print(msg)
-            try:
-                if type(ref) is FieldData:
-                    phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
-            except AssertionError as e:
-                print(e)
-        return self.failed[0][0]
+    def __repr__(self):
+        return f"EqualityReport(failed={self.failed})"

This change ensures that __repr__ returns a string representation of the EqualityReport object without side effects.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])
def __bool__(self):
return self.ok
return not self.failed
def __repr__(self):
for msg, ref, cmp in self:
print(msg)
try:
if type(ref) is FieldData:
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
except AssertionError as e:
print(e)
return self.failed[0][0]
def __call__(self, reason, ref=None, cmp=None):
self.failed.append((reason, ref, cmp))
return self
def __getitem__(self, idx):
return (self.failed[idx][1], self.failed[idx][2])
def __iter__(self):
return self.failed.__iter__()
def __reversed__(self):
return reversed(self.failed)
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])
def __bool__(self):
return not self.failed
def __repr__(self):
return f"EqualityReport(failed={self.failed})"
def __call__(self, reason, ref=None, cmp=None):
self.failed.append((reason, ref, cmp))
return self
def __getitem__(self, idx):
return (self.failed[idx][1], self.failed[idx][2])
def __iter__(self):
return self.failed.__iter__()
def __reversed__(self):
return reversed(self.failed)



def hierarchy_compare(this, that):
def hierarchy_compare(this, that, atol=1e-16):
eqr = EqualityReport()

if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
return EqualityReport(False, "class type mismatch")
return eqr("class type mismatch")

if this.ndim != that.ndim or this.domain_box != that.domain_box:
return EqualityReport(False, "dimensional mismatch")
return eqr("dimensional mismatch")

if this.time_hier.keys() != that.time_hier.keys():
return EqualityReport(False, "timesteps mismatch")
return eqr("timesteps mismatch")

for tidx in this.times():
patch_levels_ref = this.time_hier[tidx]
patch_levels_cmp = that.time_hier[tidx]

if patch_levels_ref.keys() != patch_levels_cmp.keys():
return EqualityReport(False, "levels mismatch")
return eqr("levels mismatch")

for level_idx in patch_levels_cmp.keys():
patch_level_ref = patch_levels_ref[level_idx]
Expand All @@ -590,21 +619,62 @@
patch_cmp = patch_level_cmp.patches[patch_idx]

if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
print(list(patch_ref.patch_datas.keys()))
print(list(patch_cmp.patch_datas.keys()))
return EqualityReport(False, "data keys mismatch")
return eqr("data keys mismatch")

for patch_data_key in patch_ref.patch_datas.keys():
patch_data_ref = patch_ref.patch_datas[patch_data_key]
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]

if patch_data_cmp != patch_data_ref:
return EqualityReport(
False,
"data mismatch: "
+ type(patch_data_cmp).__name__
+ " "
+ type(patch_data_ref).__name__,
)
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
eqr(msg, patch_data_cmp, patch_data_ref)

if not eqr:
return eqr

return eqr


def single_patch_for_LO(hier, qties=None, skip=None):
def _skip(qty):
return (qties is not None and qty not in qties) or (
skip is not None and qty in skip
)

return EqualityReport(True, "OK")
cier = deepcopy(hier)
sim = hier.sim
layout = GridLayout(
Box(sim.origin, sim.cells), sim.origin, sim.dl, interp_order=sim.interp_order
)
p0 = Patch(patch_datas={}, patch_id="", layout=layout)
for t in cier.times():
cier.time_hier[format_timestamp(t)] = {0: cier.level(0, t)}
cier.level(0, t).patches = [deepcopy(p0)]
l0_pds = cier.level(0, t).patches[0].patch_datas
for k, v in hier.level(0, t).patches[0].patch_datas.items():
if _skip(k):
continue
if isinstance(v, FieldData):
l0_pds[k] = FieldData(
layout, v.field_name, None, centering=v.centerings
)
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
patch_box = hier.level(0, t).patches[0].box
l0_pds[k][patch_box] = v[patch_box]

elif isinstance(v, ParticleData):
l0_pds[k] = deepcopy(v)
else:
raise RuntimeError("unexpected state")

for patch in hier.level(0, t).patches[1:]:
for k, v in patch.patch_datas.items():
if _skip(k):
continue
if isinstance(v, FieldData):
l0_pds[k][patch.box] = v[patch.box]
elif isinstance(v, ParticleData):
l0_pds[k].dataset.add(v.dataset)
else:
raise RuntimeError("unexpected state")
return cier
Loading
Loading