Skip to content

[GENERAL SUPPORT]: scheduler + batch question #3247

Open
@VMLC-PV

Description

Question

I am trying to write a scheduler that runs some simulations into batches but still uses AxClient and I cannot make it work properly.
I modified the scheduler tuto to run the job using a multiprocessing pool and write results to a file in a tmp folder where I can later fetch the results from.
I modified my code with the brain function to allow for other people to test (see code below)

Basically when I set the init_seconds_between_polls in
options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
to a small value (i.e. shorter than the run time of most processes) it fails because of the following:
Scheduler: MetricFetchE INFO: Because branin is an objective, marking trial 19 as TrialStatus.FAILED.

which I guess comes from #L2025 and I don't understand why this happens?
if I set init_seconds_between_polls=4 to ensure that it polls after they are all done then things seem to work. I don't understand what I am missing.
In my real-life case, I don't really know how long I need to wait before polls so I expected that putting a init_seconds_between_polls to a small value would just check often if all jobs in the batch are done and then proceed but it is not what happens...

Please provide any relevant code snippet if applicable.

from ax.utils.measurement.synthetic_functions import branin

import os,sys,json,uuid,time,torch,random
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from functools import partial,reduce
from typing import Any, Dict, NamedTuple, Union, Iterable, Set
import ax
from ax import *
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.runners.synthetic import SyntheticRunner
from ax.core.base_trial import BaseTrial
from ax.core.base_trial import TrialStatus
from ax.core.metric import Metric, MetricFetchResult, MetricFetchE
from ax.core.data import Data
from ax.utils.common.result import Ok, Err
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.service.scheduler import Scheduler, SchedulerOptions, TrialType
from collections import defaultdict

from torch.multiprocessing import Pool, set_start_method
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.core.base_trial import TrialStatus as T
from ax.core.parameter import RangeParameter, ParameterType

try: # needed for multiprocessing when using pytorch
    set_start_method('spawn')
except RuntimeError:
    pass

class MockJob(NamedTuple):
    """Dummy class to represent a job scheduled on `MockJobQueue`."""

    id: int
    parameters: Dict[str, Union[str, float, int, bool]]

    def run(self, job_id, parameter, tmp_dir = None):

        x1 = parameter['x1']
        x2 = parameter['x2']
        branin_data = branin(x1, x2)
        time.sleep(random.uniform(0.1,3))
        res_dic = {'branin':branin_data}
        print('Job ID:',job_id, 'Parameters:',parameter, 'Results:',res_dic)
        # save the results in tmp folder with the job_id in json format
        if tmp_dir is not None:
            if not os.path.exists(tmp_dir):
                os.makedirs(tmp_dir)
            with open(os.path.join(tmp_dir,str(job_id)+'.json'), 'w') as fp:
                json.dump(res_dic, fp)
            
        

class MockJobQueueClient:
        """Dummy class to represent a job queue where the Ax `Scheduler` will
        deploy trial evaluation runs during optimization.
        """

        jobs: Dict[str, MockJob] = {}

        def __init__(self, pool = None, tmp_dir = None ):
            self.pool = pool
            self.tmp_dir = tmp_dir

        def schedule_job_with_parameters(
            self, parameters: Dict[str, Union[str, float, int, bool]]
        ) -> int:
            """Schedules an evaluation job with given parameters and returns job ID."""
            # Code to actually schedule the job and produce an ID would go here;
            job_id = str(uuid.uuid4())
            mock = MockJob(job_id, parameters)
            # add mock run to the queue q 
            self.jobs[job_id] = MockJob(job_id, parameters)
            self.pool.apply_async(self.jobs[job_id].run, args=(job_id, parameters, self.tmp_dir ))

            return job_id

        def get_job_status(self, job_id: str) -> TrialStatus:
            """ "Get status of the job by a given ID. For simplicity of the example,
            return an Ax `TrialStatus`.
            """
            job = self.jobs[job_id]
            # check if job_id.json exists in the tmp directory
            if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
                #load the results
                with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
                    res_dic = json.load(fp)

                # check is nan in res_dic
                for key in res_dic.keys():
                    if np.isnan(res_dic[key]):
                        return TrialStatus.FAILED
                    
                return TrialStatus.COMPLETED
            else:
                return TrialStatus.RUNNING

        def get_outcome_value_for_completed_job(self, job_id: int) -> Dict[str, float]:
            """Get evaluation results for a given completed job."""
            job = self.jobs[job_id]
            # In a real external system, this would retrieve real relevant outcomes and
            # not a synthetic function value.
            # check if job_id.json exists in the tmp directory
            if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
                #load the results
                with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
                    res_dic = json.load(fp)
                # delete file
                os.remove(os.path.join(self.tmp_dir,str(job_id)+'.json'))
                # print('WE ARE DELETING THE FILE')
                return res_dic
            else:
                raise ValueError('The job is not completed yet')



