Skip to content

Commit ad21e82

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Add caching to common methods
Summary: This method is called many, many times during generation and it's computational cost adds up over time. By cacheing it we can significant improvements in computation time, especially in high trial count regimes. Differential Revision: D91552553
1 parent d9b0326 commit ad21e82

1 file changed

Lines changed: 20 additions & 5 deletions

File tree

ax/generation_strategy/generation_node.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def __init__(
190190
self.fallback_specs = (
191191
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
192192
)
193+
# Cache for trials_from_node property to avoid recomputation
194+
# on every access. Invalidated when trial count changes.
195+
self._trials_from_node_cache: set[int] | None = None
196+
self._cached_trial_count: int = -1
193197

194198
@property
195199
def name(self) -> str:
@@ -724,17 +728,28 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec:
724728
def trials_from_node(self) -> set[int]:
725729
"""Returns a set containing the indices of trials generated by this node.
726730
731+
Results are cached and invalidated when the experiment's trial count changes.
732+
727733
Returns:
728734
Set[int]: A set containing all the indices of trials generated by this node.
729735
"""
736+
current_trial_count = len(self.experiment.trials)
737+
if (
738+
self._trials_from_node_cache is not None
739+
and self._cached_trial_count == current_trial_count
740+
):
741+
return self._trials_from_node_cache
742+
743+
# (re)-build cache
730744
trials_from_node = set()
731-
for _idx, trial in self.experiment.trials.items():
745+
for trial in self.experiment.trials.values():
732746
for gr in trial.generator_runs:
733-
if (
734-
gr._generation_node_name is not None
735-
and gr._generation_node_name == self.name
736-
):
747+
if gr._generation_node_name == self.name:
737748
trials_from_node.add(trial.index)
749+
break
750+
751+
self._trials_from_node_cache = trials_from_node
752+
self._cached_trial_count = current_trial_count
738753
return trials_from_node
739754

740755
@property

0 commit comments

Comments
 (0)