Skip to content

Commit 0693be6

Browse files
committed
++
1 parent e62e80d commit 0693be6

File tree

7 files changed

+85
-41
lines changed

7 files changed

+85
-41
lines changed

pyphare/pyphare/pharesee/hierarchy/fromh5.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
particle_files_patterns = ("domain", "patchGhost", "levelGhost")
2525

2626

27-
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"]):
27+
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"], hier=None):
2828
time = format_timestamp(time)
29-
hier = None
3029
path = Path(filepath)
3130
for h5 in path.glob("*.h5"):
3231
if h5.parent == path and h5.stem not in exclude:

pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from copy import deepcopy
33
import numpy as np
44

5+
from typing import Any
6+
57
from .hierarchy import PatchHierarchy, format_timestamp
68
from .patchdata import FieldData, ParticleData
79
from .patchlevel import PatchLevel
@@ -10,7 +12,6 @@
1012
from ...core.gridlayout import GridLayout
1113
from ...core.phare_utilities import listify
1214
from ...core.phare_utilities import refinement_ratio
13-
from pyphare.pharesee import particles as mparticles
1415

1516

1617
field_qties = {
@@ -562,15 +563,24 @@ def _compute_scalardiv(patch_datas, **kwargs):
562563
class EqualityReport:
563564
ok: bool
564565
reason: str
566+
ref: Any = None
567+
cmp: Any = None
565568

566569
def __bool__(self):
567570
return self.ok
568571

569572
def __repr__(self):
570573
return self.reason
571574

575+
def __post_init__(self):
576+
not_nones = [a is not None for a in [self.ref, self.cmp]]
577+
if all(not_nones):
578+
assert id(self.ref) != id(self.cmp)
579+
else:
580+
assert not any(not_nones)
581+
572582

573-
def hierarchy_compare(this, that):
583+
def hierarchy_compare(this, that, atol=1e-16):
574584
if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
575585
return EqualityReport(False, "class type mismatch")
576586

@@ -596,24 +606,26 @@ def hierarchy_compare(this, that):
596606
patch_cmp = patch_level_cmp.patches[patch_idx]
597607

598608
if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
599-
print(list(patch_ref.patch_datas.keys()))
600-
print(list(patch_cmp.patch_datas.keys()))
601609
return EqualityReport(False, "data keys mismatch")
602610

603611
for patch_data_key in patch_ref.patch_datas.keys():
604612
patch_data_ref = patch_ref.patch_datas[patch_data_key]
605613
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]
606614

607-
if patch_data_cmp != patch_data_ref:
608-
msg = f"data mismatch: {patch_data_key} {type(patch_data_cmp).__name__} {type(patch_data_ref).__name__}"
609-
return EqualityReport(False, msg)
615+
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
616+
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
617+
return EqualityReport(
618+
False, msg, patch_data_cmp, patch_data_ref
619+
)
610620

611621
return EqualityReport(True, "OK")
612622

613623

614-
def single_patch_for_LO(hier, qties=None):
624+
def single_patch_for_LO(hier, qties=None, skip=None):
615625
def _skip(qty):
616-
return qties is not None and qty not in qties
626+
return (qties is not None and qty not in qties) or (
627+
skip is not None and qty in skip
628+
)
617629

618630
cier = deepcopy(hier)
619631
sim = hier.sim
@@ -633,22 +645,22 @@ def _skip(qty):
633645
layout, v.field_name, None, centering=v.centerings
634646
)
635647
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
648+
patch_box = hier.level(0, t).patches[0].box
649+
l0_pds[k][patch_box] = v[patch_box]
636650

637651
elif isinstance(v, ParticleData):
638652
l0_pds[k] = deepcopy(v)
639653
else:
640654
raise RuntimeError("unexpected state")
641655

642-
for patch in hier.level(0, t).patches:
656+
for patch in hier.level(0, t).patches[1:]:
643657
for k, v in patch.patch_datas.items():
644658
if _skip(k):
645659
continue
646660
if isinstance(v, FieldData):
647661
l0_pds[k][patch.box] = v[patch.box]
648662
elif isinstance(v, ParticleData):
649-
l0_pds[k].dataset = mparticles.aggregate(
650-
[l0_pds[k].dataset, v.dataset]
651-
)
663+
l0_pds[k].dataset.add(v.dataset)
652664
else:
653665
raise RuntimeError("unexpected state")
654666
return cier

