@@ -89,7 +89,9 @@ def _set_stochastic(self, seed=None):
8989 attr_value = None
9090 if input_value is not None :
9191 if "factor" in input_name :
92- attr_value = self ._validate_factors (input_name , input_value )
92+ attr_value = self ._validate_factors (
93+ input_name , input_value , seed
94+ )
9395 elif input_name not in self .exception_list :
9496 if isinstance (input_value , tuple ):
9597 attr_value = self ._validate_tuple (input_name , input_value )
@@ -104,6 +106,7 @@ def _set_stochastic(self, seed=None):
104106 else :
105107 raise AssertionError (
106108 f"'{ input_name } ' must be a tuple, list, int, or float"
109+ "or a custom sampler"
107110 )
108111 else :
109112 attr_value = [getattr (self .obj , input_name )]
@@ -285,7 +288,7 @@ def _validate_scalar(self, input_name, input_value, getattr=getattr): # pylint:
285288 get_distribution ("normal" , self .__random_number_generator ),
286289 )
287290
288- def _validate_factors (self , input_name , input_value ):
291+ def _validate_factors (self , input_name , input_value , seed ):
289292 """
290293 Validate factor arguments.
291294
@@ -313,8 +316,12 @@ def _validate_factors(self, input_name, input_value):
313316 return self ._validate_tuple_factor (input_name , input_value )
314317 elif isinstance (input_value , list ):
315318 return self ._validate_list_factor (input_name , input_value )
319+ elif isinstance (input_value , CustomSampler ):
320+ return self ._validate_custom_sampler (input_name , input_value , seed )
316321 else :
317- raise AssertionError (f"`{ input_name } `: must be either a tuple or list" )
322+ raise AssertionError (
323+ f"`{ input_name } `: must be either a tuple or listor a custom sampler"
324+ )
318325
319326 def _validate_tuple_factor (self , input_name , factor_tuple ):
320327 """
@@ -463,7 +470,7 @@ def _validate_custom_sampler(self, input_name, sampler, seed=None):
463470 sampler .reset_seed (seed )
464471 except RuntimeError as e :
465472 raise RuntimeError (
466- f"An error occurred in the 'reset_seed' of { input_name } CustomSampler"
473+ f"An error occurred in the 'reset_seed' method of { input_name } CustomSampler"
467474 ) from e
468475
469476 return sampler
@@ -531,7 +538,7 @@ def dict_generator(self):
531538 generated_dict [arg ] = value .sample (n_samples = 1 )[0 ]
532539 except RuntimeError as e :
533540 raise RuntimeError (
534- f"An error occurred in the 'sample' of { arg } CustomSampler"
541+ f"An error occurred in the 'sample' method of { arg } CustomSampler"
535542 ) from e
536543 self .last_rnd_dict = generated_dict
537544 yield generated_dict
0 commit comments