Skip to content

Commit 81d2f1a

Browse files
daiyippyglove authors
authored andcommitted
Dynamic evaluation to support divide-and-conquer on a search space.
Example: ``` def fun(): return pg.oneof([1, 2], hints='ssd1') + pg.oneof([3, 4], hints='ssd2') ssd1 = pg.hyper.DynamicEvaluationContext(where=lambda x: x.hints == 'ssd1') ssd2 = pg.hyper.DynamicEvaluationContext(where=lambda x: x.hints == 'ssd2') # Partitioning the search space into ssd1 and ssd2. with ssd1.collect(): with ssd2.collect(): fun() # Nested search. for ex1, f1 in pg.sample(ssd1, algorithm1): rs = [] for ex2, f2 in pg.sample(ssd2, algorithm2): with ex1(): with ex2(): r = fun() f2(r) rs.append(r) f1(sum(rs)) ``` PiperOrigin-RevId: 480115321
1 parent 31990e5 commit 81d2f1a

File tree

2 files changed

+343
-89
lines changed

2 files changed

+343
-89
lines changed

pyglove/core/hyper.py

Lines changed: 210 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,17 @@ def foo():
22082208
best_foo, best_reward = example, reward
22092209
"""
22102210

2211+
class _AnnoymousHyperNameAccumulator:
2212+
"""Name accumulator for annoymous hyper primitives."""
2213+
2214+
def __init__(self):
2215+
self.index = 0
2216+
2217+
def next_name(self):
2218+
name = f'decision_{self.index}'
2219+
self.index += 1
2220+
return name
2221+
22112222
def __init__(self,
22122223
where: Optional[Callable[[HyperPrimitive], bool]] = None,
22132224
require_hyper_name: bool = False,
@@ -2236,10 +2247,17 @@ def __init__(self,
22362247
self._where = where
22372248
self._require_hyper_name: bool = require_hyper_name
22382249
self._name_to_hyper: Dict[Text, HyperPrimitive] = dict()
2239-
self._annoymous_hyper_name_accumulator: int = 0
2250+
self._annoymous_hyper_name_accumulator = (
2251+
DynamicEvaluationContext._AnnoymousHyperNameAccumulator())
22402252
self._hyper_dict = symbolic.Dict() if dna_spec is None else None
22412253
self._dna_spec: Optional[geno.DNASpec] = dna_spec
22422254
self._per_thread = per_thread
2255+
self._decision_getter = None
2256+
2257+
@property
2258+
def per_thread(self) -> bool:
2259+
"""Returns True if current context collects/applies decisions per thread."""
2260+
return self._per_thread
22432261

22442262
@property
22452263
def dna_spec(self) -> geno.DNASpec:
@@ -2257,8 +2275,7 @@ def _decision_name(self, hyper_primitive: HyperPrimitive) -> Text:
22572275
raise ValueError(
22582276
f'\'name\' must be specified for hyper '
22592277
f'primitive {hyper_primitive!r}.')
2260-
name = f'decision_{self._annoymous_hyper_name_accumulator}'
2261-
self._annoymous_hyper_name_accumulator += 1
2278+
name = self._annoymous_hyper_name_accumulator.next_name()
22622279
return name
22632280

22642281
@property
@@ -2296,71 +2313,80 @@ def collect(self):
22962313
f'`collect` cannot be called on a dynamic evaluation context that is '
22972314
f'using an external DNASpec: {self._dna_spec}.')
22982315

2299-
with self._collect() as sub_space:
2316+
# Ensure per-thread dynamic evaluation context will not be used
2317+
# together with process-level dynamic evaluation context.
2318+
_dynamic_evaluation_stack.ensure_thread_safety(self)
2319+
2320+
self._hyper_dict = {}
2321+
with dynamic_evaluate(self.add_decision_point, per_thread=self._per_thread):
23002322
try:
2323+
# Push current context to dynamic evaluatoin stack so nested context
2324+
# can defer unresolved hyper primitive to current context.
2325+
_dynamic_evaluation_stack.push(self)
23012326
yield self._hyper_dict
2302-
finally:
2303-
# NOTE(daiyip): when registering new hyper primitives in the sub-space,
2304-
# the keys are already ensured not to conflict with the keys in current
2305-
# search space. Therefore it's safe to update current space.
2306-
self._hyper_dict.update(sub_space)
23072327

2328+
finally:
23082329
# Invalidate DNASpec.
23092330
self._dna_spec = None
23102331

2311-
def _collect(self):
2312-
"""A context manager for collecting hyper primitive within the scope."""
2313-
hyper_dict = symbolic.Dict()
2332+
# Pop current context from dynamic evaluatoin stack.
2333+
_dynamic_evaluation_stack.pop(self)
23142334

2315-
def _register_child(c):
2335+
def add_decision_point(self, hyper_primitive: HyperPrimitive):
2336+
"""Registers a parameter with current context and return its first value."""
2337+
def _add_child_decision_point(c):
23162338
if isinstance(c, types.LambdaType):
23172339
s = schema.get_signature(c)
23182340
if not s.args and not s.has_wildcard_args:
2319-
with self._collect() as child_hyper:
2341+
sub_context = DynamicEvaluationContext(
2342+
where=self._where, per_thread=self._per_thread)
2343+
sub_context._annoymous_hyper_name_accumulator = ( # pylint: disable=protected-access
2344+
self._annoymous_hyper_name_accumulator)
2345+
with sub_context.collect() as hyper_dict:
23202346
v = c()
2321-
return (v, child_hyper)
2347+
return (v, hyper_dict)
23222348
return (c, c)
23232349

2324-
def _register_hyper_primitive(hyper_primitive):
2325-
"""Registers a decision point from an hyper_primitive."""
2326-
if self._where and not self._where(hyper_primitive):
2327-
# Skip hyper primitives that do not pass the `where` predicate.
2328-
return hyper_primitive
2329-
2330-
if isinstance(hyper_primitive, Template):
2331-
return hyper_primitive.value
2332-
2333-
assert isinstance(hyper_primitive, HyperPrimitive), hyper_primitive
2334-
name = self._decision_name(hyper_primitive)
2335-
if isinstance(hyper_primitive, Choices):
2336-
candidate_values, candidates = zip(
2337-
*[_register_child(c) for c in hyper_primitive.candidates])
2338-
if hyper_primitive.choices_distinct:
2339-
assert hyper_primitive.num_choices <= len(hyper_primitive.candidates)
2340-
v = [candidate_values[i] for i in range(hyper_primitive.num_choices)]
2341-
else:
2342-
v = [candidate_values[0]] * hyper_primitive.num_choices
2343-
hyper_primitive = hyper_primitive.clone(deep=True, override={
2344-
'candidates': list(candidates)
2345-
})
2346-
first_value = v[0] if isinstance(hyper_primitive, ChoiceValue) else v
2347-
elif isinstance(hyper_primitive, Float):
2348-
first_value = hyper_primitive.min_value
2350+
if self._where and not self._where(hyper_primitive):
2351+
# Delegate the resolution of hyper primitives that do not pass
2352+
# the `where` predicate to its parent context.
2353+
parent_context = _dynamic_evaluation_stack.get_parent(self)
2354+
if parent_context is not None:
2355+
return parent_context.add_decision_point(hyper_primitive)
2356+
return hyper_primitive
2357+
2358+
if isinstance(hyper_primitive, Template):
2359+
return hyper_primitive.value
2360+
2361+
assert isinstance(hyper_primitive, HyperPrimitive), hyper_primitive
2362+
name = self._decision_name(hyper_primitive)
2363+
if isinstance(hyper_primitive, Choices):
2364+
candidate_values, candidates = zip(
2365+
*[_add_child_decision_point(c) for c in hyper_primitive.candidates])
2366+
if hyper_primitive.choices_distinct:
2367+
assert hyper_primitive.num_choices <= len(hyper_primitive.candidates)
2368+
v = [candidate_values[i] for i in range(hyper_primitive.num_choices)]
23492369
else:
2350-
assert isinstance(hyper_primitive, CustomHyper), hyper_primitive
2351-
first_value = hyper_primitive.decode(hyper_primitive.first_dna())
2370+
v = [candidate_values[0]] * hyper_primitive.num_choices
2371+
hyper_primitive = hyper_primitive.clone(deep=True, override={
2372+
'candidates': list(candidates)
2373+
})
2374+
first_value = v[0] if isinstance(hyper_primitive, ChoiceValue) else v
2375+
elif isinstance(hyper_primitive, Float):
2376+
first_value = hyper_primitive.min_value
2377+
else:
2378+
assert isinstance(hyper_primitive, CustomHyper), hyper_primitive
2379+
first_value = hyper_primitive.decode(hyper_primitive.first_dna())
23522380

2353-
if (name in self._name_to_hyper
2354-
and hyper_primitive != self._name_to_hyper[name]):
2355-
raise ValueError(
2356-
f'Found different hyper primitives under the same name {name!r}: '
2357-
f'Instance1={self._name_to_hyper[name]!r}, '
2358-
f'Instance2={hyper_primitive!r}.')
2359-
hyper_dict[name] = hyper_primitive
2360-
self._name_to_hyper[name] = hyper_primitive
2361-
return first_value
2362-
return dynamic_evaluate(
2363-
_register_hyper_primitive, hyper_dict, per_thread=self._per_thread)
2381+
if (name in self._name_to_hyper
2382+
and hyper_primitive != self._name_to_hyper[name]):
2383+
raise ValueError(
2384+
f'Found different hyper primitives under the same name {name!r}: '
2385+
f'Instance1={self._name_to_hyper[name]!r}, '
2386+
f'Instance2={hyper_primitive!r}.')
2387+
self._hyper_dict[name] = hyper_primitive
2388+
self._name_to_hyper[name] = hyper_primitive
2389+
return first_value
23642390

23652391
def _decision_getter_and_evaluation_finalizer(
23662392
self, decisions: Union[geno.DNA, List[Union[int, float, str]]]):
@@ -2461,6 +2487,7 @@ def err_on_unused_decisions():
24612487
f'Found extra decision values that are not used: {remaining!r}')
24622488
return get_decision_by_position, err_on_unused_decisions
24632489

2490+
@contextlib.contextmanager
24642491
def apply(
24652492
self, decisions: Union[geno.DNA, List[Union[int, float, str]]]):
24662493
"""Context manager for applying decisions.
@@ -2482,65 +2509,159 @@ def fun():
24822509
decisions: A DNA or a list of numbers or strings as decisions for currrent
24832510
search space.
24842511
2485-
Returns:
2486-
Context manager for applying decisions to the function that defines the
2487-
search space.
2512+
Yields:
2513+
None
24882514
"""
24892515
if not isinstance(decisions, (geno.DNA, list)):
24902516
raise ValueError('`decisions` should be a DNA or a list of numbers.')
24912517

2492-
get_decision, evaluation_finalizer = (
2518+
# Ensure per-thread dynamic evaluation context will not be used
2519+
# together with process-level dynamic evaluation context.
2520+
_dynamic_evaluation_stack.ensure_thread_safety(self)
2521+
2522+
get_current_decision, evaluation_finalizer = (
24932523
self._decision_getter_and_evaluation_finalizer(decisions))
24942524

2525+
has_errors = False
2526+
with dynamic_evaluate(self.evaluate, per_thread=self._per_thread):
2527+
try:
2528+
# Set decision getter for current decision.
2529+
self._decision_getter = get_current_decision
2530+
2531+
# Push current context to dynamic evaluation stack so nested context
2532+
# can delegate evaluate to current context.
2533+
_dynamic_evaluation_stack.push(self)
2534+
2535+
yield
2536+
except Exception:
2537+
has_errors = True
2538+
raise
2539+
finally:
2540+
# Pop current context from dynamic evaluatoin stack.
2541+
_dynamic_evaluation_stack.pop(self)
2542+
2543+
# Reset decisions.
2544+
self._decision_getter = None
2545+
2546+
# Call evaluation finalizer to make sure all decisions are used.
2547+
if not has_errors:
2548+
evaluation_finalizer()
2549+
2550+
def evaluate(self, hyper_primitive: HyperPrimitive):
2551+
"""Evaluates a hyper primitive based on current decisions."""
2552+
if self._decision_getter is None:
2553+
raise ValueError(
2554+
'`evaluate` needs to be called under the `apply` context.')
2555+
2556+
get_current_decision = self._decision_getter
24952557
def _apply_child(c):
24962558
if isinstance(c, types.LambdaType):
24972559
s = schema.get_signature(c)
24982560
if not s.args and not s.has_wildcard_args:
24992561
return c()
25002562
return c
25012563

2502-
def _apply_decision(hyper_primitive: HyperPrimitive):
2503-
"""Apply a decision value to an hyper_primitive object."""
2504-
if self._where and not self._where(hyper_primitive):
2505-
# Skip hyper primitives that do not pass the `where` predicate.
2506-
return hyper_primitive
2507-
2508-
if isinstance(hyper_primitive, Float):
2509-
return get_decision(hyper_primitive)
2510-
2511-
if isinstance(hyper_primitive, CustomHyper):
2512-
return hyper_primitive.decode(geno.DNA(get_decision(hyper_primitive)))
2513-
2514-
assert isinstance(hyper_primitive, Choices)
2515-
value = symbolic.List()
2516-
for i in range(hyper_primitive.num_choices):
2517-
# NOTE(daiyip): during registering the hyper primitives when
2518-
# constructing the search space, we will need to evaluate every
2519-
# candidate in order to pick up sub search spaces correctly, which is
2520-
# not necessary for `pg.DynamicEvaluationContext.apply`.
2521-
value.append(_apply_child(
2522-
hyper_primitive.candidates[get_decision(hyper_primitive, i)]))
2523-
if isinstance(hyper_primitive, ChoiceValue):
2524-
assert len(value) == 1
2525-
value = value[0]
2526-
return value
2527-
return dynamic_evaluate(
2528-
_apply_decision,
2529-
exit_fn=evaluation_finalizer,
2530-
per_thread=self._per_thread)
2564+
if self._where and not self._where(hyper_primitive):
2565+
# Delegate the resolution of hyper primitives that do not pass
2566+
# the `where` predicate to its parent context.
2567+
parent_context = _dynamic_evaluation_stack.get_parent(self)
2568+
if parent_context is not None:
2569+
return parent_context.evaluate(hyper_primitive)
2570+
return hyper_primitive
2571+
2572+
if isinstance(hyper_primitive, Float):
2573+
return get_current_decision(hyper_primitive)
2574+
2575+
if isinstance(hyper_primitive, CustomHyper):
2576+
return hyper_primitive.decode(
2577+
geno.DNA(get_current_decision(hyper_primitive)))
2578+
2579+
assert isinstance(hyper_primitive, Choices), hyper_primitive
2580+
value = symbolic.List()
2581+
for i in range(hyper_primitive.num_choices):
2582+
# NOTE(daiyip): during registering the hyper primitives when
2583+
# constructing the search space, we will need to evaluate every
2584+
# candidate in order to pick up sub search spaces correctly, which is
2585+
# not necessary for `pg.DynamicEvaluationContext.apply`.
2586+
value.append(_apply_child(
2587+
hyper_primitive.candidates[get_current_decision(hyper_primitive, i)]))
2588+
if isinstance(hyper_primitive, ChoiceValue):
2589+
assert len(value) == 1
2590+
value = value[0]
2591+
return value
2592+
2593+
2594+
# We maintain a stack of dynamic evaluation context for support search space
2595+
# combination
2596+
class _DynamicEvaluationStack:
2597+
"""Dynamic evaluation stack used for dealing with nested evaluation."""
2598+
2599+
_TLS_KEY = 'dynamic_evaluation_stack'
2600+
2601+
def __init__(self):
2602+
self._global_stack = []
2603+
2604+
def ensure_thread_safety(self, context: DynamicEvaluationContext):
2605+
if ((context.per_thread and self._global_stack)
2606+
or (not context.per_thread and self._local_stack)):
2607+
raise ValueError(
2608+
'Nested dynamic evaluation contexts must be either all per-thread '
2609+
'or all process-wise. Please check the `per_thread` argument of '
2610+
'the `pg.hyper.DynamicEvaluationContext` objects being used.')
2611+
2612+
@property
2613+
def _local_stack(self):
2614+
"""Returns thread-local stack."""
2615+
stack = getattr(_thread_local_state, self._TLS_KEY, None)
2616+
if stack is None:
2617+
stack = []
2618+
setattr(_thread_local_state, self._TLS_KEY, stack)
2619+
return stack
2620+
2621+
def push(self, context: DynamicEvaluationContext):
2622+
"""Pushes the context to the stack."""
2623+
stack = self._local_stack if context.per_thread else self._global_stack
2624+
stack.append(context)
2625+
2626+
def pop(self, context: DynamicEvaluationContext):
2627+
"""Pops the context from the stack."""
2628+
stack = self._local_stack if context.per_thread else self._global_stack
2629+
assert stack
2630+
stack_top = stack.pop(-1)
2631+
assert stack_top is context, (stack_top, context)
2632+
2633+
def get_parent(
2634+
self,
2635+
context: DynamicEvaluationContext) -> Optional[DynamicEvaluationContext]:
2636+
"""Returns the parent context of the input context."""
2637+
stack = self._local_stack if context.per_thread else self._global_stack
2638+
parent = None
2639+
for i in reversed(range(1, len(stack))):
2640+
if context is stack[i]:
2641+
parent = stack[i - 1]
2642+
break
2643+
return parent
2644+
2645+
2646+
# System-wise dynamic evaluation stack.
2647+
_dynamic_evaluation_stack = _DynamicEvaluationStack()
25312648

25322649

25332650
def trace(
25342651
fun: Callable[[], Any],
2652+
*,
2653+
where: Optional[Callable[[HyperPrimitive], bool]] = None,
25352654
require_hyper_name: bool = False,
2536-
per_thread: bool = True
2537-
) -> DynamicEvaluationContext:
2655+
per_thread: bool = True) -> DynamicEvaluationContext:
25382656
"""Trace the hyper primitives called within a function by executing it.
25392657
25402658
See examples in :class:`pyglove.hyper.DynamicEvaluationContext`.
25412659
25422660
Args:
25432661
fun: Function in which the search space is defined.
2662+
where: A callable object that decide whether a hyper primitive should be
2663+
included when being instantiated under `collect`.
2664+
If None, all hyper primitives under `collect` will be included.
25442665
require_hyper_name: If True, all hyper primitives defined in this scope
25452666
will need to carry their names, which is usually a good idea when the
25462667
function that instantiates the hyper primtives need to be called multiple
@@ -2552,7 +2673,7 @@ def trace(
25522673
An DynamicEvaluationContext that can be passed to `pg.sample`.
25532674
"""
25542675
context = DynamicEvaluationContext(
2555-
require_hyper_name=require_hyper_name, per_thread=per_thread)
2676+
where=where, require_hyper_name=require_hyper_name, per_thread=per_thread)
25562677
with context.collect():
25572678
fun()
25582679
return context

0 commit comments

Comments
 (0)