pyphare/pyphare/pharesee/hierarchy/patchdata.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22

3-
from ...core.phare_utilities import deep_copy, fp_any_all_close
3+
from ...core import phare_utilities as phut
4+
5+
# deep_copy, fp_any_all_close, assert_fp_any_all_close
46
from ...core import box as boxm
57
from ...core.box import Box
68

@@ -24,7 +26,7 @@ def __init__(self, layout, quantity):
2426

2527
def __deepcopy__(self, memo):
2628
no_copy_keys = ["dataset"] # do not copy these things
27-
return deep_copy(self, memo, no_copy_keys)
29+
return phut.deep_copy(self, memo, no_copy_keys)
2830

2931

3032
class FieldData(PatchData):
@@ -81,10 +83,12 @@ def __repr__(self):
8183
return self.__str__()
8284

8385
def compare(self, that, atol=1e-16):
84-
return fp_any_all_close(self.dataset[:], that.dataset[:], atol)
86+
return self.field_name == that.field_name and phut.fp_any_all_close(
87+
self.dataset[:], that.dataset[:], atol=atol
88+
)
8589

8690
def __eq__(self, that):
87-
return self.field_name == that.field_name and self.compare(that)
91+
return self.compare(that)
8892

8993
def __ne__(self, that):
9094
return not (self == that)
@@ -228,5 +232,9 @@ def __getitem__(self, box):
228232
def size(self):
229233
return self.dataset.size()
230234

231-
def __eq__(self, that):
235+
def compare(self, that, *args, **kwargs):
236+
"""args/kwargs may include atol for consistency with field::compare"""
232237
return self.name == that.name and self.dataset == that.dataset
238+
239+
def __eq__(self, that):
240+
return self.compare(that)

pyphare/pyphare/pharesee/particles.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def size(self):
7777

7878
def __eq__(self, that):
7979
if isinstance(that, Particles):
80+
if self.size() != that.size():
81+
print(
82+
f"particles.py:Particles::eq size diff: {self.size()} != {that.size()}"
83+
)
84+
return False
8085
# fails on OSX for some reason
8186
set_check = set(self.as_tuples()) == set(that.as_tuples())
8287
if set_check:
@@ -88,9 +93,12 @@ def __eq__(self, that):
8893
print(f"particles.py:Particles::eq failed with: {ex}")
8994
print_trace()
9095
return False
91-
96+
print(f"particles.py:Particles::eq bad type: {type(that)}")
9297
return False
9398

