Skip to content

Commit a57e9c0

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Add caching to common methods (#4830)
Summary: Pull Request resolved: #4830 This method is called many, many times during generation and it's computational cost adds up over time. By cacheing it we can significant improvements in computation time, especially in high trial count regimes. Reviewed By: mpolson64 Differential Revision: D91552553
1 parent 7e0d2c1 commit a57e9c0

2 files changed

Lines changed: 21 additions & 5 deletions

File tree

ax/generation_strategy/generation_node.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def __init__(
190190
self.fallback_specs = (
191191
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
192192
)
193+
# Cache for trials_from_node property to avoid recomputation
194+
# on every access. Invalidated when trial count changes.
195+
self._trials_from_node_cache: set[int] = set()
196+
self._cached_trial_count: int | None = None
193197

194198
@property
195199
def name(self) -> str:
@@ -724,17 +728,25 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec:
724728
def trials_from_node(self) -> set[int]:
725729
"""Returns a set containing the indices of trials generated by this node.
726730
731+
Results are cached and invalidated when the experiment's trial count changes.
732+
727733
Returns:
728734
Set[int]: A set containing all the indices of trials generated by this node.
729735
"""
736+
current_trial_count = len(self.experiment.trials)
737+
if self._cached_trial_count == current_trial_count:
738+
return self._trials_from_node_cache
739+
740+
# (re)-build cache
730741
trials_from_node = set()
731-
for _idx, trial in self.experiment.trials.items():
742+
for trial in self.experiment.trials.values():
732743
for gr in trial.generator_runs:
733-
if (
734-
gr._generation_node_name is not None
735-
and gr._generation_node_name == self.name
736-
):
744+
if gr._generation_node_name == self.name:
737745
trials_from_node.add(trial.index)
746+
break
747+
748+
self._trials_from_node_cache = trials_from_node
749+
self._cached_trial_count = current_trial_count
738750
return trials_from_node
739751

740752
@property

ax/generation_strategy/generation_strategy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ def _unset_non_persistent_state_fields(self) -> None:
361361
n._step_index = None
362362
if len(n.generator_specs) > 1:
363363
n._generator_spec_to_gen_from = None
364+
# Reset cache fields that are used for performance optimization only
365+
# and should not affect equality comparisons.
366+
n._trials_from_node_cache = set()
367+
n._cached_trial_count = None
364368

365369
# TODO: Deprecate `steps` argument fully in Q1'26.
366370
def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:

0 commit comments

Comments
 (0)