Skip to content

Commit f7b7592

Browse files
committed
~~
1 parent 0693be6 commit f7b7592

File tree

5 files changed

+99
-56
lines changed

5 files changed

+99
-56
lines changed

pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py

+37-23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from copy import deepcopy
33
import numpy as np
44

5-
from typing import Any
5+
from typing import Any, List, Tuple
66

77
from .hierarchy import PatchHierarchy, format_timestamp
88
from .patchdata import FieldData, ParticleData
@@ -12,6 +12,7 @@
1212
from ...core.gridlayout import GridLayout
1313
from ...core.phare_utilities import listify
1414
from ...core.phare_utilities import refinement_ratio
15+
from pyphare.core import phare_utilities as phut
1516

1617

1718
field_qties = {
@@ -561,41 +562,53 @@ def _compute_scalardiv(patch_datas, **kwargs):
561562

562563
@dataclass
563564
class EqualityReport:
564-
ok: bool
565-
reason: str
566-
ref: Any = None
567-
cmp: Any = None
565+
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])
568566

569567
def __bool__(self):
570-
return self.ok
568+
return not self.failed
571569

572570
def __repr__(self):
573-
return self.reason
571+
for msg, ref, cmp in self:
572+
print(msg)
573+
try:
574+
if type(ref) is FieldData:
575+
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
576+
except AssertionError as e:
577+
print(e)
578+
return self.failed[0][0]
574579

575-
def __post_init__(self):
576-
not_nones = [a is not None for a in [self.ref, self.cmp]]
577-
if all(not_nones):
578-
assert id(self.ref) != id(self.cmp)
579-
else:
580-
assert not any(not_nones)
580+
def __call__(self, reason, ref=None, cmp=None):
581+
self.failed.append((reason, ref, cmp))
582+
return self
583+
584+
def __getitem__(self, idx):
585+
return (self.failed[idx][1], self.failed[idx][2])
586+
587+
def __iter__(self):
588+
return self.failed.__iter__()
589+
590+
def __reversed__(self):
591+
return reversed(self.failed)
581592

582593

583594
def hierarchy_compare(this, that, atol=1e-16):
595+
eqr = EqualityReport()
596+
584597
if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
585-
return EqualityReport(False, "class type mismatch")
598+
return eqr("class type mismatch")
586599

587600
if this.ndim != that.ndim or this.domain_box != that.domain_box:
588-
return EqualityReport(False, "dimensional mismatch")
601+
return eqr("dimensional mismatch")
589602

590603
if this.time_hier.keys() != that.time_hier.keys():
591-
return EqualityReport(False, "timesteps mismatch")
604+
return eqr("timesteps mismatch")
592605

593606
for tidx in this.times():
594607
patch_levels_ref = this.time_hier[tidx]
595608
patch_levels_cmp = that.time_hier[tidx]
596609

597610
if patch_levels_ref.keys() != patch_levels_cmp.keys():
598-
return EqualityReport(False, "levels mismatch")
611+
return eqr("levels mismatch")
599612

600613
for level_idx in patch_levels_cmp.keys():
601614
patch_level_ref = patch_levels_ref[level_idx]
@@ -606,19 +619,20 @@ def hierarchy_compare(this, that, atol=1e-16):
606619
patch_cmp = patch_level_cmp.patches[patch_idx]
607620

608621
if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
609-
return EqualityReport(False, "data keys mismatch")
622+
return eqr("data keys mismatch")
610623

611624
for patch_data_key in patch_ref.patch_datas.keys():
612625
patch_data_ref = patch_ref.patch_datas[patch_data_key]
613626
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]
614627

615628
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
616629
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
617-
return EqualityReport(
618-
False, msg, patch_data_cmp, patch_data_ref
619-
)
630+
eqr(msg, patch_data_cmp, patch_data_ref)
631+
632+
if not eqr:
633+
return eqr
620634

621-
return EqualityReport(True, "OK")
635+
return eqr
622636

623637

624638
def single_patch_for_LO(hier, qties=None, skip=None):

src/amr/data/field/initializers/samrai_hdf5_field_initializer.hpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ template<typename Field_t, typename GridLayout>
3131
void SamraiHDF5FieldInitializer<Field_t, GridLayout>::load(Field_t& field,
3232
GridLayout const& layout) const
3333
{
34-
auto const local_cell
35-
= [&](auto const& box, auto const& point) { return layout.AMRToLocal(point, box); };
36-
3734
auto const& dest_box = layout.AMRBox();
3835
auto const& centering = layout.centering(field.physicalQuantity());
3936
auto const& overlaps = SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(dest_box);
@@ -44,17 +41,22 @@ void SamraiHDF5FieldInitializer<Field_t, GridLayout>::load(Field_t& field,
4441
auto const src_box = pdata.box;
4542
auto const data = h5File.template read_data_set_flat<double>(
4643
pdata.base_path + "/" + field.name() + "##default/field_" + field.name());
47-
core::Box<std::uint32_t, GridLayout::dimension> const lcl_src_box{
44+
core::Box<std::uint32_t, GridLayout::dimension> const lcl_src_gbox{
4845
core::Point{core::ConstArray<std::uint32_t, GridLayout::dimension>()},
4946
core::Point{
5047
core::for_N<GridLayout::dimension, core::for_N_R_mode::make_array>([&](auto i) {
5148
return static_cast<std::uint32_t>(
5249
src_box.upper[i] - src_box.lower[i] + (GridLayout::nbrGhosts() * 2)
5350
+ (centering[i] == core::QtyCentering::primal ? 1 : 0));
5451
})}};
55-
auto data_view = core::make_array_view(data.data(), *lcl_src_box.shape());
56-
for (auto const& point : overlap_box)
57-
field(local_cell(dest_box, point)) = data_view(local_cell(src_box, point));
52+
auto const data_view = core::make_array_view(data.data(), *lcl_src_gbox.shape());
53+
auto const overlap_gb = grow(overlap_box, GridLayout::nbrGhosts());
54+
auto const lcl_src_box = layout.AMRToLocal(overlap_gb, src_box);
55+
auto const lcl_dst_box = layout.AMRToLocal(overlap_gb, dest_box);
56+
auto src_it = lcl_src_box.begin();
57+
auto dst_it = lcl_dst_box.begin();
58+
for (; src_it != lcl_src_box.end(); ++src_it, ++dst_it)
59+
field(*dst_it) = data_view(*src_it);
5860
}
5961
}
6062

src/amr/level_initializer/hybrid_level_initializer.hpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
#include "amr/level_initializer/level_initializer.hpp"
55
#include "amr/messengers/hybrid_messenger.hpp"
66
#include "amr/messengers/messenger.hpp"
7-
#include "amr/physical_models/hybrid_model.hpp"
87
#include "amr/physical_models/physical_model.hpp"
98
#include "amr/resources_manager/amr_utils.hpp"
109
#include "core/data/grid/gridlayout_utils.hpp"
11-
#include "core/data/ions/ions.hpp"
1210
#include "core/numerics/ampere/ampere.hpp"
1311
#include "core/numerics/interpolator/interpolator.hpp"
1412
#include "core/numerics/moments/moments.hpp"
@@ -43,10 +41,12 @@ namespace solver
4341
: ohm_{dict["algo"]["ohm"]}
4442
{
4543
}
46-
virtual void initialize(std::shared_ptr<hierarchy_t> const& hierarchy, int levelNumber,
47-
std::shared_ptr<level_t> const& oldLevel, IPhysicalModelT& model,
48-
amr::IMessenger<IPhysicalModelT>& messenger, double initDataTime,
49-
bool isRegridding) override
44+
45+
46+
void initialize(std::shared_ptr<hierarchy_t> const& hierarchy, int levelNumber,
47+
std::shared_ptr<level_t> const& oldLevel, IPhysicalModelT& model,
48+
amr::IMessenger<IPhysicalModelT>& messenger, double initDataTime,
49+
bool isRegridding) override
5050
{
5151
core::Interpolator<dimension, interp_order> interpolate_;
5252
auto& hybridModel = static_cast<HybridModel&>(model);
@@ -163,6 +163,8 @@ namespace solver
163163
hybMessenger.prepareStep(hybridModel, level, initDataTime);
164164
}
165165
};
166+
167+
166168
} // namespace solver
167169
} // namespace PHARE
168170

src/core/data/grid/gridlayout.hpp

+39-7
Original file line numberDiff line numberDiff line change
@@ -832,17 +832,19 @@ namespace core
832832
* This method only deals with **cell** indexes.
833833
*/
834834
template<typename T>
835-
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox) const
835+
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox,
836+
Box<int, dimension> const& localbox) const
836837
{
837838
static_assert(std::is_integral_v<T>, "Error, must be MeshIndex (integral Point)");
838-
auto localBox = Box<std::uint32_t, dimension>{};
839-
840-
localBox.lower = AMRToLocal(AMRBox.lower);
841-
localBox.upper = AMRToLocal(AMRBox.upper);
842-
843-
return localBox;
839+
return Box<std::uint32_t, dimension>{AMRToLocal(AMRBox.lower, localbox),
840+
AMRToLocal(AMRBox.upper, localbox)};
844841
}
845842

843+
template<typename T>
844+
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox) const
845+
{
846+
return AMRToLocal(AMRBox, AMRBox_);
847+
}
846848

847849

848850
template<typename Field, std::size_t nbr_points>
@@ -1171,6 +1173,22 @@ namespace core
11711173
evalOnBox_(field, fn, indices);
11721174
}
11731175

1176+
template<typename Field>
1177+
auto domainBoxFor(Field const& field) const
1178+
{
1179+
return _BoxFor(field, [&](auto const& centering, auto const direction) {
1180+
return this->physicalStartToEnd(centering, direction);
1181+
});
1182+
}
1183+
1184+
template<typename Field>
1185+
auto ghostBoxFor(Field const& field) const
1186+
{
1187+
return _BoxFor(field, [&](auto const& centering, auto const direction) {
1188+
return this->ghostStartToEnd(centering, direction);
1189+
});
1190+
}
1191+
11741192

11751193
private:
11761194
template<typename Field, typename IndicesFn, typename Fn>
@@ -1206,6 +1224,20 @@ namespace core
12061224
}
12071225

12081226

1227+
template<typename Field, typename Fn>
1228+
auto _BoxFor(Field const& field, Fn startToEnd) const
1229+
{
1230+
constexpr auto directions = std::array{Direction::X, Direction::Y, Direction::Z};
1231+
std::array<std::uint32_t, dimension> lower, upper;
1232+
core::for_N<dimension>([&](auto i) {
1233+
auto const [i0, i1] = startToEnd(field, directions[i]);
1234+
lower[i] = i0;
1235+
upper[i] = i1;
1236+
});
1237+
return Box<std::uint32_t, dimension>{lower, upper};
1238+
}
1239+
1240+
12091241
template<typename Centering, typename StartToEnd>
12101242
auto StartToEndIndices_(Centering const& centering, StartToEnd const&& startToEnd,
12111243
bool const includeEnd = false) const

tests/simulator/test_init_from_restart.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pyphare.pharein as ph
77

8-
from pyphare.core import phare_utilities as phut
8+
99
from pyphare.simulator.simulator import Simulator
1010
from pyphare.pharesee.hierarchy.patchdata import FieldData, ParticleData
1111
from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5
@@ -24,11 +24,9 @@
2424
cells = 200
2525
first_out = "phare_outputs/reinit/first"
2626
secnd_out = "phare_outputs/reinit/secnd"
27-
# timestamps = [0,time_step]
2827
timestamps = np.arange(0, final_time + time_step, time_step)
2928
restart_idx = Z = 2
3029
simInitArgs = dict(
31-
largest_patch_size=100,
3230
time_step_nbr=time_step_nbr,
3331
time_step=time_step,
3432
cells=cells,
@@ -41,7 +39,7 @@
4139
def setup_model(sim):
4240
model = ph.MaxwellianFluidModel(
4341
protons={"mass": 1, "charge": 1, "nbr_part_per_cell": ppc},
44-
alpha={"mass": 4.0, "charge": 1, "nbr_part_per_cell": ppc},
42+
alpha={"mass": 4, "charge": 1, "nbr_part_per_cell": ppc},
4543
)
4644
ph.ElectronModel(closure="isothermal", Te=0.12)
4745
dump_all_diags(model.populations, timestamps=timestamps)
@@ -65,23 +63,18 @@ def test_reinit(self):
6563
sim = ph.Simulation(**copy.deepcopy(simInitArgs))
6664
setup_model(sim)
6765
Simulator(sim).run().reset()
68-
fidx, sidx = 2, 0
66+
fidx, sidx = 4, 2
6967
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[fidx])
7068
datahier0.time_hier = { # swap times
7169
format_timestamp(timestamps[sidx]): datahier0.time_hier[
7270
format_timestamp(timestamps[fidx])
7371
]
7472
}
7573
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[sidx])
76-
qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"]
77-
skip = None # ["protons_patchGhost", "alpha_patchGhost"]
74+
qties = None
75+
skip = ["protons_patchGhost", "alpha_patchGhost"]
7876
ds = [single_patch_for_LO(d, qties, skip) for d in [datahier0, datahier1]]
79-
eq = hierarchy_compare(*ds, atol=1e-14)
80-
if not eq:
81-
print(eq)
82-
if type(eq.ref) == FieldData:
83-
phut.assert_fp_any_all_close(eq.ref[:], eq.cmp[:], atol=1e-16)
84-
self.assertTrue(eq)
77+
self.assertTrue(hierarchy_compare(*ds, atol=1e-12))
8578

8679

8780
def run_first_sim():

0 commit comments

Comments
 (0)