77from ray .tune .search import create_searcher , ConcurrencyLimiter , SEARCH_ALG_IMPORT
88from netpyne .batchtools import runtk
99from collections import namedtuple
10- from batchtk .raytk .search import ray_trial , LABEL_POINTER
11- from batchtk .utils import get_path
10+ from batchtk .utils import get_path , SQLiteStorage , ScriptLogger
1211from io import StringIO
1312import numpy
1413from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1514from netpyne .batchtools import submits
15+ from batchtk import runtk
16+ from batchtk .runtk .trial import trial , LABEL_POINTER
17+ import datetime
1618#import signal #incompatible with signal and threading from ray
1719#import threading
1820
@@ -84,62 +86,8 @@ def ray_optuna_search(dispatcher_constructor: Callable, # constructor for the di
8486 -------
8587 Study: namedtuple('Study', ['algo', 'results'])(algo, results), # named tuple containing the created algorithm and the results of the search
8688 """
87- from ray .tune .search .optuna import OptunaSearch
88-
89- if ray_config is None :
90- ray_config = {}
91- ray_init_kwargs = ray_config #{"runtime_env": {"working_dir:": "."}} | ray_config # do not actually need to specify a working dir, can
92- ray .init (** ray_init_kwargs )# TODO needed for python import statements ?
93- if optuna_config == None :
94- optuna_config = {}
95-
96- storage_path = get_path (checkpoint_path )
97- algo = ConcurrencyLimiter (searcher = OptunaSearch (metric = metric , mode = mode , ** optuna_config ),
98- max_concurrent = max_concurrent ,
99- batch = batch ) #TODO does max_concurrent and batch work?
100-
101- #submit = submit_constructor()
102- #submit.update_templates(
103- # **run_config
104- #)
105- project_path = os .getcwd ()
106-
107- def run (config ):
108- config .update ({'saveFolder' : output_path , 'simLabel' : LABEL_POINTER })
109- data = ray_trial (config = config , label = label , dispatcher_constructor = dispatcher_constructor ,
110- project_path = project_path , output_path = output_path , submit_constructor = submit_constructor ,
111- submit_kwargs = run_config , log = None )
112- if isinstance (metric , str ):#TODO only Optuna supports multiobjective?
113- metrics = {'config' : config , 'data' : data , metric : data [metric ]}
114- session .report (metrics )
115- elif isinstance (metric , (list , tuple )):
116- metrics = {k : data [k ] for k in metric }
117- metrics ['config' ] = config
118- metrics ['data' ] = data
119- session .report (metrics )
120- else :
121- raise ValueError ("metric must be a string or a list/tuple of strings" )
122- tuner = tune .Tuner (
123- run ,
124- tune_config = tune .TuneConfig (
125- search_alg = algo ,
126- num_samples = num_samples ,
127- ),
128- run_config = RunConfig (
129- storage_path = storage_path ,
130- name = label ,
131- ),
132- param_space = params ,
133- )
134-
135- results = tuner .fit ()
136- resultsdf = results .get_dataframe ()
137- resultsdf .to_csv ("{}.csv" .format (label ))
138- #return namedtuple('Study', ['algo', 'results'])(algo, results)
139- if clean_checkpoint :
140- os .system ("rm -r {}" .format (storage_path ))
141- return namedtuple ('Study' , ['algo' , 'results' ])(algo .searcher ._ot_study , results )
142-
89+ from warnings import warn
90+ warn ("ray_optuna_search is deprecated, please use ray_search with algorithm='optuna' instead" , DeprecationWarning )
14391"""
14492Parameters
14593:
@@ -183,7 +131,7 @@ def ray_search(dispatcher_constructor: Callable, # constructor for the dispatche
183131 output_path : Optional [str ] = './batch' , # directory for storing generated files
184132 checkpoint_path : Optional [str ] = './checkpoint' , # directory for storing checkpoint files
185133 max_concurrent : Optional [int ] = 1 , # number of concurrent trials to run at one time
186- batch : Optional [bool ] = True , # whether concurrent trials should run synchronously or asynchronously
134+ batch : Optional [bool ] = True , # whether concurrent trials should run synch\ronously or asynchronously
187135 num_samples : Optional [int ] = 1 , # number of trials to run
188136 metric : Optional [str ] = None , # metric to optimize, if not supplied, no data will be collated.
189137 mode : Optional [str ] = "min" , # either 'min' or 'max' (whether to minimize or maximize the metric
@@ -196,18 +144,23 @@ def ray_search(dispatcher_constructor: Callable, # constructor for the dispatche
196144 prune_metadata = True , # whether to prune the metadata from the results.csv
197145 remote_dir : Optional [str ] = None , # absolute path for directory to run the search on (for submissions over SSH)
198146 host : Optional [str ] = None , # host to run the search on
199- key : Optional [str ] = None # key for TOTP generator...
147+ key : Optional [str ] = None , # key for TOTP generator...
148+ file_cleanup : Optional [bool | list | tuple ] = True , # whether to clean up accessory files after the search is completed
149+ advanced_logging : Optional [bool | str ] = True ,
200150 ) -> study :
201151
202152 expected_total = params .pop ('_expected_trials_per_sample' ) * num_samples
203153 if (dispatcher_constructor == runtk .dispatchers .SSHDispatcher ) or \
204154 (dispatcher_constructor == SSHGridDispatcher ):
205- if submit_constructor == submits .SGESubmitSFS :
155+ dispatcher_kwargs = None
156+ if submit_constructor == submits .SGESubmitSSH :
206157 from fabric import connection
207158 dispatcher_kwargs = {'connection' : connection .Connection (host )}
208159 if submit_constructor == submits .SlurmSubmitSSH :
209160 from batchtk .utils import TOTPConnection
210161 dispatcher_kwargs = {'connection' : TOTPConnection (host , key )}
162+ if dispatcher_kwargs == None :
163+ raise ValueError ("for SSH based methods, please provide either 'sftp' or None as the comm_type" )
211164 else :
212165 dispatcher_kwargs = {}
213166 if ray_config is None :
@@ -233,6 +186,18 @@ def ray_search(dispatcher_constructor: Callable, # constructor for the dispatche
233186 #TODO class this object for self calls? cleaner? vs nested functions
234187 #TODO clean up working_dir and excludes
235188 storage_path = get_path (checkpoint_path )
189+ adv_path = None
190+ timestamp = datetime .datetime .now ().strftime ("%Y%m%d_%H%M%S" )
191+ if advanced_logging :
192+ if advanced_logging is True :
193+ advanced_logging = "./" #follows from os.getcwd()
194+ adv_path = get_path ("{}/run_{}" .format (advanced_logging , timestamp ))
195+ if isinstance (advanced_logging , str ):
196+ adv_path = get_path (advanced_logging )
197+ os .makedirs (adv_path , exist_ok = True )
198+
199+ if file_cleanup is True :
200+ file_cleanup = (runtk .SGLOUT , runtk .MSGOUT )
236201 load_path = "{}/{}" .format (storage_path , label )
237202 algo = create_searcher (algorithm , ** algorithm_config ) #concurrency may not be accepted by all algo
238203 #search_alg – The search algorithm to use.
@@ -248,13 +213,29 @@ def ray_search(dispatcher_constructor: Callable, # constructor for the dispatche
248213 #submit.update_templates(
249214 # **run_config
250215 #)
216+ def ray_trial (config , label , dispatcher_constructor , project_path , output_path , submit_constructor ,
217+ dispatcher_kwargs = None , submit_kwargs = None , interval = 60 , data_storage = None , debug_log = None ,
218+ report = ('path' , 'config' , 'data' ), cleanup = (runtk .SGLOUT , runtk .MSGOUT ), check_storage = False ):
219+ debug_log , data_storage = None , None
220+ if adv_path :
221+ debug_log = ScriptLogger (file_out = "{}/trials.log" .format (adv_path ))
222+ data_storage = SQLiteStorage (label = 'trials' , path = adv_path , entries = ('path' , 'config' , 'data' ))
223+ tid = tune .get_context ().get_trial_id ()
224+ tid = tid .split ('_' )[- 1 ] # value for trial (can be int/string)
225+ return trial (
226+ config = config , label = label , tid = tid , dispatcher_constructor = dispatcher_constructor ,
227+ project_path = project_path , output_path = output_path , submit_constructor = submit_constructor ,
228+ dispatcher_kwargs = dispatcher_kwargs , submit_kwargs = submit_kwargs , interval = interval ,
229+ data_storage = data_storage , debug_log = debug_log , report = report , cleanup = cleanup , check_storage = check_storage )
230+
251231 project_path = remote_dir or os .getcwd () # if remote_dir is None, then use the current working directory
252232 def run (config ):
253233 config .update ({'saveFolder' : output_path , 'simLabel' : LABEL_POINTER })
254234 data = ray_trial (config = config , label = label , dispatcher_constructor = dispatcher_constructor ,
255235 project_path = project_path , output_path = output_path , submit_constructor = submit_constructor ,
256236 dispatcher_kwargs = dispatcher_kwargs , submit_kwargs = run_config ,
257- interval = sample_interval , log = None , report = report_config )
237+ interval = sample_interval , report = report_config ,
238+ cleanup = file_cleanup , check_storage = False )
258239 if metric is None :
259240 metrics = {'data' : data , '_none_placeholder' : 0 } #TODO, should include 'config' now with purge_metadata?
260241 session .report (metrics )
@@ -411,7 +392,9 @@ def shim(dispatcher_constructor: Optional[Callable] = None, # constructor for th
411392 prune_metadata : Optional [bool ] = True , # whether to prune the metadata from the results.csv
412393 remote_dir : Optional [str ] = None , # absolute path for directory to run the search on (for submissions over SSH)
413394 host : Optional [str ] = None , # host to run the search on
414- key : Optional [str ] = None # key for TOTP generator...
395+ key : Optional [str ] = None , # key for TOTP generator...
396+ file_cleanup : Optional [bool ] = True , # whether to clean up accessory files after the search is completed
397+ advanced_logging : Optional [bool | str ] = True ,
415398 ) -> Dict :
416399 kwargs = locals ()
417400 if metric is None and algorithm not in ['variant_generator' , 'random' , 'grid' ]:
@@ -464,7 +447,9 @@ def search(dispatcher_constructor: Optional[Callable] = None, # constructor for
464447 prune_metadata : Optional [bool ] = True , # whether to prune the metadata from the results.csv
465448 remote_dir : Optional [str ] = None , # absolute path for directory to run the search on (for submissions over SSH)
466449 host : Optional [str ] = None , # host to run the search on
467- key : Optional [str ] = None # key for TOTP generator.
450+ key : Optional [str ] = None , # key for TOTP generator.
451+ file_cleanup : Optional [bool ] = True , # whether to clean up accessory files after the search is completed
452+ advanced_logging : Optional [bool | str ] = True ,
468453 ) -> study : # results of the search -> study.results (raw tune.ResultGrid), study.data (pandas.DataFrame conversion)
469454 """
470455 search(...)
@@ -495,6 +480,10 @@ def search(dispatcher_constructor: Optional[Callable] = None, # constructor for
495480 remote_dir: Optional[str] = None, # absolute path for directory to run the search on (for submissions over SSH)
496481 host: Optional[str] = None, # host to run the search on (for submissions over SSH)
497482 key: Optional[str] = None # key for TOTP generator (for submissions over SSH)
483+ file_cleanup: Optional[bool] = True, # whether to clean up accessory files after the search is completed
484+ advanced_logging: Optional[bool] = True, # enables advanced logging features, checkpoint_db and log_file.
485+ checkpoint_db: Optional[str] = None, # path for checkpoint db file.
486+ log_file: Optional[str] = None, # path for the log file
498487 Creates (upon completed fitting run...)
499488 -------
500489 <label>.csv: file containing the results of the search
@@ -514,20 +503,20 @@ def search(dispatcher_constructor: Optional[Callable] = None, # constructor for
514503"""
515504SEE:
516505'variant_generator'
517- 'random' -> points to variant_generator
506+ 'random' <- deprecated -> points to variant_generator
518507'ax'
519- 'dragonfly'
520- 'skopt'
508+ 'dragonfly' <- deprecated
509+ 'skopt' <- deprecated
521510'hyperopt'
522511'bayesopt'
523512'bohb'
524513'nevergrad'
525514'optuna'
526515'zoopt'
527- 'sigopt'
516+ 'sigopt' <- deprecated
528517'hebo'
529- 'blendsearch'
530- 'cfo'
518+ 'blendsearch' <- deprecated
519+ 'cfo' <- deprecated
531520"""
532521
533522
0 commit comments