diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 502cd523586..a4a22c4cc5a 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1123,6 +1123,53 @@ def _check_exit_status_and_report_results( idle_callback, force_refit=True ) + def _check_if_can_run_trials(self) -> bool: + """Check if we have capacity to run new trials.""" + + capacity = self.runner.poll_available_capacity() + max_pending_trials = self._get_max_pending_trials() + num_pending_trials = len(self.pending_trials) + # Check if we've reached maximum pending trials + if num_pending_trials >= max_pending_trials: + self.logger.debug( + f"`max_pending_trials={max_pending_trials}` and {num_pending_trials} " + "trials are currently pending; not initiating any additional trials." + ) + return False + # Check if we have available capacity + if capacity != -1 and capacity < 1: + self.logger.debug("There is no capacity to run any trials.") + return False + return True + + def _compute_num_trials_to_run(self, max_new_trials: int) -> tuple[int, int]: + num_existing_trials_to_run = ( + self.runner.poll_available_capacity() + if self.options.run_trials_in_batches + else 1 + ) + total_trials = self.options.total_trials + max_pending_trials = self._get_max_pending_trials() + + num_pending_trials = len(self.pending_trials) + max_pending_upper_bound = max_pending_trials - num_pending_trials + + num_existing_trials_to_run = ( + max_pending_upper_bound + if num_existing_trials_to_run == -1 + else min(max_pending_upper_bound, num_existing_trials_to_run) + ) + + if total_trials is not None: + left_in_total = total_trials - len(self.trials_expecting_data) + num_existing_trials_to_run = min(num_existing_trials_to_run, left_in_total) + + return num_existing_trials_to_run, min( + num_existing_trials_to_run + - len(self.candidate_trials[:num_existing_trials_to_run]), + max_new_trials, + ) + def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool: """Schedules trial evaluation(s) if stopping criterion is not triggered, maximum parallelism is not currently reached, and capacity allows. @@ -1157,10 +1204,28 @@ def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool: # Check if capacity allows for running new evaluations and generate as many # trials as possible, limited by capacity and model requirements. self._sleep_if_too_early_to_poll() - existing_trials, new_trials = self._prepare_trials( + + if not self._check_if_can_run_trials(): + return False + + existing_trial_count, new_trial_count = self._compute_num_trials_to_run( max_new_trials=max_new_trials ) + existing_trials = self.candidate_trials[:existing_trial_count] + new_trials, gen_candidates_err = ( + self.generate_candidates(num_trials=new_trial_count) + if new_trial_count > 0 + else ([], None) + ) + + if gen_candidates_err: + new_trials = [] + self.logger.exception( + "An unexpected error occurred while generating " + f"trials: {gen_candidates_err}. " + ) + if not existing_trials and not new_trials: # Unable to gen. new run due to max parallelism limit or need for data # or unable to run trials due to lack of capacity.