@@ -38,12 +38,15 @@ def MergeInteractionCollections(primary_type,int_col_list):
3838
3939# Parent python class for handling event generation and weighting
4040class SIREN_Controller :
41- def __init__ (self , events_to_inject , experiment , seed = 0 ):
41+
42+ def __init__ (self , events_to_inject , experiment = None , detector_model_file = None , materials_model_file = None , seed = 0 ):
4243 """
4344 SIREN controller class constructor.
4445 :param int event_to_inject: number of events to generate
45- :param str experiment: experiment name in string
46- :param int seed: Optional random number generator seed
46+ :param str experiment: experiment name in string (default None)
47+ :param str detector_model_file: path to the detector model file (default None)
48+ :param str materials_model_file: path to the materials model file (default None)
49+ :param int seed: Optional random number generator seed (default 0)
4750 """
4851
4952 self .global_start = time .time ()
@@ -60,12 +63,19 @@ def __init__(self, events_to_inject, experiment, seed=0):
6063 # Empty list for our interaction trees
6164 self .events = []
6265
63- # Find the density and materials files
64- materials_file = _util .get_material_model_path (experiment )
65- detector_model_file = _util .get_detector_model_path (experiment )
66+
67+ self .detector_model_file = detector_model_file
68+ self .materials_model_file = materials_model_file
69+ if experiment is not None :
70+ # Find the density and materials files
71+ self .materials_model_file = _util .get_material_model_path (experiment )
72+ self .detector_model_file = _util .get_detector_model_path (experiment )
73+ elif (self .detector_model_file is None or self .materials_model_file is None ):
74+ print ("Must provide either an experiment name or both a detector model file and materials model file. Exiting" )
75+ exit (0 )
6676
6777 self .detector_model = _detector .DetectorModel ()
68- self .detector_model .LoadMaterialModel (materials_file )
78+ self .detector_model .LoadMaterialModel (materials_model_file )
6979 self .detector_model .LoadDetectorModel (detector_model_file )
7080
7181 # Define the primary injection and physical process
@@ -343,8 +353,7 @@ def GetFiducialVolume(self):
343353 """
344354 :return: identified fiducial volume for the experiment, None if not found
345355 """
346- detector_model_file = _util .get_detector_model_path (self .experiment )
347- with open (detector_model_file ) as file :
356+ with open (self .detector_model_file ) as file :
348357 fiducial_line = None
349358 detector_line = None
350359 for line in file :
@@ -360,18 +369,25 @@ def GetFiducialVolume(self):
360369 return _detector .DetectorModel .ParseFiducialVolume (fiducial_line , detector_line )
361370 return None
362371
363- def GetCylinderVolumePositionDistributionFromSector (self , sector_name ):
372+ def GetVolumePositionDistributionFromSector (self , sector_name ):
364373 geo = self .GetDetectorSectorGeometry (sector_name )
365374 if geo is None :
366375 print ("Sector %s not found. Exiting" % sector_name )
367376 exit (0 )
368- # the position of this cylinder is in geometry coordinates
377+ # the position is in geometry coordinates
369378 # must update to detector coordintes
370379 det_position = self .detector_model .GeoPositionToDetPosition (_detector .GeometryPosition (geo .placement .Position ))
371380 det_rotation = geo .placement .Quaternion
372381 det_placement = _geometry .Placement (det_position .get (), det_rotation )
373- cylinder = _geometry .Cylinder (det_placement ,geo .Radius ,geo .InnerRadius ,geo .Z )
374- return _distributions .CylinderVolumePositionDistribution (cylinder )
382+ if type (geo )== _geometry .Cylinder :
383+ cylinder = _geometry .Cylinder (det_placement ,geo .Radius ,geo .InnerRadius ,geo .Z )
384+ return _distributions .CylinderVolumePositionDistribution (cylinder )
385+ elif type (geo )== _geometry .Sphere :
386+ sphere = _geometry .Sphere (det_placement ,geo .Radius ,geo .InnerRadius )
387+ return _distributions .SphereVolumePositionDistribution (sphere )
388+ else :
389+ print ("Geometry type %s not supported for position distribution. Exiting" % str (type (geo )))
390+ exit (0 )
375391
376392 def GetDetectorModelTargets (self ):
377393 """
@@ -563,7 +579,8 @@ def LoadEvents(self, filename):
563579 # if the weighter exists, calculate the event weight too
564580 def SaveEvents (self , filename , fill_tables_at_exit = True ,
565581 hdf5 = True , parquet = True , siren_events = True , # filetypes to save events
566- save_int_probs = False ,save_int_params = False ,save_survival_probs = False ):
582+ save_int_probs = False ,save_int_params = False ,save_survival_probs = False ,
583+ verbose = True ):
567584
568585 if siren_events :
569586 _dataclasses .SaveInteractionTrees (self .events , filename )
@@ -589,7 +606,7 @@ def SaveEvents(self, filename, fill_tables_at_exit=True,
589606 if save_int_probs : datasets ["int_probs" ] = []
590607 if save_survival_probs : datasets ["survival_probs" ] = []
591608 for ie , event in enumerate (self .events ):
592- print ("Saving Event %d/%d " % (ie , len (self .events )), end = "\r " )
609+ if verbose : print ("Saving Event %d/%d " % (ie , len (self .events )), end = "\r " )
593610 t0 = time .time ()
594611 datasets ["event_weight" ].append (self .weighter .EventWeight (event ) if hasattr (self ,"weighter" ) else 0 )
595612 if save_int_probs :
0 commit comments