Skip to content

Commit 357f9eb

Browse files
mgrange1998meta-codesync[bot]
authored andcommitted
Rename get_max_parallelism to get_max_concurrency in AxClient (facebook#4923)
Summary: Pull Request resolved: facebook#4923 Renames `AxClient.get_max_parallelism()` to `get_max_concurrency()` and updates internal variable names, comments, and docstrings to use "concurrency" terminology. The old `get_max_parallelism` is preserved as a deprecated stub raising `NotImplementedError`. Also updates `get_recommended_max_parallelism` to point to the new name, and imports `MaxParallelismReachedException` / `MaxGenerationParallelism` under concurrency-named aliases. `get_max_parallelism` is only used directly in ad-hoc notebooks, making this a low-risk rename Reviewed By: saitcakmak Differential Revision: D93771849 fbshipit-source-id: d19096a66f81f214167fdfff8354d97a2a3b24d3
1 parent 72a3a58 commit 357f9eb

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

ax/service/ax_client.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -836,39 +836,39 @@ def get_trials_data_frame(self) -> pd.DataFrame:
836836
"""
837837
return self.experiment.to_df()
838838

839-
def get_max_parallelism(self) -> list[tuple[int, int]]:
840-
"""Retrieves maximum number of trials that can be scheduled in parallel
839+
def get_max_concurrency(self) -> list[tuple[int, int]]:
840+
"""Retrieves maximum number of trials that can be scheduled concurrently
841841
at different stages of optimization.
842842
843843
Some optimization algorithms profit significantly from sequential
844844
optimization (i.e. suggest a few points, get updated with data for them,
845845
repeat, see https://ax.dev/docs/bayesopt.html).
846-
Parallelism setting indicates how many trials should be running simulteneously
846+
Concurrency setting indicates how many trials should be running simultaneously
847847
(generated, but not yet completed with data).
848848
849849
The output of this method is mapping of form
850-
{num_trials -> max_parallelism_setting}, where the max_parallelism_setting
851-
is used for num_trials trials. If max_parallelism_setting is -1, as
852-
many of the trials can be ran in parallel, as necessary. If num_trials
853-
in a tuple is -1, then the corresponding max_parallelism_setting
850+
{num_trials -> max_concurrency_setting}, where the max_concurrency_setting
851+
is used for num_trials trials. If max_concurrency_setting is -1, as
852+
many of the trials can be ran concurrently, as necessary. If num_trials
853+
in a tuple is -1, then the corresponding max_concurrency_setting
854854
should be used for all subsequent trials.
855855
856856
For example, if the returned list is [(5, -1), (12, 6), (-1, 3)],
857-
the schedule could be: run 5 trials with any parallelism, run 6 trials in
858-
parallel twice, run 3 trials in parallel for as long as needed. Here,
857+
the schedule could be: run 5 trials with any concurrency, run 6 trials
858+
concurrently twice, run 3 trials concurrently for as long as needed. Here,
859859
'running' a trial means obtaining a next trial from `AxClient` through
860860
get_next_trials and completing it with data when available.
861861
862862
Returns:
863-
Mapping of form {num_trials -> max_parallelism_setting}.
863+
Mapping of form {num_trials -> max_concurrency_setting}.
864864
"""
865-
parallelism_settings = []
865+
concurrency_settings = []
866866
for node in self.generation_strategy._nodes:
867-
# Check pausing_criteria for max parallelism
868-
max_parallelism = None
867+
# Check pausing_criteria for max concurrency
868+
max_concurrency = None
869869
for pc in node.pausing_criteria:
870870
if isinstance(pc, MaxGenerationParallelism):
871-
max_parallelism = pc.threshold
871+
max_concurrency = pc.threshold
872872
break
873873
# Try to get num_trials from the node. If there's no MinTrials
874874
# criterion (unlimited trials), num_trials will raise UserInputError.
@@ -877,13 +877,16 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
877877
num_trials = node.num_trials
878878
except UserInputError:
879879
num_trials = -1
880-
parallelism_settings.append(
880+
concurrency_settings.append(
881881
(
882882
num_trials,
883-
max_parallelism if max_parallelism is not None else num_trials,
883+
max_concurrency if max_concurrency is not None else num_trials,
884884
)
885885
)
886-
return parallelism_settings
886+
return concurrency_settings
887+
888+
def get_max_parallelism(self) -> list[tuple[int, int]]:
889+
raise NotImplementedError("Use `get_max_concurrency` instead.")
887890

888891
def get_optimization_trace(
889892
self, objective_optimum: float | None = None
@@ -1702,8 +1705,8 @@ def __repr__(self) -> str:
17021705
@staticmethod
17031706
def get_recommended_max_parallelism() -> None:
17041707
raise NotImplementedError(
1705-
"Use `get_max_parallelism` instead; parallelism levels are now "
1706-
"enforced in generation strategy, so max parallelism is no longer "
1708+
"Use `get_max_concurrency` instead; concurrency levels are now "
1709+
"enforced in generation strategy, so max concurrency is no longer "
17071710
"just recommended."
17081711
)
17091712

ax/service/tests/test_ax_client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,14 +1615,14 @@ def test_keep_generating_without_data(self) -> None:
16151615
]
16161616
self.assertEqual(len(node0_pausing_criteria), 0)
16171617

1618-
# Check that max_parallelism is None by verifying no MaxGenerationParallelism
1618+
# Check that max_concurrency is None by verifying no MaxGenerationConcurrency
16191619
# criterion exists in pausing_criteria
1620-
node1_max_parallelism = [
1620+
node1_max_concurrency = [
16211621
pc
16221622
for pc in ax_client.generation_strategy._nodes[1].pausing_criteria
16231623
if isinstance(pc, MaxGenerationParallelism)
16241624
]
1625-
self.assertEqual(len(node1_max_parallelism), 0)
1625+
self.assertEqual(len(node1_max_concurrency), 0)
16261626

16271627
for _ in range(10):
16281628
ax_client.get_next_trial()
@@ -1938,17 +1938,17 @@ def test_relative_oc_without_sq(self) -> None:
19381938
def test_recommended_parallelism(self) -> None:
19391939
ax_client = AxClient()
19401940
with self.assertRaisesRegex(AssertionError, "No generation strategy"):
1941-
ax_client.get_max_parallelism()
1941+
ax_client.get_max_concurrency()
19421942
ax_client.create_experiment(
19431943
parameters=[
19441944
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
19451945
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
19461946
],
19471947
)
1948-
self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)])
1948+
self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)])
19491949
self.assertEqual(
19501950
run_trials_using_recommended_parallelism(
1951-
ax_client, ax_client.get_max_parallelism(), 20
1951+
ax_client, ax_client.get_max_concurrency(), 20
19521952
),
19531953
0,
19541954
)
@@ -2319,6 +2319,8 @@ def test_deprecated_save_load_method_errors(self) -> None:
23192319
ax_client.load_experiment("test_experiment")
23202320
with self.assertRaises(NotImplementedError):
23212321
ax_client.get_recommended_max_parallelism()
2322+
with self.assertRaises(NotImplementedError):
2323+
ax_client.get_max_parallelism()
23222324

23232325
def test_find_last_trial_with_parameterization(self) -> None:
23242326
ax_client = AxClient()
@@ -2871,7 +2873,7 @@ def test_estimate_early_stopping_savings(self) -> None:
28712873

28722874
self.assertEqual(ax_client.estimate_early_stopping_savings(), 0)
28732875

2874-
def test_max_parallelism_exception_when_early_stopping(self) -> None:
2876+
def test_max_concurrency_exception_when_early_stopping(self) -> None:
28752877
ax_client = AxClient()
28762878
ax_client.create_experiment(
28772879
parameters=[

0 commit comments

Comments
 (0)