def get_mock_job_queue_client(MOCK_JOB_QUEUE_CLIENT) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return MOCK_JOB_QUEUE_CLIENT


class MockJobRunner(Runner):  # Deploys trials to external system.

    def __init__(self, pool = None, tmp_dir = None):
        self.pool = pool
        self.tmp_dir = tmp_dir
        self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)

    def _get_mock_job_queue_client(self) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return self.MOCK_JOB_QUEUE_CLIENT
    
    def run(self, trial: BaseTrial) -> Dict[str, Any]:
        """Deploys a trial based on custom runner subclass implementation.

        Args:
            trial: The trial to deploy.

        Returns:
            Dict of run metadata from the deployment process.
        """
        if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
            raise ValueError("This runner only handles `Trial`.")

        mock_job_queue = self._get_mock_job_queue_client()

        run_metadata = []
        if isinstance(trial, BatchTrial):
            for arm in trial.arms:
                job_id = mock_job_queue.schedule_job_with_parameters(
                    parameters=arm.parameters
                )
                # This run metadata will be attached to trial as `trial.run_metadata`
                # by the base `Scheduler`.
                arm.run_metadata = {"job_id": job_id}
        else:
            job_id = mock_job_queue.schedule_job_with_parameters(
                parameters=trial.arm.parameters
            )

        # This run metadata will be attached to trial as `trial.run_metadata`
        # by the base `Scheduler`.
        return {"job_id": job_id}

    def poll_trial_status(
        self, trials: Iterable[BaseTrial]
    ) -> Dict[TrialStatus, Set[int]]:
        """Checks the status of any non-terminal trials and returns their
        indices as a mapping from TrialStatus to a list of indices. Required
        for runners used with Ax ``Scheduler``.

        NOTE: Does not need to handle waiting between polling calls while trials
        are running; this function should just perform a single poll.

        Args:
            trials: Trials to poll.

        Returns:
            A dictionary mapping TrialStatus to a list of trial indices that have
            the respective status at the time of the polling. This does not need to
            include trials that at the time of polling already have a terminal
            (ABANDONED, FAILED, COMPLETED) status (but it may).
        """
        status_dict = defaultdict(set)
        for trial in trials:
            mock_job_queue = self._get_mock_job_queue_client()
            status = mock_job_queue.get_job_status(
                job_id=trial.run_metadata.get("job_id")
            )
            status_dict[status].add(trial.index)

        return status_dict
    
class BraninForMockJobMetric(Metric):  # Pulls data for trial from external system.
    def __init__(self, name = None, pool = None, tmp_dir = None, **kwargs):
        self.pool = pool
        self.tmp_dir = tmp_dir
        self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)
        super().__init__(name=name, **kwargs)

    def _get_mock_job_queue_client(self) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return self.MOCK_JOB_QUEUE_CLIENT

    def fetch_trial_data(self, trial: BaseTrial) -> MetricFetchResult:
        """Obtains data via fetching it from ` for a given trial."""
        if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
            raise ValueError("This metric only handles `Trial`.")

        try:
            mock_job_queue = self._get_mock_job_queue_client()

            # Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
            if isinstance(trial, BatchTrial):
                lst_df_dict = []
                for arm in trial.arms:
                    # branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                    #     job_id=trial.run_metadata.get("job_id")
                    # )
                    # arm.run_metadata.get("job_id")
                    branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                        job_id=arm.run_metadata.get("job_id")
                    )
                    name_ = list(branin_data.keys())[0]
                    df_dict = {
                        "trial_index": trial.index,
                        "metric_name": self.name,
                        "arm_name": arm.name,
                        "mean": branin_data.get(self.name),
                        # Can be set to 0.0 if function is known to be noiseless
                        # or to an actual value when SEM is known. Setting SEM to
                        # `None` results in Ax assuming unknown noise and inferring
                        # noise level from data.
                        "sem": None,
                    }
                    lst_df_dict.append(df_dict)
                return Ok(value=Data(df=pd.DataFrame.from_records(lst_df_dict)))
            else:
                # branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                #         job_id=trial.run_metadata.get("job_id")
                #     )
                # df_dict = {
                #     "trial_index": trial.index,
                #     "metric_name": self.name,
                #     "arm_name": trial.arm.name,
                #     "mean": branin_data.get(self.name),
                #     # Can be set to 0.0 if function is known to be noiseless
                #     # or to an actual value when SEM is known. Setting SEM to
                #     # `None` results in Ax assuming unknown noise and inferring
                #     # noise level from data.
                #     "sem": None,
                # }
                branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                        job_id=arm.run_metadata.get("job_id")
                    )
                name_ = list(branin_data.keys())[0]
                df_dict = {
                    "trial_index": trial.index,
                    "metric_name": self.name,
                    "arm_name": arm.name,
                    "mean": branin_data.get(self.name),
                    # Can be set to 0.0 if function is known to be noiseless
                    # or to an actual value when SEM is known. Setting SEM to
                    # `None` results in Ax assuming unknown noise and inferring
                    # noise level from data.
                    "sem": None,
                }
                return Ok(value=Data(df=pd.DataFrame.from_records([df_dict])))
        except Exception as e:
            return Err(
                MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
            )


