Skip to content

Commit 37a3eb8

Browse files
authored
Merge pull request #29 from SWIFTSIM/bugfix_namedcolumn_mask
Bugfix: Copy named column data when deep copying
2 parents 1a66007 + 29aae3e commit 37a3eb8

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed

swiftgalaxy/iterator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def __init__(
239239
)
240240
else:
241241
num_targets = len(self.halo_catalogue._region_centre)
242-
print(num_targets)
243242
# before evaluating optimized solutions:
244243
if num_targets == 0:
245244
# if we have 0 targets short-circuit

swiftgalaxy/reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,9 @@ def _data_copy(
19101910
)
19111911
if data is not None:
19121912
setattr(
1913-
new_named_columns_helper, f"_{named_column}", data[mask]
1913+
new_named_columns_helper._named_column_dataset,
1914+
f"_{named_column}",
1915+
data[mask],
19141916
)
19151917
else:
19161918
data = getattr(

tests/test_copy.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from copy import deepcopy
2+
import unyt as u
3+
from unyt.testing import assert_allclose_units
4+
5+
abstol_m = 1e2 * u.solMass
6+
reltol_m = 1.0e-4
7+
abstol_nd = 1.0e-4
8+
reltol_nd = 1.0e-4
9+
10+
11+
class TestCopySWIFTGalaxy:
12+
def test_deepcopy_sg(self, sg):
13+
"""
14+
Test that datasets get copied on deep copy.
15+
"""
16+
# lazy load a dataset and a named column
17+
sg.gas.masses
18+
sg.gas.hydrogen_ionization_fractions.neutral
19+
sg_copy = deepcopy(sg)
20+
# check private attribute to not trigger lazy loading
21+
assert_allclose_units(
22+
sg.gas.masses,
23+
sg_copy.gas._particle_dataset._masses,
24+
rtol=reltol_m,
25+
atol=abstol_m,
26+
)
27+
assert_allclose_units(
28+
sg.gas.hydrogen_ionization_fractions.neutral,
29+
sg_copy.gas.hydrogen_ionization_fractions._named_column_dataset._neutral,
30+
rtol=reltol_nd,
31+
atol=abstol_nd,
32+
)

tests/test_masking.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import pytest
22
import numpy as np
3-
import unyt as u
43
from unyt.testing import assert_allclose_units
54
from toysnap import present_particle_types
65
from swiftgalaxy import MaskCollection
76

8-
abstol_c = 1 * u.pc # less than this is ~0
9-
abstol_v = 10 * u.m / u.s # less than this is ~0
10-
abstol_a = 1.0e-4 * u.rad
117
abstol_nd = 1.0e-4
128
reltol_nd = 1.0e-4
139

@@ -61,6 +57,21 @@ def test_bool_mask(self, sg, particle_name, before_load):
6157
ids = getattr(sg, particle_name).particle_ids
6258
assert_allclose_units(ids_before[mask], ids, rtol=0, atol=0)
6359

60+
@pytest.mark.parametrize("before_load", (True, False))
61+
def test_namedcolumn_masked(self, sg, before_load):
62+
"""
63+
Test that named columns get masked too.
64+
"""
65+
neutral_before = sg.gas.hydrogen_ionization_fractions.neutral
66+
mask = np.random.rand(neutral_before.size) > 0.5
67+
if before_load:
68+
sg.gas.hydrogen_ionization_fractions._named_column_dataset._neutral = None
69+
sg.mask_particles(MaskCollection(**{"gas": mask}))
70+
neutral = sg.gas.hydrogen_ionization_fractions.neutral
71+
assert_allclose_units(
72+
neutral_before[mask], neutral, rtol=reltol_nd, atol=abstol_nd
73+
)
74+
6475

6576
class TestMaskingParticleDatasets:
6677
@pytest.mark.parametrize("particle_name", present_particle_types.values())

0 commit comments

Comments
 (0)