Skip to content

Commit a457c66

Browse files
committed
++
1 parent 686239f commit a457c66

File tree

4 files changed

+78
-28
lines changed

4 files changed

+78
-28
lines changed

pyphare/pyphare/pharesee/hierarchy/hierarchy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
14
from .patch import Patch
25
from .patchlevel import PatchLevel
36
from ...core.box import Box
47
from ...core import box as boxm
5-
from ...core.phare_utilities import refinement_ratio
68
from ...core.phare_utilities import listify
7-
8-
import numpy as np
9-
import matplotlib.pyplot as plt
9+
from ...core.phare_utilities import deep_copy
10+
from ...core.phare_utilities import refinement_ratio
1011

1112

1213
def format_timestamp(timestamp):
@@ -68,6 +69,10 @@ def __init__(
6869

6970
self.update()
7071

72+
def __deepcopy__(self, memo):
73+
no_copy_keys = ["data_files"] # do not copy these things
74+
return deep_copy(self, memo, no_copy_keys)
75+
7176
def __getitem__(self, qty):
7277
return self.__dict__[qty]
7378

pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
from .hierarchy import PatchHierarchy
2-
from .patchdata import FieldData
1+
from dataclasses import dataclass
2+
from copy import deepcopy
3+
import numpy as np
4+
5+
from .hierarchy import PatchHierarchy, format_timestamp
6+
from .patchdata import FieldData, ParticleData
37
from .patchlevel import PatchLevel
48
from .patch import Patch
9+
from ...core.box import Box
10+
from ...core.gridlayout import GridLayout
511
from ...core.phare_utilities import listify
612
from ...core.phare_utilities import refinement_ratio
13+
from pyphare.pharesee import particles as mparticles
714

8-
import numpy as np
915

1016
field_qties = {
1117
"EM_B_x": "Bx",
@@ -552,9 +558,6 @@ def _compute_scalardiv(patch_datas, **kwargs):
552558
return tuple(pd_attrs)
553559

554560

555-
from dataclasses import dataclass
556-
557-
558561
@dataclass
559562
class EqualityReport:
560563
ok: bool
@@ -606,3 +609,46 @@ def hierarchy_compare(this, that):
606609
return EqualityReport(False, msg)
607610

608611
return EqualityReport(True, "OK")
612+
613+
614+
def single_patch_for_LO(hier, qties=None):
615+
def _skip(qty):
616+
return qties is not None and qty not in qties
617+
618+
cier = deepcopy(hier)
619+
sim = hier.sim
620+
layout = GridLayout(
621+
Box(sim.origin, sim.cells), sim.origin, sim.dl, interp_order=sim.interp_order
622+
)
623+
p0 = Patch(patch_datas={}, patch_id="", layout=layout)
624+
for t in cier.times():
625+
cier.time_hier[format_timestamp(t)] = {0: cier.level(0, t)}
626+
cier.level(0, t).patches = [deepcopy(p0)]
627+
l0_pds = cier.level(0, t).patches[0].patch_datas
628+
for k, v in hier.level(0, t).patches[0].patch_datas.items():
629+
if _skip(k):
630+
continue
631+
if isinstance(v, FieldData):
632+
l0_pds[k] = FieldData(
633+
layout, v.field_name, None, centering=v.centerings
634+
)
635+
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
636+
637+
elif isinstance(v, ParticleData):
638+
l0_pds[k] = deepcopy(v)
639+
else:
640+
raise RuntimeError("unexpected state")
641+
642+
for patch in hier.level(0, t).patches:
643+
for k, v in patch.patch_datas.items():
644+
if _skip(k):
645+
continue
646+
if isinstance(v, FieldData):
647+
l0_pds[k][patch.box] = v[patch.box]
648+
elif isinstance(v, ParticleData):
649+
l0_pds[k].dataset = mparticles.aggregate(
650+
[l0_pds[k].dataset, v.dataset]
651+
)
652+
else:
653+
raise RuntimeError("unexpected state")
654+
return cier

pyphare/pyphare/pharesee/hierarchy/patchdata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __getitem__(self, box_or_slice):
118118
return self.dataset[box_or_slice]
119119
return self.select(box_or_slice)
120120

121+
def __setitem__(self, box_or_slice, val):
122+
self.__getitem__(box_or_slice)[:] = val
123+
121124
def __init__(self, layout, field_name, data, **kwargs):
122125
"""
123126
:param layout: A GridLayout representing the domain on which data is defined

tests/simulator/test_init_from_restart.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,26 @@
99
from pyphare.core import phare_utilities as phut
1010
from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5
1111
from pyphare.pharesee.particles import single_patch_per_level_per_pop_from
12-
from pyphare.pharesee.hierarchy.hierarchy_utils import flat_finest_field
12+
from pyphare.pharesee.hierarchy.hierarchy_utils import (
13+
flat_finest_field,
14+
single_patch_for_LO,
15+
hierarchy_compare,
16+
)
1317

1418
from tests.simulator import SimulatorTest, test_restarts
1519
from tests.diagnostic import dump_all_diags
1620

1721
timestep = 0.001
1822
time_step_nbr = 1
19-
first_mpi_size = 1
23+
first_mpi_size = 4
2024
ppc = 100
2125
cells = 200
2226
first_out = "phare_outputs/reinit/first"
2327
secnd_out = "phare_outputs/reinit/secnd"
24-
timestamps = [0] # np.array([timestep * 2, timestep * 4])
28+
timestamps = [0]
2529
restart_idx = Z = 0
2630
simInitArgs = dict(
27-
# largest_patch_size=100,
31+
largest_patch_size=100,
2832
time_step_nbr=time_step_nbr,
2933
time_step=timestep,
3034
cells=cells,
@@ -63,20 +67,12 @@ def test_reinit(self):
6367
Simulator(sim).run().reset()
6468
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[0])
6569
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[0])
66-
67-
for k in "xyz":
68-
a = flat_finest_field(datahier0, f"B{k}", timestamps[0], 0)
69-
b = flat_finest_field(datahier1, f"B{k}", timestamps[0], 0)
70-
phut.assert_fp_any_all_close(a, b)
71-
72-
def get_merged(hier):
73-
return single_patch_per_level_per_pop_from(hier)
74-
75-
ds = [get_merged(datahier0), get_merged(datahier1)]
76-
for key in ["alpha", "protons"]:
77-
a, b = [d.level(0).patches[0].patch_datas[f"{key}_domain"] for d in ds]
78-
self.assertGreater(a.size(), (cells - 1) * ppc)
79-
self.assertEqual(a, b)
70+
qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"]
71+
ds = [single_patch_for_LO(d, qties) for d in [datahier0, datahier1]]
72+
eq = hierarchy_compare(*ds)
73+
if not eq:
74+
print(eq)
75+
self.assertTrue(eq)
8076

8177

8278
def run_first_sim():

0 commit comments

Comments
 (0)