def create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list):
        """ Create a generation strategy for the optimization process using the models and the number of batches and batch sizes. See ax documentation for more details: https://ax.dev/tutorials/generation_strategy.html

        Returns
        -------
        GenerationStrategy
            The generation strategy for the optimization process

        Raises
        ------
        ValueError
            If the model is not a string or a Models enum
        """        

        steps = []
        for i, model in enumerate(models):
            if type(model) == str:
                model = Models[model]
            elif isinstance(model, Models):
                model = model
            else:
                raise ValueError('Model must be a string or a Models enum')
            steps.append(GenerationStep(
                model=model,
                num_trials=n_batches[i]*batch_size[i],
                max_parallelism=min(max_parallelism,batch_size[i]),
                model_kwargs= model_kwargs_list[i],
                model_gen_kwargs= model_gen_kwargs_list[i],
            ))

        gs = GenerationStrategy(steps=steps, )

        return gs

from ax.service.utils.report_utils import exp_to_df
def main():
    print('----We are Starting the Branin Test----')
    print(branin(0.1,0.1))
    tmp_dir = os.path.join(os.getcwd(),'.tmp_dir')
    print('tmp_dir:',tmp_dir)
    models = ['SOBOL','BOTORCH_MODULAR']
    n_batches = [1,4]
    batch_size = [1,4]
    max_parallelism = 4
    model_gen_kwargs_list =[{},{}]
    model_kwargs_list = [{},{'torch_device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),'torch_dtype': torch.double,'botorch_acqf_class':qLogNoisyExpectedImprovement,'transforms':[RemoveFixed, Log,UnitX, StandardizeY],},] 

    parameter_space = [
        {
            "name": "x1",
            "type": "range",
            "bounds": [-5, 10],
            "value_type": "float",
            "log_scale": False,
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0, 15],
            "value_type": "float",
            "log_scale": False,
        },
    ]
    enforce_sequential_optimization = False
    gs = create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list)

    ax_client = AxClient(generation_strategy=gs, enforce_sequential_optimization=enforce_sequential_optimization)

    objectives_ = {'branin':ObjectiveProperties(minimize=True)}

    ax_client.create_experiment(
            name='test_branin',
            parameters=parameter_space,
            # objectives=objectives_,
            
        )
    q = Pool(4)
    obj = Objective(metric=BraninForMockJobMetric(name='branin', pool = q, tmp_dir = tmp_dir), minimize=True)

    ax_client.experiment.optimization_config=OptimizationConfig(objective=obj)

    # create runner
    runner = MockJobRunner(pool = q, tmp_dir = tmp_dir)
    ax_client.experiment.runner = runner

    n = 0
    total_trials = sum(np.asarray(n_batches)*np.asarray(batch_size))
    n_step_points = np.cumsum(np.asarray(n_batches)*np.asarray(batch_size))
    scheduler = Scheduler(
        experiment=ax_client.experiment,
        generation_strategy=ax_client.generation_strategy,
        options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
        )
    
    while n < total_trials:
        # check the current batch size
        
        curr_batch_size = batch_size[np.argmax(n_step_points>n)]
        n += curr_batch_size
        if n > total_trials:
            curr_batch_size = curr_batch_size - (n-total_trials)

        scheduler.run_n_trials(max_trials=4)
    
    q.close()
    q.join()

    df = exp_to_df(ax_client.experiment)
    print(df)

if __name__ == '__main__':
    main()

Code of Conduct

  • I agree to follow this Ax's Code of Conduct

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions