88import numpy as np
99
1010from rocketpy .mathutils .function import Function
11+ from rocketpy .stochastic .custom_sampler import CustomSampler
1112
1213from ..tools import get_distribution
1314
@@ -96,6 +97,10 @@ def _set_stochastic(self, seed=None):
9697 attr_value = self ._validate_list (input_name , input_value )
9798 elif isinstance (input_value , (int , float )):
9899 attr_value = self ._validate_scalar (input_name , input_value )
100+ elif isinstance (input_value , CustomSampler ):
101+ attr_value = self ._validate_custom_sampler (
102+ input_name , input_value , seed
103+ )
99104 else :
100105 raise AssertionError (
101106 f"'{ input_name } ' must be a tuple, list, int, or float"
@@ -436,6 +441,33 @@ def _validate_positive_int_list(self, input_name, input_value):
436441 isinstance (member , int ) and member >= 0 for member in input_value
437442 ), f"`{ input_name } ` must be a list of positive integers"
438443
444+ def _validate_custom_sampler (self , input_name , sampler , seed = None ):
445+ """
446+ Validate a custom sampler.
447+
448+ Parameters
449+ ----------
450+ input_name : str
451+ Name of the input argument.
452+ sampler : CustomSampler object
453+ Custom sampler provided by the user
454+ seed : int, optional
455+ Seed for the random number generator. The default is None
456+
457+ Raises
458+ ------
459+ AssertionError
460+ If the input is not in a valid format.
461+ """
462+ try :
463+ sampler .reset_seed (seed )
464+ except RuntimeError as e :
465+ raise RuntimeError (
466+ f"An error occurred in the 'reset_seed' of { input_name } CustomSampler"
467+ ) from e
468+
469+ return sampler
470+
439471 def _validate_airfoil (self , airfoil ):
440472 """
441473 Validate airfoil input.
@@ -490,9 +522,17 @@ def dict_generator(self):
490522 generated_dict = {}
491523 for arg , value in self .__dict__ .items ():
492524 if isinstance (value , tuple ):
493- generated_dict [arg ] = value [- 1 ](value [0 ], value [1 ])
525+ dist_sampler = value [- 1 ]
526+ generated_dict [arg ] = dist_sampler (value [0 ], value [1 ])
494527 elif isinstance (value , list ):
495528 generated_dict [arg ] = choice (value ) if value else value
529+ elif isinstance (value , CustomSampler ):
530+ try :
531+ generated_dict [arg ] = value .sample (n_samples = 1 )[0 ]
532+ except RuntimeError as e :
533+ raise RuntimeError (
534+ f"An error occurred in the 'sample' of { arg } CustomSampler"
535+ ) from e
496536 self .last_rnd_dict = generated_dict
497537 yield generated_dict
498538
@@ -527,6 +567,12 @@ def format_attribute(attr, value):
527567 f"{ nominal_value :.5f} ± "
528568 f"{ std_dev :.5f} ({ dist_func .__name__ } )"
529569 )
570+ elif isinstance (value , CustomSampler ):
571+ sampler_name = type (value ).__name__
572+ return (
573+ f"\t { attr .ljust (max_str_length )} "
574+ f"\t { sampler_name .ljust (max_str_length )} "
575+ )
530576 return None
531577
532578 attributes = {k : v for k , v in self .__dict__ .items () if not k .startswith ("_" )}
@@ -550,6 +596,9 @@ def format_attribute(attr, value):
550596 list_attributes = [
551597 attr for attr , val in items if isinstance (val , list ) and len (val ) > 1
552598 ]
599+ custom_attributes = [
600+ attr for attr , val in items if isinstance (val , CustomSampler )
601+ ]
553602
554603 if constant_attributes :
555604 report .append ("\n Constant Attributes:" )
@@ -568,5 +617,10 @@ def format_attribute(attr, value):
568617 report .extend (
569618 format_attribute (attr , attributes [attr ]) for attr in list_attributes
570619 )
620+ if custom_attributes :
621+ report .append ("\n Stochastic Attributes with Custom user samplers:" )
622+ report .extend (
623+ format_attribute (attr , attributes [attr ]) for attr in custom_attributes
624+ )
571625
572626 print ("\n " .join (filter (None , report )))
0 commit comments