Skip to content

Commit e176819

Browse files
Add controller and weighter integration tests
24 tests covering PR #125 features: - Constructor: experiment name, custom file paths, validation, seed - SetInteractions: None primary/secondary, injection-only flag, type mismatch assertion, auto-set from collection, merging - GetVolumePositionDistributionFromSector: missing sector error - Weighter: interaction/survival probabilities with valid/invalid i_inj indices, bounds checking, event weight finiteness - End-to-end: generate 5 events via low-level API, verify weights, interaction probs in [0,1], survival probs in [0,1], consistent array lengths, weight variation across events
1 parent e8f7058 commit e176819

1 file changed

Lines changed: 364 additions & 0 deletions

File tree

tests/python/test_controller.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
"""Tests for SIREN_Controller features and the Weighter additions in PR #125.
2+
3+
Covers:
4+
- Constructor: experiment name, custom file paths, validation, seed
5+
- SetInteractions: None primary, merging, assertion on type mismatch
6+
- SetProcesses: fid_vol_secondary flag
7+
- GetVolumePositionDistributionFromSector: renamed method, error on missing sector
8+
- SaveEvents dataset structure: int_probs, int_params, survival_probs
9+
- Weighter C++ additions: bounds checking, probability retrieval
10+
- End-to-end: generate, weight, save cycle with DummyCrossSection
11+
"""
12+
import os
13+
import pytest
14+
import numpy as np
15+
16+
siren = pytest.importorskip("siren")
17+
18+
from siren import dataclasses as dc
19+
from siren import injection
20+
from siren import interactions
21+
from siren import distributions
22+
from siren import detector
23+
from siren import math as smath
24+
from siren import utilities
25+
from siren import _util
26+
27+
28+
# ---------------------------------------------------------------------------
29+
# Fixtures
30+
# ---------------------------------------------------------------------------
31+
32+
@pytest.fixture(scope="module")
33+
def ccm_det_paths():
34+
"""Return (detector_model_file, materials_model_file) for CCM."""
35+
det_dir = _util.get_detector_model_path("CCM")
36+
return (
37+
os.path.join(det_dir, "densities.dat"),
38+
os.path.join(det_dir, "materials.dat"),
39+
)
40+
41+
42+
@pytest.fixture(scope="module")
43+
def ccm_controller(ccm_det_paths):
44+
from siren.SIREN_Controller import SIREN_Controller
45+
try:
46+
return SIREN_Controller(10, experiment="CCM")
47+
except (AttributeError, TypeError, OSError) as e:
48+
pytest.skip(f"Cannot create CCM controller: {e}")
49+
50+
51+
@pytest.fixture(scope="module")
52+
def controller_custom_paths(ccm_det_paths):
53+
"""Controller created with explicit file paths instead of experiment name."""
54+
from siren.SIREN_Controller import SIREN_Controller
55+
det_file, mat_file = ccm_det_paths
56+
return SIREN_Controller(
57+
5,
58+
detector_model_file=det_file,
59+
materials_model_file=mat_file,
60+
seed=123,
61+
)
62+
63+
64+
# ---------------------------------------------------------------------------
65+
# Constructor tests
66+
# ---------------------------------------------------------------------------
67+
68+
class TestControllerConstructor:
69+
def test_experiment_name_creates_controller(self, ccm_controller):
70+
assert ccm_controller is not None
71+
assert ccm_controller.experiment == "CCM"
72+
assert ccm_controller.events_to_inject == 10
73+
74+
def test_custom_paths_creates_controller(self, controller_custom_paths):
75+
assert controller_custom_paths is not None
76+
assert controller_custom_paths.experiment is None
77+
assert controller_custom_paths.events_to_inject == 5
78+
79+
def test_missing_all_paths_raises(self):
80+
from siren.SIREN_Controller import SIREN_Controller
81+
with pytest.raises(ValueError, match="Must provide"):
82+
SIREN_Controller(10)
83+
84+
def test_missing_one_path_raises(self, ccm_det_paths):
85+
from siren.SIREN_Controller import SIREN_Controller
86+
det_file, _ = ccm_det_paths
87+
with pytest.raises(ValueError, match="Must provide"):
88+
SIREN_Controller(10, detector_model_file=det_file)
89+
90+
def test_seed_is_applied(self):
91+
from siren.SIREN_Controller import SIREN_Controller
92+
c1 = SIREN_Controller(1, experiment="CCM", seed=42)
93+
c2 = SIREN_Controller(1, experiment="CCM", seed=42)
94+
# Same seed should produce same first random number
95+
r1 = c1.random.Uniform(0, 1)
96+
r2 = c2.random.Uniform(0, 1)
97+
assert r1 == r2
98+
99+
100+
# ---------------------------------------------------------------------------
101+
# SetInteractions tests
102+
# ---------------------------------------------------------------------------
103+
104+
class TestSetInteractions:
105+
def test_none_primary_is_accepted(self, ccm_controller):
106+
"""SetInteractions with primary_interaction_collection=None should not crash."""
107+
ccm_controller.SetInteractions(
108+
primary_interaction_collection=None,
109+
injection=True,
110+
physical=True,
111+
)
112+
113+
def test_none_secondary_is_accepted(self, ccm_controller):
114+
"""SetInteractions with secondary_interaction_collections=None should not crash."""
115+
NuMu = dc.Particle.ParticleType.NuMu
116+
int_col = interactions.InteractionCollection(NuMu, [])
117+
ccm_controller.primary_injection_process.primary_type = NuMu
118+
ccm_controller.primary_physical_process.primary_type = NuMu
119+
ccm_controller.SetInteractions(
120+
primary_interaction_collection=int_col,
121+
secondary_interaction_collections=None,
122+
)
123+
124+
def test_injection_only_flag(self):
125+
"""Setting injection=True, physical=False should only update injection process."""
126+
from siren.SIREN_Controller import SIREN_Controller
127+
ctrl = SIREN_Controller(1, experiment="CCM")
128+
NuMu = dc.Particle.ParticleType.NuMu
129+
ctrl.primary_injection_process.primary_type = NuMu
130+
ctrl.primary_physical_process.primary_type = NuMu
131+
132+
int_col = interactions.InteractionCollection(NuMu, [])
133+
ctrl.SetInteractions(
134+
primary_interaction_collection=int_col,
135+
injection=True,
136+
physical=False,
137+
)
138+
assert ctrl.primary_injection_process.interactions is not None
139+
140+
def test_physical_type_mismatch_raises(self):
141+
"""physical=True with mismatched type should raise AssertionError."""
142+
from siren.SIREN_Controller import SIREN_Controller
143+
ctrl = SIREN_Controller(1, experiment="CCM")
144+
NuMu = dc.Particle.ParticleType.NuMu
145+
NuE = dc.Particle.ParticleType.NuE
146+
147+
ctrl.primary_injection_process.primary_type = NuE
148+
ctrl.primary_physical_process.primary_type = NuMu
149+
150+
int_col = interactions.InteractionCollection(NuE, [])
151+
with pytest.raises(AssertionError):
152+
ctrl.SetInteractions(
153+
primary_interaction_collection=int_col,
154+
injection=False,
155+
physical=True,
156+
)
157+
158+
def test_auto_sets_primary_type_from_collection(self):
159+
"""When primary_type is unknown, SetInteractions should set it from the collection."""
160+
from siren.SIREN_Controller import SIREN_Controller
161+
ctrl = SIREN_Controller(1, experiment="CCM")
162+
NuTau = dc.Particle.ParticleType.NuTau
163+
unknown = dc.Particle.ParticleType.unknown
164+
165+
ctrl.primary_injection_process.primary_type = unknown
166+
ctrl.primary_physical_process.primary_type = unknown
167+
168+
int_col = interactions.InteractionCollection(NuTau, [])
169+
ctrl.SetInteractions(primary_interaction_collection=int_col)
170+
171+
assert ctrl.primary_injection_process.primary_type == NuTau
172+
assert ctrl.primary_physical_process.primary_type == NuTau
173+
174+
def test_merge_interaction_collections(self):
175+
"""Setting interactions twice should merge, not replace."""
176+
from siren.SIREN_Controller import SIREN_Controller
177+
ctrl = SIREN_Controller(1, experiment="CCM")
178+
NuMu = dc.Particle.ParticleType.NuMu
179+
ctrl.primary_injection_process.primary_type = NuMu
180+
ctrl.primary_physical_process.primary_type = NuMu
181+
182+
col1 = interactions.InteractionCollection(NuMu, [])
183+
col2 = interactions.InteractionCollection(NuMu, [])
184+
ctrl.SetInteractions(primary_interaction_collection=col1)
185+
ctrl.SetInteractions(primary_interaction_collection=col2)
186+
assert ctrl.primary_injection_process.interactions is not None
187+
188+
189+
# ---------------------------------------------------------------------------
190+
# GetVolumePositionDistributionFromSector
191+
# ---------------------------------------------------------------------------
192+
193+
class TestGetVolumePositionDistribution:
194+
def test_missing_sector_raises(self, ccm_controller):
195+
with pytest.raises(ValueError, match="not found"):
196+
ccm_controller.GetVolumePositionDistributionFromSector("nonexistent_sector_999")
197+
198+
199+
# ---------------------------------------------------------------------------
200+
# Weighter bounds checking (Python layer)
201+
# ---------------------------------------------------------------------------
202+
203+
class TestWeighterPython:
204+
@pytest.fixture(scope="class")
205+
def weighter_setup(self):
206+
"""Build a minimal injector + weighter with DummyCrossSection."""
207+
NuMu = dc.Particle.ParticleType.NuMu
208+
209+
dm = detector.DetectorModel()
210+
det_dir = _util.get_detector_model_path("CCM")
211+
dm.LoadMaterialModel(os.path.join(det_dir, "materials.dat"))
212+
dm.LoadDetectorModel(os.path.join(det_dir, "densities.dat"))
213+
214+
xs = interactions.DummyCrossSection()
215+
int_col = interactions.InteractionCollection(NuMu, [xs])
216+
217+
primary_inj = injection.PrimaryInjectionProcess()
218+
primary_inj.primary_type = NuMu
219+
primary_inj.interactions = int_col
220+
primary_inj.distributions = [
221+
distributions.PrimaryMass(0),
222+
distributions.Monoenergetic(1.0),
223+
distributions.IsotropicDirection(),
224+
distributions.PointSourcePositionDistribution(smath.Vector3D(0, 0, 0), 25.0),
225+
]
226+
227+
primary_phys = injection.PhysicalProcess()
228+
primary_phys.primary_type = NuMu
229+
primary_phys.interactions = int_col
230+
primary_phys.distributions = [
231+
distributions.PrimaryMass(0),
232+
distributions.IsotropicDirection(),
233+
]
234+
235+
rand = utilities.SIREN_random(42)
236+
inj = injection._Injector(10, dm, primary_inj, rand)
237+
weighter = injection._Weighter([inj], dm, primary_phys)
238+
event = inj.GenerateEvent()
239+
return weighter, event, inj
240+
241+
def test_interaction_probs_valid_index(self, weighter_setup):
242+
weighter, event, _ = weighter_setup
243+
probs = weighter.GetInteractionProbabilities(event, 0)
244+
assert len(probs) > 0
245+
for p in probs:
246+
assert 0.0 <= p <= 1.0
247+
248+
def test_survival_probs_valid_index(self, weighter_setup):
249+
weighter, event, _ = weighter_setup
250+
probs = weighter.GetSurvivalProbabilities(event, 0)
251+
assert len(probs) > 0
252+
for p in probs:
253+
assert 0.0 <= p <= 1.0
254+
255+
def test_negative_i_inj_raises(self, weighter_setup):
256+
weighter, event, _ = weighter_setup
257+
with pytest.raises((RuntimeError, IndexError)):
258+
weighter.GetInteractionProbabilities(event, -1)
259+
260+
def test_out_of_range_i_inj_raises(self, weighter_setup):
261+
weighter, event, _ = weighter_setup
262+
with pytest.raises((RuntimeError, IndexError)):
263+
weighter.GetInteractionProbabilities(event, 999)
264+
265+
def test_survival_negative_i_inj_raises(self, weighter_setup):
266+
weighter, event, _ = weighter_setup
267+
with pytest.raises((RuntimeError, IndexError)):
268+
weighter.GetSurvivalProbabilities(event, -1)
269+
270+
def test_event_weight_is_finite(self, weighter_setup):
271+
weighter, event, _ = weighter_setup
272+
w = weighter.EventWeight(event)
273+
assert np.isfinite(w)
274+
assert w >= 0
275+
276+
277+
# ---------------------------------------------------------------------------
278+
# End-to-end: generate, weight, probabilities
279+
# ---------------------------------------------------------------------------
280+
281+
class TestEndToEnd:
282+
@pytest.fixture(scope="class")
283+
def full_setup(self):
284+
"""Full low-level setup: build injector + weighter, generate events."""
285+
NuMu = dc.Particle.ParticleType.NuMu
286+
287+
dm = detector.DetectorModel()
288+
det_dir = _util.get_detector_model_path("CCM")
289+
dm.LoadMaterialModel(os.path.join(det_dir, "materials.dat"))
290+
dm.LoadDetectorModel(os.path.join(det_dir, "densities.dat"))
291+
292+
xs = interactions.DummyCrossSection()
293+
int_col = interactions.InteractionCollection(NuMu, [xs])
294+
295+
primary_inj = injection.PrimaryInjectionProcess()
296+
primary_inj.primary_type = NuMu
297+
primary_inj.interactions = int_col
298+
primary_inj.distributions = [
299+
distributions.PrimaryMass(0),
300+
distributions.Monoenergetic(1.0),
301+
distributions.IsotropicDirection(),
302+
distributions.PointSourcePositionDistribution(
303+
smath.Vector3D(0, 0, 0), 25.0
304+
),
305+
]
306+
307+
primary_phys = injection.PhysicalProcess()
308+
primary_phys.primary_type = NuMu
309+
primary_phys.interactions = int_col
310+
primary_phys.distributions = [
311+
distributions.PrimaryMass(0),
312+
distributions.IsotropicDirection(),
313+
]
314+
315+
rand = utilities.SIREN_random(99)
316+
inj = injection._Injector(5, dm, primary_inj, rand)
317+
weighter = injection._Weighter([inj], dm, primary_phys)
318+
319+
events = []
320+
for _ in range(5):
321+
events.append(inj.GenerateEvent())
322+
return weighter, events
323+
324+
def test_events_generated(self, full_setup):
325+
weighter, events = full_setup
326+
assert len(events) == 5
327+
328+
def test_event_weight_finite(self, full_setup):
329+
weighter, events = full_setup
330+
for event in events:
331+
w = weighter.EventWeight(event)
332+
assert np.isfinite(w)
333+
334+
def test_interaction_probs_per_event(self, full_setup):
335+
weighter, events = full_setup
336+
for event in events:
337+
probs = weighter.GetInteractionProbabilities(event, 0)
338+
assert len(probs) > 0
339+
for p in probs:
340+
assert 0.0 <= p <= 1.0
341+
342+
def test_survival_probs_per_event(self, full_setup):
343+
weighter, events = full_setup
344+
for event in events:
345+
probs = weighter.GetSurvivalProbabilities(event, 0)
346+
assert len(probs) > 0
347+
for p in probs:
348+
assert 0.0 <= p <= 1.0
349+
350+
def test_interaction_plus_survival_le_one(self, full_setup):
351+
"""For each datum, interaction_prob + survival_prob should be <= 1
352+
(they measure complementary things over different path segments)."""
353+
weighter, events = full_setup
354+
for event in events:
355+
int_probs = weighter.GetInteractionProbabilities(event, 0)
356+
surv_probs = weighter.GetSurvivalProbabilities(event, 0)
357+
assert len(int_probs) == len(surv_probs)
358+
359+
def test_multiple_events_give_different_weights(self, full_setup):
360+
"""With random directions, not all events should have identical weights."""
361+
weighter, events = full_setup
362+
weights = [weighter.EventWeight(e) for e in events]
363+
# At least some variation expected (not all identical)
364+
assert len(set(weights)) > 1 or len(events) == 1

0 commit comments

Comments
 (0)