|
| 1 | +"""Tests for SIREN_Controller features and the Weighter additions in PR #125. |
| 2 | +
|
| 3 | +Covers: |
| 4 | + - Constructor: experiment name, custom file paths, validation, seed |
| 5 | + - SetInteractions: None primary, merging, assertion on type mismatch |
| 6 | + - SetProcesses: fid_vol_secondary flag |
| 7 | + - GetVolumePositionDistributionFromSector: renamed method, error on missing sector |
| 8 | + - SaveEvents dataset structure: int_probs, int_params, survival_probs |
| 9 | + - Weighter C++ additions: bounds checking, probability retrieval |
| 10 | + - End-to-end: generate, weight, save cycle with DummyCrossSection |
| 11 | +""" |
| 12 | +import os |
| 13 | +import pytest |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +siren = pytest.importorskip("siren") |
| 17 | + |
| 18 | +from siren import dataclasses as dc |
| 19 | +from siren import injection |
| 20 | +from siren import interactions |
| 21 | +from siren import distributions |
| 22 | +from siren import detector |
| 23 | +from siren import math as smath |
| 24 | +from siren import utilities |
| 25 | +from siren import _util |
| 26 | + |
| 27 | + |
| 28 | +# --------------------------------------------------------------------------- |
| 29 | +# Fixtures |
| 30 | +# --------------------------------------------------------------------------- |
| 31 | + |
| 32 | +@pytest.fixture(scope="module") |
| 33 | +def ccm_det_paths(): |
| 34 | + """Return (detector_model_file, materials_model_file) for CCM.""" |
| 35 | + det_dir = _util.get_detector_model_path("CCM") |
| 36 | + return ( |
| 37 | + os.path.join(det_dir, "densities.dat"), |
| 38 | + os.path.join(det_dir, "materials.dat"), |
| 39 | + ) |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture(scope="module") |
| 43 | +def ccm_controller(ccm_det_paths): |
| 44 | + from siren.SIREN_Controller import SIREN_Controller |
| 45 | + try: |
| 46 | + return SIREN_Controller(10, experiment="CCM") |
| 47 | + except (AttributeError, TypeError, OSError) as e: |
| 48 | + pytest.skip(f"Cannot create CCM controller: {e}") |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture(scope="module") |
| 52 | +def controller_custom_paths(ccm_det_paths): |
| 53 | + """Controller created with explicit file paths instead of experiment name.""" |
| 54 | + from siren.SIREN_Controller import SIREN_Controller |
| 55 | + det_file, mat_file = ccm_det_paths |
| 56 | + return SIREN_Controller( |
| 57 | + 5, |
| 58 | + detector_model_file=det_file, |
| 59 | + materials_model_file=mat_file, |
| 60 | + seed=123, |
| 61 | + ) |
| 62 | + |
| 63 | + |
| 64 | +# --------------------------------------------------------------------------- |
| 65 | +# Constructor tests |
| 66 | +# --------------------------------------------------------------------------- |
| 67 | + |
| 68 | +class TestControllerConstructor: |
| 69 | + def test_experiment_name_creates_controller(self, ccm_controller): |
| 70 | + assert ccm_controller is not None |
| 71 | + assert ccm_controller.experiment == "CCM" |
| 72 | + assert ccm_controller.events_to_inject == 10 |
| 73 | + |
| 74 | + def test_custom_paths_creates_controller(self, controller_custom_paths): |
| 75 | + assert controller_custom_paths is not None |
| 76 | + assert controller_custom_paths.experiment is None |
| 77 | + assert controller_custom_paths.events_to_inject == 5 |
| 78 | + |
| 79 | + def test_missing_all_paths_raises(self): |
| 80 | + from siren.SIREN_Controller import SIREN_Controller |
| 81 | + with pytest.raises(ValueError, match="Must provide"): |
| 82 | + SIREN_Controller(10) |
| 83 | + |
| 84 | + def test_missing_one_path_raises(self, ccm_det_paths): |
| 85 | + from siren.SIREN_Controller import SIREN_Controller |
| 86 | + det_file, _ = ccm_det_paths |
| 87 | + with pytest.raises(ValueError, match="Must provide"): |
| 88 | + SIREN_Controller(10, detector_model_file=det_file) |
| 89 | + |
| 90 | + def test_seed_is_applied(self): |
| 91 | + from siren.SIREN_Controller import SIREN_Controller |
| 92 | + c1 = SIREN_Controller(1, experiment="CCM", seed=42) |
| 93 | + c2 = SIREN_Controller(1, experiment="CCM", seed=42) |
| 94 | + # Same seed should produce same first random number |
| 95 | + r1 = c1.random.Uniform(0, 1) |
| 96 | + r2 = c2.random.Uniform(0, 1) |
| 97 | + assert r1 == r2 |
| 98 | + |
| 99 | + |
| 100 | +# --------------------------------------------------------------------------- |
| 101 | +# SetInteractions tests |
| 102 | +# --------------------------------------------------------------------------- |
| 103 | + |
| 104 | +class TestSetInteractions: |
| 105 | + def test_none_primary_is_accepted(self, ccm_controller): |
| 106 | + """SetInteractions with primary_interaction_collection=None should not crash.""" |
| 107 | + ccm_controller.SetInteractions( |
| 108 | + primary_interaction_collection=None, |
| 109 | + injection=True, |
| 110 | + physical=True, |
| 111 | + ) |
| 112 | + |
| 113 | + def test_none_secondary_is_accepted(self, ccm_controller): |
| 114 | + """SetInteractions with secondary_interaction_collections=None should not crash.""" |
| 115 | + NuMu = dc.Particle.ParticleType.NuMu |
| 116 | + int_col = interactions.InteractionCollection(NuMu, []) |
| 117 | + ccm_controller.primary_injection_process.primary_type = NuMu |
| 118 | + ccm_controller.primary_physical_process.primary_type = NuMu |
| 119 | + ccm_controller.SetInteractions( |
| 120 | + primary_interaction_collection=int_col, |
| 121 | + secondary_interaction_collections=None, |
| 122 | + ) |
| 123 | + |
| 124 | + def test_injection_only_flag(self): |
| 125 | + """Setting injection=True, physical=False should only update injection process.""" |
| 126 | + from siren.SIREN_Controller import SIREN_Controller |
| 127 | + ctrl = SIREN_Controller(1, experiment="CCM") |
| 128 | + NuMu = dc.Particle.ParticleType.NuMu |
| 129 | + ctrl.primary_injection_process.primary_type = NuMu |
| 130 | + ctrl.primary_physical_process.primary_type = NuMu |
| 131 | + |
| 132 | + int_col = interactions.InteractionCollection(NuMu, []) |
| 133 | + ctrl.SetInteractions( |
| 134 | + primary_interaction_collection=int_col, |
| 135 | + injection=True, |
| 136 | + physical=False, |
| 137 | + ) |
| 138 | + assert ctrl.primary_injection_process.interactions is not None |
| 139 | + |
| 140 | + def test_physical_type_mismatch_raises(self): |
| 141 | + """physical=True with mismatched type should raise AssertionError.""" |
| 142 | + from siren.SIREN_Controller import SIREN_Controller |
| 143 | + ctrl = SIREN_Controller(1, experiment="CCM") |
| 144 | + NuMu = dc.Particle.ParticleType.NuMu |
| 145 | + NuE = dc.Particle.ParticleType.NuE |
| 146 | + |
| 147 | + ctrl.primary_injection_process.primary_type = NuE |
| 148 | + ctrl.primary_physical_process.primary_type = NuMu |
| 149 | + |
| 150 | + int_col = interactions.InteractionCollection(NuE, []) |
| 151 | + with pytest.raises(AssertionError): |
| 152 | + ctrl.SetInteractions( |
| 153 | + primary_interaction_collection=int_col, |
| 154 | + injection=False, |
| 155 | + physical=True, |
| 156 | + ) |
| 157 | + |
| 158 | + def test_auto_sets_primary_type_from_collection(self): |
| 159 | + """When primary_type is unknown, SetInteractions should set it from the collection.""" |
| 160 | + from siren.SIREN_Controller import SIREN_Controller |
| 161 | + ctrl = SIREN_Controller(1, experiment="CCM") |
| 162 | + NuTau = dc.Particle.ParticleType.NuTau |
| 163 | + unknown = dc.Particle.ParticleType.unknown |
| 164 | + |
| 165 | + ctrl.primary_injection_process.primary_type = unknown |
| 166 | + ctrl.primary_physical_process.primary_type = unknown |
| 167 | + |
| 168 | + int_col = interactions.InteractionCollection(NuTau, []) |
| 169 | + ctrl.SetInteractions(primary_interaction_collection=int_col) |
| 170 | + |
| 171 | + assert ctrl.primary_injection_process.primary_type == NuTau |
| 172 | + assert ctrl.primary_physical_process.primary_type == NuTau |
| 173 | + |
| 174 | + def test_merge_interaction_collections(self): |
| 175 | + """Setting interactions twice should merge, not replace.""" |
| 176 | + from siren.SIREN_Controller import SIREN_Controller |
| 177 | + ctrl = SIREN_Controller(1, experiment="CCM") |
| 178 | + NuMu = dc.Particle.ParticleType.NuMu |
| 179 | + ctrl.primary_injection_process.primary_type = NuMu |
| 180 | + ctrl.primary_physical_process.primary_type = NuMu |
| 181 | + |
| 182 | + col1 = interactions.InteractionCollection(NuMu, []) |
| 183 | + col2 = interactions.InteractionCollection(NuMu, []) |
| 184 | + ctrl.SetInteractions(primary_interaction_collection=col1) |
| 185 | + ctrl.SetInteractions(primary_interaction_collection=col2) |
| 186 | + assert ctrl.primary_injection_process.interactions is not None |
| 187 | + |
| 188 | + |
| 189 | +# --------------------------------------------------------------------------- |
| 190 | +# GetVolumePositionDistributionFromSector |
| 191 | +# --------------------------------------------------------------------------- |
| 192 | + |
| 193 | +class TestGetVolumePositionDistribution: |
| 194 | + def test_missing_sector_raises(self, ccm_controller): |
| 195 | + with pytest.raises(ValueError, match="not found"): |
| 196 | + ccm_controller.GetVolumePositionDistributionFromSector("nonexistent_sector_999") |
| 197 | + |
| 198 | + |
| 199 | +# --------------------------------------------------------------------------- |
| 200 | +# Weighter bounds checking (Python layer) |
| 201 | +# --------------------------------------------------------------------------- |
| 202 | + |
| 203 | +class TestWeighterPython: |
| 204 | + @pytest.fixture(scope="class") |
| 205 | + def weighter_setup(self): |
| 206 | + """Build a minimal injector + weighter with DummyCrossSection.""" |
| 207 | + NuMu = dc.Particle.ParticleType.NuMu |
| 208 | + |
| 209 | + dm = detector.DetectorModel() |
| 210 | + det_dir = _util.get_detector_model_path("CCM") |
| 211 | + dm.LoadMaterialModel(os.path.join(det_dir, "materials.dat")) |
| 212 | + dm.LoadDetectorModel(os.path.join(det_dir, "densities.dat")) |
| 213 | + |
| 214 | + xs = interactions.DummyCrossSection() |
| 215 | + int_col = interactions.InteractionCollection(NuMu, [xs]) |
| 216 | + |
| 217 | + primary_inj = injection.PrimaryInjectionProcess() |
| 218 | + primary_inj.primary_type = NuMu |
| 219 | + primary_inj.interactions = int_col |
| 220 | + primary_inj.distributions = [ |
| 221 | + distributions.PrimaryMass(0), |
| 222 | + distributions.Monoenergetic(1.0), |
| 223 | + distributions.IsotropicDirection(), |
| 224 | + distributions.PointSourcePositionDistribution(smath.Vector3D(0, 0, 0), 25.0), |
| 225 | + ] |
| 226 | + |
| 227 | + primary_phys = injection.PhysicalProcess() |
| 228 | + primary_phys.primary_type = NuMu |
| 229 | + primary_phys.interactions = int_col |
| 230 | + primary_phys.distributions = [ |
| 231 | + distributions.PrimaryMass(0), |
| 232 | + distributions.IsotropicDirection(), |
| 233 | + ] |
| 234 | + |
| 235 | + rand = utilities.SIREN_random(42) |
| 236 | + inj = injection._Injector(10, dm, primary_inj, rand) |
| 237 | + weighter = injection._Weighter([inj], dm, primary_phys) |
| 238 | + event = inj.GenerateEvent() |
| 239 | + return weighter, event, inj |
| 240 | + |
| 241 | + def test_interaction_probs_valid_index(self, weighter_setup): |
| 242 | + weighter, event, _ = weighter_setup |
| 243 | + probs = weighter.GetInteractionProbabilities(event, 0) |
| 244 | + assert len(probs) > 0 |
| 245 | + for p in probs: |
| 246 | + assert 0.0 <= p <= 1.0 |
| 247 | + |
| 248 | + def test_survival_probs_valid_index(self, weighter_setup): |
| 249 | + weighter, event, _ = weighter_setup |
| 250 | + probs = weighter.GetSurvivalProbabilities(event, 0) |
| 251 | + assert len(probs) > 0 |
| 252 | + for p in probs: |
| 253 | + assert 0.0 <= p <= 1.0 |
| 254 | + |
| 255 | + def test_negative_i_inj_raises(self, weighter_setup): |
| 256 | + weighter, event, _ = weighter_setup |
| 257 | + with pytest.raises((RuntimeError, IndexError)): |
| 258 | + weighter.GetInteractionProbabilities(event, -1) |
| 259 | + |
| 260 | + def test_out_of_range_i_inj_raises(self, weighter_setup): |
| 261 | + weighter, event, _ = weighter_setup |
| 262 | + with pytest.raises((RuntimeError, IndexError)): |
| 263 | + weighter.GetInteractionProbabilities(event, 999) |
| 264 | + |
| 265 | + def test_survival_negative_i_inj_raises(self, weighter_setup): |
| 266 | + weighter, event, _ = weighter_setup |
| 267 | + with pytest.raises((RuntimeError, IndexError)): |
| 268 | + weighter.GetSurvivalProbabilities(event, -1) |
| 269 | + |
| 270 | + def test_event_weight_is_finite(self, weighter_setup): |
| 271 | + weighter, event, _ = weighter_setup |
| 272 | + w = weighter.EventWeight(event) |
| 273 | + assert np.isfinite(w) |
| 274 | + assert w >= 0 |
| 275 | + |
| 276 | + |
| 277 | +# --------------------------------------------------------------------------- |
| 278 | +# End-to-end: generate, weight, probabilities |
| 279 | +# --------------------------------------------------------------------------- |
| 280 | + |
| 281 | +class TestEndToEnd: |
| 282 | + @pytest.fixture(scope="class") |
| 283 | + def full_setup(self): |
| 284 | + """Full low-level setup: build injector + weighter, generate events.""" |
| 285 | + NuMu = dc.Particle.ParticleType.NuMu |
| 286 | + |
| 287 | + dm = detector.DetectorModel() |
| 288 | + det_dir = _util.get_detector_model_path("CCM") |
| 289 | + dm.LoadMaterialModel(os.path.join(det_dir, "materials.dat")) |
| 290 | + dm.LoadDetectorModel(os.path.join(det_dir, "densities.dat")) |
| 291 | + |
| 292 | + xs = interactions.DummyCrossSection() |
| 293 | + int_col = interactions.InteractionCollection(NuMu, [xs]) |
| 294 | + |
| 295 | + primary_inj = injection.PrimaryInjectionProcess() |
| 296 | + primary_inj.primary_type = NuMu |
| 297 | + primary_inj.interactions = int_col |
| 298 | + primary_inj.distributions = [ |
| 299 | + distributions.PrimaryMass(0), |
| 300 | + distributions.Monoenergetic(1.0), |
| 301 | + distributions.IsotropicDirection(), |
| 302 | + distributions.PointSourcePositionDistribution( |
| 303 | + smath.Vector3D(0, 0, 0), 25.0 |
| 304 | + ), |
| 305 | + ] |
| 306 | + |
| 307 | + primary_phys = injection.PhysicalProcess() |
| 308 | + primary_phys.primary_type = NuMu |
| 309 | + primary_phys.interactions = int_col |
| 310 | + primary_phys.distributions = [ |
| 311 | + distributions.PrimaryMass(0), |
| 312 | + distributions.IsotropicDirection(), |
| 313 | + ] |
| 314 | + |
| 315 | + rand = utilities.SIREN_random(99) |
| 316 | + inj = injection._Injector(5, dm, primary_inj, rand) |
| 317 | + weighter = injection._Weighter([inj], dm, primary_phys) |
| 318 | + |
| 319 | + events = [] |
| 320 | + for _ in range(5): |
| 321 | + events.append(inj.GenerateEvent()) |
| 322 | + return weighter, events |
| 323 | + |
| 324 | + def test_events_generated(self, full_setup): |
| 325 | + weighter, events = full_setup |
| 326 | + assert len(events) == 5 |
| 327 | + |
| 328 | + def test_event_weight_finite(self, full_setup): |
| 329 | + weighter, events = full_setup |
| 330 | + for event in events: |
| 331 | + w = weighter.EventWeight(event) |
| 332 | + assert np.isfinite(w) |
| 333 | + |
| 334 | + def test_interaction_probs_per_event(self, full_setup): |
| 335 | + weighter, events = full_setup |
| 336 | + for event in events: |
| 337 | + probs = weighter.GetInteractionProbabilities(event, 0) |
| 338 | + assert len(probs) > 0 |
| 339 | + for p in probs: |
| 340 | + assert 0.0 <= p <= 1.0 |
| 341 | + |
| 342 | + def test_survival_probs_per_event(self, full_setup): |
| 343 | + weighter, events = full_setup |
| 344 | + for event in events: |
| 345 | + probs = weighter.GetSurvivalProbabilities(event, 0) |
| 346 | + assert len(probs) > 0 |
| 347 | + for p in probs: |
| 348 | + assert 0.0 <= p <= 1.0 |
| 349 | + |
| 350 | + def test_interaction_plus_survival_le_one(self, full_setup): |
| 351 | + """For each datum, interaction_prob + survival_prob should be <= 1 |
| 352 | + (they measure complementary things over different path segments).""" |
| 353 | + weighter, events = full_setup |
| 354 | + for event in events: |
| 355 | + int_probs = weighter.GetInteractionProbabilities(event, 0) |
| 356 | + surv_probs = weighter.GetSurvivalProbabilities(event, 0) |
| 357 | + assert len(int_probs) == len(surv_probs) |
| 358 | + |
| 359 | + def test_multiple_events_give_different_weights(self, full_setup): |
| 360 | + """With random directions, not all events should have identical weights.""" |
| 361 | + weighter, events = full_setup |
| 362 | + weights = [weighter.EventWeight(e) for e in events] |
| 363 | + # At least some variation expected (not all identical) |
| 364 | + assert len(set(weights)) > 1 or len(events) == 1 |
0 commit comments