99+
def __ne__(self, that):
100+
return not (self == that)
101+
94102
def select(self, box, box_type="cell"):
95103
"""
96104
select particles from the given box

src/amr/data/initializers/samrai_hdf5_initializer.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,18 @@ void SamraiH5Interface<GridLayout>::populate_from(std::string const& dir, int co
116116
int const& mpi_size,
117117
std::string const& field_name)
118118
{
119+
if (restart_files.size()) // executed per pop, but we only need to run this once
120+
return;
121+
119122
for (int rank = 0; rank < mpi_size; ++rank)
120123
{
121124
auto const hdf5_filepath = getRestartFileFullPath(dir, idx, mpi_size, rank);
122125
auto& h5File = *restart_files.emplace_back(std::make_unique<SamraiHDF5File>(hdf5_filepath));
123126
for (auto const& group : h5File.scan_for_groups({"level_0000", field_name}))
124127
{
125-
auto const em_path = group.substr(0, group.rfind("/"));
126-
h5File.patches.emplace_back(h5File.getBoxFromPath(em_path + "/d_box"),
127-
em_path.substr(0, em_path.rfind("/")));
128+
auto const field_path = group.substr(0, group.rfind("/"));
129+
auto const& field_box = h5File.getBoxFromPath(field_path + "/d_box");
130+
h5File.patches.emplace_back(field_box, field_path.substr(0, field_path.rfind("/")));
128131
}
129132
}
130133
}

src/amr/data/particles/initializers/samrai_hdf5_particle_initializer.hpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ void SamraiHDF5ParticleInitializer<ParticleArray, GridLayout>::loadParticles(
4040
{
4141
using Packer = core::ParticlePacker<ParticleArray::dimension>;
4242

43-
auto const& overlaps
44-
= SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(layout.AMRBox());
43+
auto const& dest_box = layout.AMRBox();
44+
auto const& overlaps = SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(dest_box);
45+
4546
for (auto const& [overlap_box, h5FilePtr, pdataptr] : overlaps)
4647
{
47-
auto& h5File = *h5FilePtr;
48-
auto& pdata = *pdataptr;
48+
auto& h5File = *h5FilePtr;
49+
auto& pdata = *pdataptr;
50+
4951
std::string const poppath = pdata.base_path + "/" + popname + "##default/domainParticles_";
5052
core::ContiguousParticles<ParticleArray::dimension> soa{0};
5153

@@ -58,14 +60,12 @@ void SamraiHDF5ParticleInitializer<ParticleArray, GridLayout>::loadParticles(
5860
}
5961

6062
for (std::size_t i = 0; i < soa.size(); ++i)
61-
if (auto const p = soa.copy(i); core::isIn(core::Point{p.iCell}, overlap_box))
63+
if (auto const p = soa.copy(i); core::isIn(core::Point{p.iCell}, dest_box))
6264
particles.push_back(p);
6365
}
6466
}
6567

6668

67-
68-
6969
} // namespace PHARE::amr
7070

7171

tests/simulator/test_init_from_restart.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,35 @@
22
import copy
33
import unittest
44
import subprocess
5+
import numpy as np
56
import pyphare.pharein as ph
67

8+
from pyphare.core import phare_utilities as phut
79
from pyphare.simulator.simulator import Simulator
10+
from pyphare.pharesee.hierarchy.patchdata import FieldData, ParticleData
811
from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5
12+
from pyphare.pharesee.hierarchy.hierarchy import format_timestamp
913
from pyphare.pharesee.hierarchy.hierarchy_utils import single_patch_for_LO
1014
from pyphare.pharesee.hierarchy.hierarchy_utils import hierarchy_compare
1115
from tests.simulator import SimulatorTest, test_restarts
1216
from tests.diagnostic import dump_all_diags
1317

1418

15-
timestep = 0.001
16-
time_step_nbr = 1
19+
time_step = 0.001
20+
time_step_nbr = 5
21+
final_time = time_step_nbr * time_step
1722
first_mpi_size = 4
1823
ppc = 100
1924
cells = 200
2025
first_out = "phare_outputs/reinit/first"
2126
secnd_out = "phare_outputs/reinit/secnd"
22-
timestamps = [0]
23-
restart_idx = Z = 0
27+
# timestamps = [0,time_step]
28+
timestamps = np.arange(0, final_time + time_step, time_step)
29+
restart_idx = Z = 2
2430
simInitArgs = dict(
2531
largest_patch_size=100,
2632
time_step_nbr=time_step_nbr,
27-
time_step=timestep,
33+
time_step=time_step,
2834
cells=cells,
2935
dl=0.3,
3036
init_options=dict(dir=f"{first_out}/00000.00{Z}00", mpi_size=first_mpi_size),
@@ -59,13 +65,22 @@ def test_reinit(self):
5965
sim = ph.Simulation(**copy.deepcopy(simInitArgs))
6066
setup_model(sim)
6167
Simulator(sim).run().reset()
62-
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[0])
63-
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[0])
68+
fidx, sidx = 2, 0
69+
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[fidx])
70+
datahier0.time_hier = { # swap times
71+
format_timestamp(timestamps[sidx]): datahier0.time_hier[
72+
format_timestamp(timestamps[fidx])
73+
]
74+
}
75+
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[sidx])
6476
qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"]
65-
ds = [single_patch_for_LO(d, qties) for d in [datahier0, datahier1]]
66-
eq = hierarchy_compare(*ds)
77+
skip = None # ["protons_patchGhost", "alpha_patchGhost"]
78+
ds = [single_patch_for_LO(d, qties, skip) for d in [datahier0, datahier1]]
79+
eq = hierarchy_compare(*ds, atol=1e-14)
6780
if not eq:
6881
print(eq)
82+
if type(eq.ref) == FieldData:
83+
phut.assert_fp_any_all_close(eq.ref[:], eq.cmp[:], atol=1e-16)
6984
self.assertTrue(eq)
7085

7186

@@ -86,7 +101,6 @@ def launch():
86101
cmd = f"mpirun -n {first_mpi_size} python3 -O {__file__} lol"
87102
try:
88103
p = subprocess.run(cmd.split(" "), check=True, capture_output=True)
89-
print(p.stdout, p.stderr)
90104
except subprocess.CalledProcessError as e:
91105
print("CalledProcessError", e)
92106

0 commit comments

Comments
 (0)