Skip to content

Commit 1999eb8

Browse files
Fix controller and weighter bugs from PR #125 review
Weighter (C++): - Implement SurvivalProbability in Weighter.tcc with [0,1] clamping - Add i_inj bounds-checking in GetInteractionProbabilities/GetSurvivalProbabilities - Use SurvivalProbability method instead of inline 1-InteractionProbability SIREN_Controller (Python): - Fix SetInteractions assertion to check physical process type, not injection - Fix save_int_params to produce per-event nested dicts instead of flat lists - Fix InputDarkNewsDecay to only compute min decay width for matching decays - Fix InputDarkNewsModel to reuse existing secondary processes instead of discarding - Replace exit(0) calls with ValueError/TypeError/RuntimeError - Fix get_material_model_path (non-existent function) with correct path construction - Fix typos: "interseted", "sedcondary" Tests: - Add Weighter_TEST.cxx for one_minus_exp_of_negative helper accuracy - Add test_controller_fixes.py covering all Python and Weighter fixes
1 parent 17e3529 commit 1999eb8

6 files changed

Lines changed: 449 additions & 37 deletions

File tree

projects/injection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ install(DIRECTORY "${PROJECT_SOURCE_DIR}/projects/injection/public/"
3838
)
3939

4040
#package_add_test(UnitTest_Injector ${PROJECT_SOURCE_DIR}/projects/injection/private/test/Injector_TEST.cxx)
41+
package_add_test(UnitTest_Weighter ${PROJECT_SOURCE_DIR}/projects/injection/private/test/Weighter_TEST.cxx)
4142
if(NOT ${CIBUILDWHEEL})
4243
package_add_test(UnitTest_CCM_HNL ${PROJECT_SOURCE_DIR}/projects/injection/private/test/CCM_HNL_TEST.cxx)
4344
target_link_libraries(UnitTest_CCM_HNL pybind11::embed)

projects/injection/private/Weighter.cxx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,9 @@ double Weighter::EventWeight(siren::dataclasses::InteractionTree const & tree) c
136136
}
137137

