Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor _prepare_trials #3176

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,53 @@ def _check_exit_status_and_report_results(
idle_callback, force_refit=True
)

def _check_if_can_run_trials(self) -> bool:
"""Check if we have capacity to run new trials."""

capacity = self.runner.poll_available_capacity()
max_pending_trials = self._get_max_pending_trials()
num_pending_trials = len(self.pending_trials)
# Check if we've reached maximum pending trials
if num_pending_trials >= max_pending_trials:
self.logger.debug(
f"`max_pending_trials={max_pending_trials}` and {num_pending_trials} "
"trials are currently pending; not initiating any additional trials."
)
return False
# Check if we have available capacity
if capacity != -1 and capacity < 1:
self.logger.debug("There is no capacity to run any trials.")
return False
return True

def _compute_num_trials_to_run(self, max_new_trials: int) -> tuple[int, int]:
num_existing_trials_to_run = (
self.runner.poll_available_capacity()
if self.options.run_trials_in_batches
else 1
)
total_trials = self.options.total_trials
max_pending_trials = self._get_max_pending_trials()

num_pending_trials = len(self.pending_trials)
max_pending_upper_bound = max_pending_trials - num_pending_trials

num_existing_trials_to_run = (
max_pending_upper_bound
if num_existing_trials_to_run == -1
else min(max_pending_upper_bound, num_existing_trials_to_run)
)

if total_trials is not None:
left_in_total = total_trials - len(self.trials_expecting_data)
num_existing_trials_to_run = min(num_existing_trials_to_run, left_in_total)

return num_existing_trials_to_run, min(
num_existing_trials_to_run
- len(self.candidate_trials[:num_existing_trials_to_run]),
max_new_trials,
)

def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool:
"""Schedules trial evaluation(s) if stopping criterion is not triggered,
maximum parallelism is not currently reached, and capacity allows.
Expand Down Expand Up @@ -1157,10 +1204,28 @@ def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool:
# Check if capacity allows for running new evaluations and generate as many
# trials as possible, limited by capacity and model requirements.
self._sleep_if_too_early_to_poll()
existing_trials, new_trials = self._prepare_trials(

if not self._check_if_can_run_trials():
return False

existing_trial_count, new_trial_count = self._compute_num_trials_to_run(
max_new_trials=max_new_trials
)

existing_trials = self.candidate_trials[:existing_trial_count]
new_trials, gen_candidates_err = (
self.generate_candidates(num_trials=new_trial_count)
if new_trial_count > 0
else ([], None)
)

if gen_candidates_err:
new_trials = []
self.logger.exception(
"An unexpected error occurred while generating "
f"trials: {gen_candidates_err}. "
)

if not existing_trials and not new_trials:
# Unable to gen. new run due to max parallelism limit or need for data
# or unable to run trials due to lack of capacity.
Expand Down
Loading