@@ -2208,6 +2208,17 @@ def foo():
22082208 best_foo, best_reward = example, reward
22092209 """
22102210
2211+ class _AnnoymousHyperNameAccumulator :
2212+ """Name accumulator for annoymous hyper primitives."""
2213+
2214+ def __init__ (self ):
2215+ self .index = 0
2216+
2217+ def next_name (self ):
2218+ name = f'decision_{ self .index } '
2219+ self .index += 1
2220+ return name
2221+
22112222 def __init__ (self ,
22122223 where : Optional [Callable [[HyperPrimitive ], bool ]] = None ,
22132224 require_hyper_name : bool = False ,
@@ -2236,10 +2247,17 @@ def __init__(self,
22362247 self ._where = where
22372248 self ._require_hyper_name : bool = require_hyper_name
22382249 self ._name_to_hyper : Dict [Text , HyperPrimitive ] = dict ()
2239- self ._annoymous_hyper_name_accumulator : int = 0
2250+ self ._annoymous_hyper_name_accumulator = (
2251+ DynamicEvaluationContext ._AnnoymousHyperNameAccumulator ())
22402252 self ._hyper_dict = symbolic .Dict () if dna_spec is None else None
22412253 self ._dna_spec : Optional [geno .DNASpec ] = dna_spec
22422254 self ._per_thread = per_thread
2255+ self ._decision_getter = None
2256+
2257+ @property
2258+ def per_thread (self ) -> bool :
2259+ """Returns True if current context collects/applies decisions per thread."""
2260+ return self ._per_thread
22432261
22442262 @property
22452263 def dna_spec (self ) -> geno .DNASpec :
@@ -2257,8 +2275,7 @@ def _decision_name(self, hyper_primitive: HyperPrimitive) -> Text:
22572275 raise ValueError (
22582276 f'\' name\' must be specified for hyper '
22592277 f'primitive { hyper_primitive !r} .' )
2260- name = f'decision_{ self ._annoymous_hyper_name_accumulator } '
2261- self ._annoymous_hyper_name_accumulator += 1
2278+ name = self ._annoymous_hyper_name_accumulator .next_name ()
22622279 return name
22632280
22642281 @property
@@ -2296,71 +2313,80 @@ def collect(self):
22962313 f'`collect` cannot be called on a dynamic evaluation context that is '
22972314 f'using an external DNASpec: { self ._dna_spec } .' )
22982315
2299- with self ._collect () as sub_space :
2316+ # Ensure per-thread dynamic evaluation context will not be used
2317+ # together with process-level dynamic evaluation context.
2318+ _dynamic_evaluation_stack .ensure_thread_safety (self )
2319+
2320+ self ._hyper_dict = {}
2321+ with dynamic_evaluate (self .add_decision_point , per_thread = self ._per_thread ):
23002322 try :
2323+ # Push current context to dynamic evaluatoin stack so nested context
2324+ # can defer unresolved hyper primitive to current context.
2325+ _dynamic_evaluation_stack .push (self )
23012326 yield self ._hyper_dict
2302- finally :
2303- # NOTE(daiyip): when registering new hyper primitives in the sub-space,
2304- # the keys are already ensured not to conflict with the keys in current
2305- # search space. Therefore it's safe to update current space.
2306- self ._hyper_dict .update (sub_space )
23072327
2328+ finally :
23082329 # Invalidate DNASpec.
23092330 self ._dna_spec = None
23102331
2311- def _collect (self ):
2312- """A context manager for collecting hyper primitive within the scope."""
2313- hyper_dict = symbolic .Dict ()
2332+ # Pop current context from dynamic evaluatoin stack.
2333+ _dynamic_evaluation_stack .pop (self )
23142334
2315- def _register_child (c ):
2335+ def add_decision_point (self , hyper_primitive : HyperPrimitive ):
2336+ """Registers a parameter with current context and return its first value."""
2337+ def _add_child_decision_point (c ):
23162338 if isinstance (c , types .LambdaType ):
23172339 s = schema .get_signature (c )
23182340 if not s .args and not s .has_wildcard_args :
2319- with self ._collect () as child_hyper :
2341+ sub_context = DynamicEvaluationContext (
2342+ where = self ._where , per_thread = self ._per_thread )
2343+ sub_context ._annoymous_hyper_name_accumulator = ( # pylint: disable=protected-access
2344+ self ._annoymous_hyper_name_accumulator )
2345+ with sub_context .collect () as hyper_dict :
23202346 v = c ()
2321- return (v , child_hyper )
2347+ return (v , hyper_dict )
23222348 return (c , c )
23232349
2324- def _register_hyper_primitive (hyper_primitive ):
2325- """Registers a decision point from an hyper_primitive."""
2326- if self ._where and not self ._where (hyper_primitive ):
2327- # Skip hyper primitives that do not pass the `where` predicate.
2328- return hyper_primitive
2329-
2330- if isinstance (hyper_primitive , Template ):
2331- return hyper_primitive .value
2332-
2333- assert isinstance (hyper_primitive , HyperPrimitive ), hyper_primitive
2334- name = self ._decision_name (hyper_primitive )
2335- if isinstance (hyper_primitive , Choices ):
2336- candidate_values , candidates = zip (
2337- * [_register_child (c ) for c in hyper_primitive .candidates ])
2338- if hyper_primitive .choices_distinct :
2339- assert hyper_primitive .num_choices <= len (hyper_primitive .candidates )
2340- v = [candidate_values [i ] for i in range (hyper_primitive .num_choices )]
2341- else :
2342- v = [candidate_values [0 ]] * hyper_primitive .num_choices
2343- hyper_primitive = hyper_primitive .clone (deep = True , override = {
2344- 'candidates' : list (candidates )
2345- })
2346- first_value = v [0 ] if isinstance (hyper_primitive , ChoiceValue ) else v
2347- elif isinstance (hyper_primitive , Float ):
2348- first_value = hyper_primitive .min_value
2350+ if self ._where and not self ._where (hyper_primitive ):
2351+ # Delegate the resolution of hyper primitives that do not pass
2352+ # the `where` predicate to its parent context.
2353+ parent_context = _dynamic_evaluation_stack .get_parent (self )
2354+ if parent_context is not None :
2355+ return parent_context .add_decision_point (hyper_primitive )
2356+ return hyper_primitive
2357+
2358+ if isinstance (hyper_primitive , Template ):
2359+ return hyper_primitive .value
2360+
2361+ assert isinstance (hyper_primitive , HyperPrimitive ), hyper_primitive
2362+ name = self ._decision_name (hyper_primitive )
2363+ if isinstance (hyper_primitive , Choices ):
2364+ candidate_values , candidates = zip (
2365+ * [_add_child_decision_point (c ) for c in hyper_primitive .candidates ])
2366+ if hyper_primitive .choices_distinct :
2367+ assert hyper_primitive .num_choices <= len (hyper_primitive .candidates )
2368+ v = [candidate_values [i ] for i in range (hyper_primitive .num_choices )]
23492369 else :
2350- assert isinstance (hyper_primitive , CustomHyper ), hyper_primitive
2351- first_value = hyper_primitive .decode (hyper_primitive .first_dna ())
2370+ v = [candidate_values [0 ]] * hyper_primitive .num_choices
2371+ hyper_primitive = hyper_primitive .clone (deep = True , override = {
2372+ 'candidates' : list (candidates )
2373+ })
2374+ first_value = v [0 ] if isinstance (hyper_primitive , ChoiceValue ) else v
2375+ elif isinstance (hyper_primitive , Float ):
2376+ first_value = hyper_primitive .min_value
2377+ else :
2378+ assert isinstance (hyper_primitive , CustomHyper ), hyper_primitive
2379+ first_value = hyper_primitive .decode (hyper_primitive .first_dna ())
23522380
2353- if (name in self ._name_to_hyper
2354- and hyper_primitive != self ._name_to_hyper [name ]):
2355- raise ValueError (
2356- f'Found different hyper primitives under the same name { name !r} : '
2357- f'Instance1={ self ._name_to_hyper [name ]!r} , '
2358- f'Instance2={ hyper_primitive !r} .' )
2359- hyper_dict [name ] = hyper_primitive
2360- self ._name_to_hyper [name ] = hyper_primitive
2361- return first_value
2362- return dynamic_evaluate (
2363- _register_hyper_primitive , hyper_dict , per_thread = self ._per_thread )
2381+ if (name in self ._name_to_hyper
2382+ and hyper_primitive != self ._name_to_hyper [name ]):
2383+ raise ValueError (
2384+ f'Found different hyper primitives under the same name { name !r} : '
2385+ f'Instance1={ self ._name_to_hyper [name ]!r} , '
2386+ f'Instance2={ hyper_primitive !r} .' )
2387+ self ._hyper_dict [name ] = hyper_primitive
2388+ self ._name_to_hyper [name ] = hyper_primitive
2389+ return first_value
23642390
23652391 def _decision_getter_and_evaluation_finalizer (
23662392 self , decisions : Union [geno .DNA , List [Union [int , float , str ]]]):
@@ -2461,6 +2487,7 @@ def err_on_unused_decisions():
24612487 f'Found extra decision values that are not used: { remaining !r} ' )
24622488 return get_decision_by_position , err_on_unused_decisions
24632489
2490+ @contextlib .contextmanager
24642491 def apply (
24652492 self , decisions : Union [geno .DNA , List [Union [int , float , str ]]]):
24662493 """Context manager for applying decisions.
@@ -2482,65 +2509,159 @@ def fun():
24822509 decisions: A DNA or a list of numbers or strings as decisions for currrent
24832510 search space.
24842511
2485- Returns:
2486- Context manager for applying decisions to the function that defines the
2487- search space.
2512+ Yields:
2513+ None
24882514 """
24892515 if not isinstance (decisions , (geno .DNA , list )):
24902516 raise ValueError ('`decisions` should be a DNA or a list of numbers.' )
24912517
2492- get_decision , evaluation_finalizer = (
2518+ # Ensure per-thread dynamic evaluation context will not be used
2519+ # together with process-level dynamic evaluation context.
2520+ _dynamic_evaluation_stack .ensure_thread_safety (self )
2521+
2522+ get_current_decision , evaluation_finalizer = (
24932523 self ._decision_getter_and_evaluation_finalizer (decisions ))
24942524
2525+ has_errors = False
2526+ with dynamic_evaluate (self .evaluate , per_thread = self ._per_thread ):
2527+ try :
2528+ # Set decision getter for current decision.
2529+ self ._decision_getter = get_current_decision
2530+
2531+ # Push current context to dynamic evaluation stack so nested context
2532+ # can delegate evaluate to current context.
2533+ _dynamic_evaluation_stack .push (self )
2534+
2535+ yield
2536+ except Exception :
2537+ has_errors = True
2538+ raise
2539+ finally :
2540+ # Pop current context from dynamic evaluatoin stack.
2541+ _dynamic_evaluation_stack .pop (self )
2542+
2543+ # Reset decisions.
2544+ self ._decision_getter = None
2545+
2546+ # Call evaluation finalizer to make sure all decisions are used.
2547+ if not has_errors :
2548+ evaluation_finalizer ()
2549+
2550+ def evaluate (self , hyper_primitive : HyperPrimitive ):
2551+ """Evaluates a hyper primitive based on current decisions."""
2552+ if self ._decision_getter is None :
2553+ raise ValueError (
2554+ '`evaluate` needs to be called under the `apply` context.' )
2555+
2556+ get_current_decision = self ._decision_getter
24952557 def _apply_child (c ):
24962558 if isinstance (c , types .LambdaType ):
24972559 s = schema .get_signature (c )
24982560 if not s .args and not s .has_wildcard_args :
24992561 return c ()
25002562 return c
25012563
2502- def _apply_decision (hyper_primitive : HyperPrimitive ):
2503- """Apply a decision value to an hyper_primitive object."""
2504- if self ._where and not self ._where (hyper_primitive ):
2505- # Skip hyper primitives that do not pass the `where` predicate.
2506- return hyper_primitive
2507-
2508- if isinstance (hyper_primitive , Float ):
2509- return get_decision (hyper_primitive )
2510-
2511- if isinstance (hyper_primitive , CustomHyper ):
2512- return hyper_primitive .decode (geno .DNA (get_decision (hyper_primitive )))
2513-
2514- assert isinstance (hyper_primitive , Choices )
2515- value = symbolic .List ()
2516- for i in range (hyper_primitive .num_choices ):
2517- # NOTE(daiyip): during registering the hyper primitives when
2518- # constructing the search space, we will need to evaluate every
2519- # candidate in order to pick up sub search spaces correctly, which is
2520- # not necessary for `pg.DynamicEvaluationContext.apply`.
2521- value .append (_apply_child (
2522- hyper_primitive .candidates [get_decision (hyper_primitive , i )]))
2523- if isinstance (hyper_primitive , ChoiceValue ):
2524- assert len (value ) == 1
2525- value = value [0 ]
2526- return value
2527- return dynamic_evaluate (
2528- _apply_decision ,
2529- exit_fn = evaluation_finalizer ,
2530- per_thread = self ._per_thread )
2564+ if self ._where and not self ._where (hyper_primitive ):
2565+ # Delegate the resolution of hyper primitives that do not pass
2566+ # the `where` predicate to its parent context.
2567+ parent_context = _dynamic_evaluation_stack .get_parent (self )
2568+ if parent_context is not None :
2569+ return parent_context .evaluate (hyper_primitive )
2570+ return hyper_primitive
2571+
2572+ if isinstance (hyper_primitive , Float ):
2573+ return get_current_decision (hyper_primitive )
2574+
2575+ if isinstance (hyper_primitive , CustomHyper ):
2576+ return hyper_primitive .decode (
2577+ geno .DNA (get_current_decision (hyper_primitive )))
2578+
2579+ assert isinstance (hyper_primitive , Choices ), hyper_primitive
2580+ value = symbolic .List ()
2581+ for i in range (hyper_primitive .num_choices ):
2582+ # NOTE(daiyip): during registering the hyper primitives when
2583+ # constructing the search space, we will need to evaluate every
2584+ # candidate in order to pick up sub search spaces correctly, which is
2585+ # not necessary for `pg.DynamicEvaluationContext.apply`.
2586+ value .append (_apply_child (
2587+ hyper_primitive .candidates [get_current_decision (hyper_primitive , i )]))
2588+ if isinstance (hyper_primitive , ChoiceValue ):
2589+ assert len (value ) == 1
2590+ value = value [0 ]
2591+ return value
2592+
2593+
2594+ # We maintain a stack of dynamic evaluation context for support search space
2595+ # combination
2596+ class _DynamicEvaluationStack :
2597+ """Dynamic evaluation stack used for dealing with nested evaluation."""
2598+
2599+ _TLS_KEY = 'dynamic_evaluation_stack'
2600+
2601+ def __init__ (self ):
2602+ self ._global_stack = []
2603+
2604+ def ensure_thread_safety (self , context : DynamicEvaluationContext ):
2605+ if ((context .per_thread and self ._global_stack )
2606+ or (not context .per_thread and self ._local_stack )):
2607+ raise ValueError (
2608+ 'Nested dynamic evaluation contexts must be either all per-thread '
2609+ 'or all process-wise. Please check the `per_thread` argument of '
2610+ 'the `pg.hyper.DynamicEvaluationContext` objects being used.' )
2611+
2612+ @property
2613+ def _local_stack (self ):
2614+ """Returns thread-local stack."""
2615+ stack = getattr (_thread_local_state , self ._TLS_KEY , None )
2616+ if stack is None :
2617+ stack = []
2618+ setattr (_thread_local_state , self ._TLS_KEY , stack )
2619+ return stack
2620+
2621+ def push (self , context : DynamicEvaluationContext ):
2622+ """Pushes the context to the stack."""
2623+ stack = self ._local_stack if context .per_thread else self ._global_stack
2624+ stack .append (context )
2625+
2626+ def pop (self , context : DynamicEvaluationContext ):
2627+ """Pops the context from the stack."""
2628+ stack = self ._local_stack if context .per_thread else self ._global_stack
2629+ assert stack
2630+ stack_top = stack .pop (- 1 )
2631+ assert stack_top is context , (stack_top , context )
2632+
2633+ def get_parent (
2634+ self ,
2635+ context : DynamicEvaluationContext ) -> Optional [DynamicEvaluationContext ]:
2636+ """Returns the parent context of the input context."""
2637+ stack = self ._local_stack if context .per_thread else self ._global_stack
2638+ parent = None
2639+ for i in reversed (range (1 , len (stack ))):
2640+ if context is stack [i ]:
2641+ parent = stack [i - 1 ]
2642+ break
2643+ return parent
2644+
2645+
2646+ # System-wise dynamic evaluation stack.
2647+ _dynamic_evaluation_stack = _DynamicEvaluationStack ()
25312648
25322649
25332650def trace (
25342651 fun : Callable [[], Any ],
2652+ * ,
2653+ where : Optional [Callable [[HyperPrimitive ], bool ]] = None ,
25352654 require_hyper_name : bool = False ,
2536- per_thread : bool = True
2537- ) -> DynamicEvaluationContext :
2655+ per_thread : bool = True ) -> DynamicEvaluationContext :
25382656 """Trace the hyper primitives called within a function by executing it.
25392657
25402658 See examples in :class:`pyglove.hyper.DynamicEvaluationContext`.
25412659
25422660 Args:
25432661 fun: Function in which the search space is defined.
2662+ where: A callable object that decide whether a hyper primitive should be
2663+ included when being instantiated under `collect`.
2664+ If None, all hyper primitives under `collect` will be included.
25442665 require_hyper_name: If True, all hyper primitives defined in this scope
25452666 will need to carry their names, which is usually a good idea when the
25462667 function that instantiates the hyper primtives need to be called multiple
@@ -2552,7 +2673,7 @@ def trace(
25522673 An DynamicEvaluationContext that can be passed to `pg.sample`.
25532674 """
25542675 context = DynamicEvaluationContext (
2555- require_hyper_name = require_hyper_name , per_thread = per_thread )
2676+ where = where , require_hyper_name = require_hyper_name , per_thread = per_thread )
25562677 with context .collect ():
25572678 fun ()
25582679 return context
0 commit comments