Description
Question
I am trying to write a scheduler that runs some simulations into batches but still uses AxClient and I cannot make it work properly.
I modified the scheduler tuto to run the job using a multiprocessing pool and write results to a file in a tmp folder where I can later fetch the results from.
I modified my code with the brain function to allow for other people to test (see code below)
Basically when I set the init_seconds_between_polls
in
options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
to a small value (i.e. shorter than the run time of most processes) it fails because of the following:
Scheduler: MetricFetchE INFO: Because branin is an objective, marking trial 19 as TrialStatus.FAILED.
which I guess comes from #L2025 and I don't understand why this happens?
if I set init_seconds_between_polls=4
to ensure that it polls after they are all done then things seem to work. I don't understand what I am missing.
In my real-life case, I don't really know how long I need to wait before polls so I expected that putting a init_seconds_between_polls
to a small value would just check often if all jobs in the batch are done and then proceed but it is not what happens...
Please provide any relevant code snippet if applicable.
from ax.utils.measurement.synthetic_functions import branin
import os,sys,json,uuid,time,torch,random
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from functools import partial,reduce
from typing import Any, Dict, NamedTuple, Union, Iterable, Set
import ax
from ax import *
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.runners.synthetic import SyntheticRunner
from ax.core.base_trial import BaseTrial
from ax.core.base_trial import TrialStatus
from ax.core.metric import Metric, MetricFetchResult, MetricFetchE
from ax.core.data import Data
from ax.utils.common.result import Ok, Err
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.service.scheduler import Scheduler, SchedulerOptions, TrialType
from collections import defaultdict
from torch.multiprocessing import Pool, set_start_method
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.core.base_trial import TrialStatus as T
from ax.core.parameter import RangeParameter, ParameterType
try: # needed for multiprocessing when using pytorch
set_start_method('spawn')
except RuntimeError:
pass
class MockJob(NamedTuple):
"""Dummy class to represent a job scheduled on `MockJobQueue`."""
id: int
parameters: Dict[str, Union[str, float, int, bool]]
def run(self, job_id, parameter, tmp_dir = None):
x1 = parameter['x1']
x2 = parameter['x2']
branin_data = branin(x1, x2)
time.sleep(random.uniform(0.1,3))
res_dic = {'branin':branin_data}
print('Job ID:',job_id, 'Parameters:',parameter, 'Results:',res_dic)
# save the results in tmp folder with the job_id in json format
if tmp_dir is not None:
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
with open(os.path.join(tmp_dir,str(job_id)+'.json'), 'w') as fp:
json.dump(res_dic, fp)
class MockJobQueueClient:
"""Dummy class to represent a job queue where the Ax `Scheduler` will
deploy trial evaluation runs during optimization.
"""
jobs: Dict[str, MockJob] = {}
def __init__(self, pool = None, tmp_dir = None ):
self.pool = pool
self.tmp_dir = tmp_dir
def schedule_job_with_parameters(
self, parameters: Dict[str, Union[str, float, int, bool]]
) -> int:
"""Schedules an evaluation job with given parameters and returns job ID."""
# Code to actually schedule the job and produce an ID would go here;
job_id = str(uuid.uuid4())
mock = MockJob(job_id, parameters)
# add mock run to the queue q
self.jobs[job_id] = MockJob(job_id, parameters)
self.pool.apply_async(self.jobs[job_id].run, args=(job_id, parameters, self.tmp_dir ))
return job_id
def get_job_status(self, job_id: str) -> TrialStatus:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
job = self.jobs[job_id]
# check if job_id.json exists in the tmp directory
if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
#load the results
with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
res_dic = json.load(fp)
# check is nan in res_dic
for key in res_dic.keys():
if np.isnan(res_dic[key]):
return TrialStatus.FAILED
return TrialStatus.COMPLETED
else:
return TrialStatus.RUNNING
def get_outcome_value_for_completed_job(self, job_id: int) -> Dict[str, float]:
"""Get evaluation results for a given completed job."""
job = self.jobs[job_id]
# In a real external system, this would retrieve real relevant outcomes and
# not a synthetic function value.
# check if job_id.json exists in the tmp directory
if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
#load the results
with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
res_dic = json.load(fp)
# delete file
os.remove(os.path.join(self.tmp_dir,str(job_id)+'.json'))
# print('WE ARE DELETING THE FILE')
return res_dic
else:
raise ValueError('The job is not completed yet')
def get_mock_job_queue_client(MOCK_JOB_QUEUE_CLIENT) -> MockJobQueueClient:
"""Obtain the singleton job queue instance."""
return MOCK_JOB_QUEUE_CLIENT
class MockJobRunner(Runner): # Deploys trials to external system.
def __init__(self, pool = None, tmp_dir = None):
self.pool = pool
self.tmp_dir = tmp_dir
self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)
def _get_mock_job_queue_client(self) -> MockJobQueueClient:
"""Obtain the singleton job queue instance."""
return self.MOCK_JOB_QUEUE_CLIENT
def run(self, trial: BaseTrial) -> Dict[str, Any]:
"""Deploys a trial based on custom runner subclass implementation.
Args:
trial: The trial to deploy.
Returns:
Dict of run metadata from the deployment process.
"""
if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
raise ValueError("This runner only handles `Trial`.")
mock_job_queue = self._get_mock_job_queue_client()
run_metadata = []
if isinstance(trial, BatchTrial):
for arm in trial.arms:
job_id = mock_job_queue.schedule_job_with_parameters(
parameters=arm.parameters
)
# This run metadata will be attached to trial as `trial.run_metadata`
# by the base `Scheduler`.
arm.run_metadata = {"job_id": job_id}
else:
job_id = mock_job_queue.schedule_job_with_parameters(
parameters=trial.arm.parameters
)
# This run metadata will be attached to trial as `trial.run_metadata`
# by the base `Scheduler`.
return {"job_id": job_id}
def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> Dict[TrialStatus, Set[int]]:
"""Checks the status of any non-terminal trials and returns their
indices as a mapping from TrialStatus to a list of indices. Required
for runners used with Ax ``Scheduler``.
NOTE: Does not need to handle waiting between polling calls while trials
are running; this function should just perform a single poll.
Args:
trials: Trials to poll.
Returns:
A dictionary mapping TrialStatus to a list of trial indices that have
the respective status at the time of the polling. This does not need to
include trials that at the time of polling already have a terminal
(ABANDONED, FAILED, COMPLETED) status (but it may).
"""
status_dict = defaultdict(set)
for trial in trials:
mock_job_queue = self._get_mock_job_queue_client()
status = mock_job_queue.get_job_status(
job_id=trial.run_metadata.get("job_id")
)
status_dict[status].add(trial.index)
return status_dict
class BraninForMockJobMetric(Metric): # Pulls data for trial from external system.
def __init__(self, name = None, pool = None, tmp_dir = None, **kwargs):
self.pool = pool
self.tmp_dir = tmp_dir
self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)
super().__init__(name=name, **kwargs)
def _get_mock_job_queue_client(self) -> MockJobQueueClient:
"""Obtain the singleton job queue instance."""
return self.MOCK_JOB_QUEUE_CLIENT
def fetch_trial_data(self, trial: BaseTrial) -> MetricFetchResult:
"""Obtains data via fetching it from ` for a given trial."""
if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
raise ValueError("This metric only handles `Trial`.")
try:
mock_job_queue = self._get_mock_job_queue_client()
# Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
if isinstance(trial, BatchTrial):
lst_df_dict = []
for arm in trial.arms:
# branin_data = mock_job_queue.get_outcome_value_for_completed_job(
# job_id=trial.run_metadata.get("job_id")
# )
# arm.run_metadata.get("job_id")
branin_data = mock_job_queue.get_outcome_value_for_completed_job(
job_id=arm.run_metadata.get("job_id")
)
name_ = list(branin_data.keys())[0]
df_dict = {
"trial_index": trial.index,
"metric_name": self.name,
"arm_name": arm.name,
"mean": branin_data.get(self.name),
# Can be set to 0.0 if function is known to be noiseless
# or to an actual value when SEM is known. Setting SEM to
# `None` results in Ax assuming unknown noise and inferring
# noise level from data.
"sem": None,
}
lst_df_dict.append(df_dict)
return Ok(value=Data(df=pd.DataFrame.from_records(lst_df_dict)))
else:
# branin_data = mock_job_queue.get_outcome_value_for_completed_job(
# job_id=trial.run_metadata.get("job_id")
# )
# df_dict = {
# "trial_index": trial.index,
# "metric_name": self.name,
# "arm_name": trial.arm.name,
# "mean": branin_data.get(self.name),
# # Can be set to 0.0 if function is known to be noiseless
# # or to an actual value when SEM is known. Setting SEM to
# # `None` results in Ax assuming unknown noise and inferring
# # noise level from data.
# "sem": None,
# }
branin_data = mock_job_queue.get_outcome_value_for_completed_job(
job_id=arm.run_metadata.get("job_id")
)
name_ = list(branin_data.keys())[0]
df_dict = {
"trial_index": trial.index,
"metric_name": self.name,
"arm_name": arm.name,
"mean": branin_data.get(self.name),
# Can be set to 0.0 if function is known to be noiseless
# or to an actual value when SEM is known. Setting SEM to
# `None` results in Ax assuming unknown noise and inferring
# noise level from data.
"sem": None,
}
return Ok(value=Data(df=pd.DataFrame.from_records([df_dict])))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
def create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list):
""" Create a generation strategy for the optimization process using the models and the number of batches and batch sizes. See ax documentation for more details: https://ax.dev/tutorials/generation_strategy.html
Returns
-------
GenerationStrategy
The generation strategy for the optimization process
Raises
------
ValueError
If the model is not a string or a Models enum
"""
steps = []
for i, model in enumerate(models):
if type(model) == str:
model = Models[model]
elif isinstance(model, Models):
model = model
else:
raise ValueError('Model must be a string or a Models enum')
steps.append(GenerationStep(
model=model,
num_trials=n_batches[i]*batch_size[i],
max_parallelism=min(max_parallelism,batch_size[i]),
model_kwargs= model_kwargs_list[i],
model_gen_kwargs= model_gen_kwargs_list[i],
))
gs = GenerationStrategy(steps=steps, )
return gs
from ax.service.utils.report_utils import exp_to_df
def main():
print('----We are Starting the Branin Test----')
print(branin(0.1,0.1))
tmp_dir = os.path.join(os.getcwd(),'.tmp_dir')
print('tmp_dir:',tmp_dir)
models = ['SOBOL','BOTORCH_MODULAR']
n_batches = [1,4]
batch_size = [1,4]
max_parallelism = 4
model_gen_kwargs_list =[{},{}]
model_kwargs_list = [{},{'torch_device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),'torch_dtype': torch.double,'botorch_acqf_class':qLogNoisyExpectedImprovement,'transforms':[RemoveFixed, Log,UnitX, StandardizeY],},]
parameter_space = [
{
"name": "x1",
"type": "range",
"bounds": [-5, 10],
"value_type": "float",
"log_scale": False,
},
{
"name": "x2",
"type": "range",
"bounds": [0, 15],
"value_type": "float",
"log_scale": False,
},
]
enforce_sequential_optimization = False
gs = create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list)
ax_client = AxClient(generation_strategy=gs, enforce_sequential_optimization=enforce_sequential_optimization)
objectives_ = {'branin':ObjectiveProperties(minimize=True)}
ax_client.create_experiment(
name='test_branin',
parameters=parameter_space,
# objectives=objectives_,
)
q = Pool(4)
obj = Objective(metric=BraninForMockJobMetric(name='branin', pool = q, tmp_dir = tmp_dir), minimize=True)
ax_client.experiment.optimization_config=OptimizationConfig(objective=obj)
# create runner
runner = MockJobRunner(pool = q, tmp_dir = tmp_dir)
ax_client.experiment.runner = runner
n = 0
total_trials = sum(np.asarray(n_batches)*np.asarray(batch_size))
n_step_points = np.cumsum(np.asarray(n_batches)*np.asarray(batch_size))
scheduler = Scheduler(
experiment=ax_client.experiment,
generation_strategy=ax_client.generation_strategy,
options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
)
while n < total_trials:
# check the current batch size
curr_batch_size = batch_size[np.argmax(n_step_points>n)]
n += curr_batch_size
if n > total_trials:
curr_batch_size = curr_batch_size - (n-total_trials)
scheduler.run_n_trials(max_trials=4)
q.close()
q.join()
df = exp_to_df(ax_client.experiment)
print(df)
if __name__ == '__main__':
main()
Code of Conduct
- I agree to follow this Ax's Code of Conduct