138138
std::vector<double> Weighter::GetInteractionProbabilities(siren::dataclasses::InteractionTree const & tree, int i_inj) const {
139-
// Returns the vector of interaction physical probabilities for each process
140-
// Since we are concerned only with the physical probability, we use the first injector since physical processes are the same for all injectors
141-
// HOWEVER the injection bounds will change based on the injector
142-
// so, we allow the user to specify which injector they are interseted in
139+
if(i_inj < 0 || static_cast<size_t>(i_inj) >= injectors.size()) {
140+
throw std::out_of_range("i_inj index out of range in GetInteractionProbabilities");
141+
}
143142

144143
std::vector<double> int_probs;
145144
for(auto const & datum : tree.tree) {
@@ -162,22 +161,23 @@ std::vector<double> Weighter::GetInteractionProbabilities(siren::dataclasses::In
162161
}
163162

164163
std::vector<double> Weighter::GetSurvivalProbabilities(siren::dataclasses::InteractionTree const & tree, int i_inj) const {
165-
// This allows the user to get the survival probabilities for each interaction
166-
// Useful in the case that secondary interactions are restricted to fiducial volumes
164+
if(i_inj < 0 || static_cast<size_t>(i_inj) >= injectors.size()) {
165+
throw std::out_of_range("i_inj index out of range in GetSurvivalProbabilities");
166+
}
167167

168168
std::vector<double> survival_probs;
169169
for(auto const & datum : tree.tree) {
170170
std::tuple<siren::math::Vector3D, siren::math::Vector3D> bounds;
171171
if(datum->depth() == 0) {
172-
std::get<0>(bounds) = datum->record.primary_initial_position; // start location
173-
std::get<1>(bounds) = std::get<0>(injectors[i_inj]->PrimaryInjectionBounds(datum->record)); // start of injection bounds
174-
survival_probs.push_back(1 - primary_process_weighters[i_inj]->InteractionProbability(bounds, datum->record));
172+
std::get<0>(bounds) = datum->record.primary_initial_position;
173+
std::get<1>(bounds) = std::get<0>(injectors[i_inj]->PrimaryInjectionBounds(datum->record));
174+
survival_probs.push_back(primary_process_weighters[i_inj]->SurvivalProbability(bounds, datum->record));
175175
}
176176
else {
177177
try {
178-
std::get<0>(bounds) = datum->record.primary_initial_position; // start location
179-
std::get<1>(bounds) = std::get<0>(injectors[i_inj]->SecondaryInjectionBounds(datum->record)); // start of injection bounds
180-
survival_probs.push_back(1 - secondary_process_weighter_maps[i_inj].at(datum->record.signature.primary_type)->InteractionProbability(bounds, datum->record));
178+
std::get<0>(bounds) = datum->record.primary_initial_position;
179+
std::get<1>(bounds) = std::get<0>(injectors[i_inj]->SecondaryInjectionBounds(datum->record));
180+
survival_probs.push_back(secondary_process_weighter_maps[i_inj].at(datum->record.signature.primary_type)->SurvivalProbability(bounds, datum->record));
181181
} catch(const std::out_of_range& oor) {
182182
std::cout << "Out of Range error: " << oor.what() << '\n';
183183
return {};
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include <cmath>
2+
#include <stdexcept>
3+
4+
#include <gtest/gtest.h>
5+
6+
// Forward-declare the helper functions from Weighter.tcc (linked via SIREN lib)
7+
// to avoid ODR violations from including the full header.
8+
namespace siren { namespace injection {
9+
double one_minus_exp_of_negative(double x);
10+
double log_one_minus_exp_of_negative(double x);
11+
}}
12+
13+
using namespace siren::injection;
14+
15+
// ---------------------------------------------------------------------------
16+
// one_minus_exp_of_negative
17+
// ---------------------------------------------------------------------------
18+
19+
TEST(WeighterHelpers, OneMinusExpSmallX) {
20+
double x = 1e-5;
21+
double result = one_minus_exp_of_negative(x);
22+
double exact = 1.0 - std::exp(-x);
23+
EXPECT_NEAR(result, exact, 1e-15);
24+
}
25+
26+
TEST(WeighterHelpers, OneMinusExpMediumX) {
27+
double x = 0.5;
28+
double result = one_minus_exp_of_negative(x);
29+
double exact = 1.0 - std::exp(-x);
30+
EXPECT_NEAR(result, exact, 1e-12);
31+
}
32+
33+
TEST(WeighterHelpers, OneMinusExpLargeX) {
34+
double x = 5.0;
35+
double result = one_minus_exp_of_negative(x);
36+
double exact = 1.0 - std::exp(-x);
37+
EXPECT_NEAR(result, exact, 1e-12);
38+
}
39+
40+
TEST(WeighterHelpers, OneMinusExpVerySmallX) {
41+
// Exercises the Taylor expansion branch
42+
double x = 1e-8;
43+
double result = one_minus_exp_of_negative(x);
44+
double exact = 1.0 - std::exp(-x);
45+
EXPECT_NEAR(result, exact, 1e-15);
46+
}
47+
48+
TEST(WeighterHelpers, OneMinusExpAtBranchPoint) {
49+
// Near the 0.1 branch boundary
50+
double x = 0.099;
51+
double result = one_minus_exp_of_negative(x);
52+
double exact = 1.0 - std::exp(-x);
53+
EXPECT_NEAR(result, exact, 1e-12);
54+
55+
x = 0.101;
56+
result = one_minus_exp_of_negative(x);
57+
exact = 1.0 - std::exp(-x);
58+
EXPECT_NEAR(result, exact, 1e-12);
59+
}
60+
61+
TEST(WeighterHelpers, OneMinusExpResultBounded) {
62+
// Result should always be in [0, 1) for non-negative x
63+
for(double x = 0.0; x <= 20.0; x += 0.1) {
64+
double result = one_minus_exp_of_negative(x);
65+
EXPECT_GE(result, 0.0) << "Failed at x=" << x;
66+
EXPECT_LT(result, 1.0) << "Failed at x=" << x;
67+
}
68+
}
69+
70+
// ---------------------------------------------------------------------------
71+
// log_one_minus_exp_of_negative
72+
// ---------------------------------------------------------------------------
73+
74+
TEST(WeighterHelpers, LogOneMinusExpSmallX) {
75+
double x = 1e-5;
76+
double result = log_one_minus_exp_of_negative(x);
77+
double exact = std::log(1.0 - std::exp(-x));
78+
EXPECT_NEAR(result, exact, 1e-10);
79+
}
80+
81+
TEST(WeighterHelpers, LogOneMinusExpMidX) {
82+
double x = 1.5;
83+
double result = log_one_minus_exp_of_negative(x);
84+
double exact = std::log(1.0 - std::exp(-x));
85+
EXPECT_NEAR(result, exact, 1e-12);
86+
}
87+
88+
TEST(WeighterHelpers, LogOneMinusExpLargeX) {
89+
// Exercises the exp-series branch (x > 3)
90+
double x = 5.0;
91+
double result = log_one_minus_exp_of_negative(x);
92+
double exact = std::log(1.0 - std::exp(-x));
93+
EXPECT_NEAR(result, exact, 1e-12);
94+
}
95+
96+
TEST(WeighterHelpers, LogOneMinusExpAtBranchPoints) {
97+
// Test near the 0.1 and 3.0 branch boundaries
98+
for(double x : {0.09, 0.11, 2.99, 3.01}) {
99+
double result = log_one_minus_exp_of_negative(x);
100+
double exact = std::log(1.0 - std::exp(-x));
101+
EXPECT_NEAR(result, exact, 1e-9) << "Failed at x=" << x;
102+
}
103+
}
104+
105+
TEST(WeighterHelpers, LogOneMinusExpIsNegative) {
106+
// log(1 - exp(-x)) is always negative for x > 0
107+
for(double x = 0.001; x <= 20.0; x += 0.1) {
108+
double result = log_one_minus_exp_of_negative(x);
109+
EXPECT_LT(result, 0.0) << "Failed at x=" << x;
110+
}
111+
}

projects/injection/public/SIREN/injection/Weighter.tcc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ double ProcessWeighter<ProcessType>::InteractionProbability(std::tuple<siren::ma
142142
return interaction_probability;
143143
}
144144

145+
template<typename ProcessType>
146+
double ProcessWeighter<ProcessType>::SurvivalProbability(std::tuple<siren::math::Vector3D, siren::math::Vector3D> const & bounds, siren::dataclasses::InteractionRecord const & record) const {
147+
double interaction_probability = InteractionProbability(bounds, record);
148+
double survival = 1.0 - interaction_probability;
149+
if(survival < 0.0) survival = 0.0;
150+
if(survival > 1.0) survival = 1.0;
151+
return survival;
152+
}
153+
145154
template<typename ProcessType>
146155
double ProcessWeighter<ProcessType>::NormalizedPositionProbability(std::tuple<siren::math::Vector3D, siren::math::Vector3D> const & bounds, siren::dataclasses::InteractionRecord const & record) const {
147156
using siren::detector::DetectorPosition;

python/SIREN_Controller.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def __init__(self, events_to_inject, experiment=None, detector_model_file=None,
6969
self.materials_model_file = materials_model_file
7070
if experiment is not None:
7171
# Find the density and materials files
72-
self.materials_model_file = _util.get_material_model_path(experiment)
73-
self.detector_model_file = _util.get_detector_model_path(experiment)
72+
detector_dir = _util.get_detector_model_path(experiment)
73+
self.materials_model_file = os.path.join(detector_dir, "materials.dat")
74+
self.detector_model_file = os.path.join(detector_dir, "densities.dat")
7475
elif (self.detector_model_file is None or self.materials_model_file is None):
75-
print("Must provide either an experiment name or both a detector model file and materials model file. Exiting")
76-
exit(0)
76+
raise ValueError("Must provide either an experiment name or both a detector model file and materials model file")
7777

7878
self.detector_model = _detector.DetectorModel()
7979
self.detector_model.LoadMaterialModel(self.materials_model_file)
@@ -277,20 +277,23 @@ def InputDarkNewsModel(self, primary_type, table_dir, upscattering=True, decay=T
277277
secondary_interaction_collections = []
278278
for secondary_type, decay_list in secondary_decays.items():
279279

280-
# Define a sedcondary injection distribution if necessary
280+
# Define a secondary injection distribution if necessary
281281
inj_sec_defined = False
282282
phys_sec_defined = False
283+
existing_inj_process = None
283284
for secondary_injection_process in self.secondary_injection_processes:
284285
if secondary_injection_process.primary_type == secondary_type:
285286
inj_sec_defined = True
287+
existing_inj_process = secondary_injection_process
286288
for secondary_physical_process in self.secondary_physical_processes:
287289
if secondary_physical_process.primary_type == secondary_type:
288290
phys_sec_defined = True
289291

290-
secondary_injection_process = _injection.SecondaryInjectionProcess()
291-
secondary_physical_process = _injection.PhysicalProcess()
292-
secondary_injection_process.primary_type = secondary_type
293-
secondary_physical_process.primary_type = secondary_type
292+
if inj_sec_defined:
293+
secondary_injection_process = existing_inj_process
294+
else:
295+
secondary_injection_process = _injection.SecondaryInjectionProcess()
296+
secondary_injection_process.primary_type = secondary_type
294297

295298
# Add the secondary position distribution
296299
if fid_vol_secondary and self.fid_vol is not None:
@@ -302,8 +305,12 @@ def InputDarkNewsModel(self, primary_type, table_dir, upscattering=True, decay=T
302305
_distributions.SecondaryPhysicalVertexDistribution()
303306
)
304307

305-
if not inj_sec_defined: self.secondary_injection_processes.append(secondary_injection_process)
306-
if not phys_sec_defined: self.secondary_physical_processes.append(secondary_physical_process)
308+
if not inj_sec_defined:
309+
self.secondary_injection_processes.append(secondary_injection_process)
310+
if not phys_sec_defined:
311+
secondary_physical_process = _injection.PhysicalProcess()
312+
secondary_physical_process.primary_type = secondary_type
313+
self.secondary_physical_processes.append(secondary_physical_process)
307314

308315
secondary_interaction_collections.append(
309316
_interactions.InteractionCollection(secondary_type, decay_list)
@@ -343,9 +350,9 @@ def InputDarkNewsDecay(self, primary_type, table_dir, **kwargs):
343350
decay.dec_case.nu_parent.pdgid
344351
):
345352
primary_decays.append(decay)
346-
total_decay_width = decay.TotalDecayWidth(primary_type)
347-
if total_decay_width < self.DN_min_decay_width:
348-
self.DN_min_decay_width = total_decay_width
353+
total_decay_width = decay.TotalDecayWidth(primary_type)
354+
if total_decay_width > 0 and total_decay_width < self.DN_min_decay_width:
355+
self.DN_min_decay_width = total_decay_width
349356
primary_interaction_collection = _interactions.InteractionCollection(
350357
primary_type, primary_decays
351358
)
@@ -377,8 +384,7 @@ def GetFiducialVolume(self):
377384
def GetVolumePositionDistributionFromSector(self, sector_name):
378385
geo = self.GetDetectorSectorGeometry(sector_name)
379386
if geo is None:
380-
print("Sector %s not found. Exiting"%sector_name)
381-
exit(0)
387+
raise ValueError("Sector %s not found" % sector_name)
382388
# the position is in geometry coordinates
383389
# must update to detector coordintes
384390
det_position = self.detector_model.GeoPositionToDetPosition(_detector.GeometryPosition(geo.placement.Position))
@@ -391,8 +397,7 @@ def GetVolumePositionDistributionFromSector(self, sector_name):
391397
sphere = _geometry.Sphere(det_placement,geo.Radius,geo.InnerRadius)
392398
return _distributions.SphereVolumePositionDistribution(sphere)
393399
else:
394-
print("Geometry type %s not supported for position distribution. Exiting"%str(type(geo)))
395-
exit(0)
400+
raise TypeError("Geometry type %s not supported for position distribution" % str(type(geo)))
396401

397402
def GetDetectorModelTargets(self):
398403
"""
@@ -449,7 +454,7 @@ def SetInteractions(
449454
if self.primary_physical_process.primary_type == _dataclasses.Particle.ParticleType.unknown:
450455
self.primary_physical_process.primary_type = primary_interaction_collection.GetPrimaryType()
451456
else:
452-
assert(self.primary_injection_process.primary_type == primary_interaction_collection.GetPrimaryType())
457+
assert(self.primary_physical_process.primary_type == primary_interaction_collection.GetPrimaryType())
453458
if self.primary_physical_process.interactions is None:
454459
self.primary_physical_process.interactions = primary_interaction_collection
455460
else:
@@ -484,11 +489,10 @@ def SetInteractions(
484489
[sec_phys.interactions, sec_ints])
485490
found_collection = True
486491
if not found_collection and(sec_inj.interactions is None or sec_phys.interactions is None):
487-
print(
488-
"Couldn't find cross section collection for secondary particle %s; Exiting"
492+
raise RuntimeError(
493+
"Couldn't find cross section collection for secondary particle %s"
489494
% record.signature.primary_type
490495
)
491-
exit(0)
492496

493497
# set the stopping condition of the injector with a python function
494498
# must accept two arguments, assumes first is datum and the second is the index of the secondary particle
@@ -636,12 +640,16 @@ def SaveEvents(self, filename, fill_tables_at_exit=True,
636640
"parent_idx",
637641
"num_daughters"]:
638642
datasets[k].append([])
643+
if save_int_params:
644+
datasets.setdefault("int_params", [])
645+
datasets["int_params"].append({})
639646
# loop over interactions
640647
for id, datum in enumerate(event.tree):
641648
if save_int_params:
642-
for param_name,param_value in datum.record.interaction_parameters.items():
643-
if ie==0: datasets[param_name] = []
644-
datasets[param_name].append(param_value)
649+
for param_name, param_value in datum.record.interaction_parameters.items():
650+
if param_name not in datasets["int_params"][-1]:
651+
datasets["int_params"][-1][param_name] = []
652+
datasets["int_params"][-1][param_name].append(param_value)
645653
datasets["vertex"][-1].append(np.array(datum.record.interaction_vertex,dtype=float))
646654
datasets["primary_initial_position"][-1].append(np.array(datum.record.primary_initial_position,dtype=float))
647655

0 commit comments

Comments
 (0)