Skip to content

Commit edc0428

Browse files
committed
support sphere volume distributions in python interface
1 parent a5f498c commit edc0428

1 file changed

Lines changed: 32 additions & 15 deletions

File tree

python/SIREN_Controller.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ def MergeInteractionCollections(primary_type,int_col_list):
3838

3939
# Parent python class for handling event generation and weighting
4040
class SIREN_Controller:
41-
def __init__(self, events_to_inject, experiment, seed=0):
41+
42+
def __init__(self, events_to_inject, experiment=None, detector_model_file=None, materials_model_file=None, seed=0):
4243
"""
4344
SIREN controller class constructor.
4445
:param int event_to_inject: number of events to generate
45-
:param str experiment: experiment name in string
46-
:param int seed: Optional random number generator seed
46+
:param str experiment: experiment name in string (default None)
47+
:param str detector_model_file: path to the detector model file (default None)
48+
:param str materials_model_file: path to the materials model file (default None)
49+
:param int seed: Optional random number generator seed (default 0)
4750
"""
4851

4952
self.global_start = time.time()
@@ -60,12 +63,19 @@ def __init__(self, events_to_inject, experiment, seed=0):
6063
# Empty list for our interaction trees
6164
self.events = []
6265

63-
# Find the density and materials files
64-
materials_file = _util.get_material_model_path(experiment)
65-
detector_model_file = _util.get_detector_model_path(experiment)
66+
67+
self.detector_model_file = detector_model_file
68+
self.materials_model_file = materials_model_file
69+
if experiment is not None:
70+
# Find the density and materials files
71+
self.materials_model_file = _util.get_material_model_path(experiment)
72+
self.detector_model_file = _util.get_detector_model_path(experiment)
73+
elif (self.detector_model_file is None or self.materials_model_file is None):
74+
print("Must provide either an experiment name or both a detector model file and materials model file. Exiting")
75+
exit(0)
6676

6777
self.detector_model = _detector.DetectorModel()
68-
self.detector_model.LoadMaterialModel(materials_file)
78+
self.detector_model.LoadMaterialModel(materials_model_file)
6979
self.detector_model.LoadDetectorModel(detector_model_file)
7080

7181
# Define the primary injection and physical process
@@ -343,8 +353,7 @@ def GetFiducialVolume(self):
343353
"""
344354
:return: identified fiducial volume for the experiment, None if not found
345355
"""
346-
detector_model_file = _util.get_detector_model_path(self.experiment)
347-
with open(detector_model_file) as file:
356+
with open(self.detector_model_file) as file:
348357
fiducial_line = None
349358
detector_line = None
350359
for line in file:
@@ -360,18 +369,25 @@ def GetFiducialVolume(self):
360369
return _detector.DetectorModel.ParseFiducialVolume(fiducial_line, detector_line)
361370
return None
362371

363-
def GetCylinderVolumePositionDistributionFromSector(self, sector_name):
372+
def GetVolumePositionDistributionFromSector(self, sector_name):
364373
geo = self.GetDetectorSectorGeometry(sector_name)
365374
if geo is None:
366375
print("Sector %s not found. Exiting"%sector_name)
367376
exit(0)
368-
# the position of this cylinder is in geometry coordinates
377+
# the position is in geometry coordinates
369378
# must update to detector coordintes
370379
det_position = self.detector_model.GeoPositionToDetPosition(_detector.GeometryPosition(geo.placement.Position))
371380
det_rotation = geo.placement.Quaternion
372381
det_placement = _geometry.Placement(det_position.get(), det_rotation)
373-
cylinder = _geometry.Cylinder(det_placement,geo.Radius,geo.InnerRadius,geo.Z)
374-
return _distributions.CylinderVolumePositionDistribution(cylinder)
382+
if type(geo)==_geometry.Cylinder:
383+
cylinder = _geometry.Cylinder(det_placement,geo.Radius,geo.InnerRadius,geo.Z)
384+
return _distributions.CylinderVolumePositionDistribution(cylinder)
385+
elif type(geo)==_geometry.Sphere:
386+
sphere = _geometry.Sphere(det_placement,geo.Radius,geo.InnerRadius)
387+
return _distributions.SphereVolumePositionDistribution(sphere)
388+
else:
389+
print("Geometry type %s not supported for position distribution. Exiting"%str(type(geo)))
390+
exit(0)
375391

376392
def GetDetectorModelTargets(self):
377393
"""
@@ -563,7 +579,8 @@ def LoadEvents(self, filename):
563579
# if the weighter exists, calculate the event weight too
564580
def SaveEvents(self, filename, fill_tables_at_exit=True,
565581
hdf5=True, parquet=True, siren_events=True, # filetypes to save events
566-
save_int_probs=False,save_int_params=False,save_survival_probs=False):
582+
save_int_probs=False,save_int_params=False,save_survival_probs=False,
583+
verbose=True):
567584

568585
if siren_events:
569586
_dataclasses.SaveInteractionTrees(self.events, filename)
@@ -589,7 +606,7 @@ def SaveEvents(self, filename, fill_tables_at_exit=True,
589606
if save_int_probs: datasets["int_probs"] = []
590607
if save_survival_probs: datasets["survival_probs"] = []
591608
for ie, event in enumerate(self.events):
592-
print("Saving Event %d/%d " % (ie, len(self.events)), end="\r")
609+
if verbose: print("Saving Event %d/%d " % (ie, len(self.events)), end="\r")
593610
t0 = time.time()
594611
datasets["event_weight"].append(self.weighter.EventWeight(event) if hasattr(self,"weighter") else 0)
595612
if save_int_probs:

0 commit comments

Comments
 (0)