@@ -171,7 +171,7 @@ def train_mnist(config):
171171 model .parameters (), lr = config ["lr" ], momentum = config ["momentum" ]
172172 )
173173
174- while True :
174+ for _ in range ( 5 ) :
175175 train_func (model , optimizer , train_loader , device )
176176 acc = test_func (model , test_loader , device )
177177 metrics = {"mean_accuracy" : acc }
@@ -180,12 +180,33 @@ def train_mnist(config):
180180 if should_checkpoint :
181181 with tempfile .TemporaryDirectory () as tempdir :
182182 torch .save (model .state_dict (), os .path .join (tempdir , "model.pt" ))
183- train .report (metrics , checkpoint = Checkpoint .from_directory (tempdir ))
183+ tune .report (metrics , checkpoint = Checkpoint .from_directory (tempdir ))
184184 else :
185- train .report (metrics )
185+ tune .report (metrics )
186186
187187
188188if __name__ == "__main__" :
189+ import os as _os
190+ # Ray 2.35.0's get_air_verbosity() expects int or AirVerbosity enum, but the
191+ # RHOAI cluster sets AIR_VERBOSITY as a plain string. Patch at the source so
192+ # it works regardless of when/how the env-var is re-injected (e.g. via ray.init).
193+ try :
194+ import ray .tune .experimental .output as _ray_output
195+ import ray .tune .tune as _ray_tune_module
196+ _orig_gav = _ray_output .get_air_verbosity
197+ def _fixed_gav (verbose ):
198+ if isinstance (verbose , str ):
199+ try :
200+ verbose = int (verbose )
201+ except (ValueError , TypeError ):
202+ verbose = 2
203+ return _orig_gav (verbose )
204+ _ray_output .get_air_verbosity = _fixed_gav
205+ _ray_tune_module .get_air_verbosity = _fixed_gav
206+ except Exception :
207+ pass
208+ _os .environ .pop ("AIR_VERBOSITY" , None )
209+
189210 # for early stopping
190211 sched = AsyncHyperBandScheduler ()
191212 gpu_value = int ("has to be specified" )
@@ -198,12 +219,8 @@ def train_mnist(config):
198219 scheduler = sched ,
199220 num_samples = 5 ,
200221 ),
201- run_config = train .RunConfig (
222+ run_config = tune .RunConfig (
202223 name = "exp" ,
203- stop = {
204- "mean_accuracy" : 0.98 ,
205- "training_iteration" : 5 ,
206- },
207224 ),
208225 param_space = {
209226 "lr" : tune .loguniform (1e-4 , 1e-2 ),
0 commit comments