2121import pytest
2222
2323from cloudai .configurator import CloudAIGymEnv , GridSearchAgent
24- from cloudai .core import Runner , Test , TestRun , TestScenario , TestTemplateStrategy
24+ from cloudai .core import BaseRunner , Runner , Test , TestRun , TestScenario , TestTemplateStrategy
2525from cloudai .systems .slurm import SlurmSystem
2626from cloudai .workloads .nemo_run import (
2727 Data ,
@@ -45,7 +45,7 @@ def nemorun() -> NeMoRunTestDefinition:
4545
4646
4747@pytest .fixture
48- def setup_env (slurm_system : SlurmSystem , nemorun : NeMoRunTestDefinition ) -> tuple [TestRun , Runner ]:
48+ def setup_env (slurm_system : SlurmSystem , nemorun : NeMoRunTestDefinition ) -> tuple [TestRun , BaseRunner ]:
4949 tdef = nemorun .model_copy (deep = True )
5050 tdef .cmd_args .trainer = Trainer (
5151 max_steps = [1000 , 2000 ],
@@ -81,10 +81,10 @@ def setup_env(slurm_system: SlurmSystem, nemorun: NeMoRunTestDefinition) -> tupl
8181
8282 runner = Runner (mode = "dry-run" , system = slurm_system , test_scenario = test_scenario )
8383
84- return test_run , runner
84+ return test_run , runner . runner
8585
8686
87- def test_observation_space (setup_env ):
87+ def test_observation_space (setup_env : tuple [ TestRun , BaseRunner ] ):
8888 test_run , runner = setup_env
8989 env = CloudAIGymEnv (test_run = test_run , runner = runner )
9090 observation_space = env .define_observation_space ()
@@ -147,7 +147,7 @@ def test_compute_reward_invalid():
147147 assert "Available functions: ['inverse', 'negative', 'identity']" in str (exc_info .value )
148148
149149
150- def test_tr_output_path (setup_env : tuple [TestRun , Runner ]):
150+ def test_tr_output_path (setup_env : tuple [TestRun , BaseRunner ]):
151151 test_run , runner = setup_env
152152 test_run .test .test_definition .cmd_args .data .global_batch_size = 8 # avoid constraint check failure
153153 env = CloudAIGymEnv (test_run = test_run , runner = runner )
@@ -160,7 +160,7 @@ def test_tr_output_path(setup_env: tuple[TestRun, Runner]):
160160 assert env .test_run .output_path .name == "42"
161161
162162
163- def test_action_space (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , Runner ]):
163+ def test_action_space (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , BaseRunner ]):
164164 tr , _ = setup_env
165165 nemorun .cmd_args .trainer = Trainer (
166166 max_steps = [1000 , 2000 ], strategy = TrainerStrategy (tensor_model_parallel_size = [1 , 2 ])
@@ -185,7 +185,7 @@ def test_action_space(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun,
185185
186186
187187@pytest .mark .parametrize ("num_nodes" , (1 , [1 , 2 ], [3 ]))
188- def test_all_combinations (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , Runner ], num_nodes : int ):
188+ def test_all_combinations (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , BaseRunner ], num_nodes : int ):
189189 tr , _ = setup_env
190190 nemorun .cmd_args .trainer = Trainer (max_steps = [1000 ], strategy = TrainerStrategy (tensor_model_parallel_size = [1 , 2 ]))
191191 nemorun .extra_env_vars ["DSE_VAR" ] = ["1" , "2" , "3" ]
@@ -224,7 +224,7 @@ def test_all_combinations(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestR
224224 assert expected in real_combinations , f"Expected { expected } in all_combinations"
225225
226226
227- def test_all_combinations_non_dse (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , Runner ]):
227+ def test_all_combinations_non_dse (nemorun : NeMoRunTestDefinition , setup_env : tuple [TestRun , BaseRunner ]):
228228 tr , _ = setup_env
229229 tr .test .test_definition = nemorun
230230 assert len (tr .all_combinations ) == 0
0 commit comments