@@ -1208,6 +1208,7 @@ def test_generate_runner_model_sweep_config(sample_master_config, temp_config_fi
12081208 class Args :
12091209 runner_type = "h200"
12101210 runner_config = runner_file
1211+ runner_node_filter = None
12111212
12121213 result = generate_runner_model_sweep_config (Args (), sample_master_config )
12131214 assert len (result ) > 0
@@ -1224,11 +1225,72 @@ def test_generate_runner_model_sweep_config_invalid_runner(sample_master_config,
12241225 class Args :
12251226 runner_type = "invalid-runner"
12261227 runner_config = runner_file
1228+ runner_node_filter = None
12271229
12281230 with pytest .raises (ValueError , match = "does not exist in runner config" ):
12291231 generate_runner_model_sweep_config (Args (), sample_master_config )
12301232
12311233
1234+ def test_generate_runner_model_sweep_config_with_node_filter (sample_master_config , temp_config_files ):
1235+ """Test runner-model sweep with runner node filter."""
1236+ _ , runner_file = temp_config_files
1237+
1238+ class Args :
1239+ runner_type = "h200"
1240+ runner_config = runner_file
1241+ runner_node_filter = "nv_1"
1242+
1243+ result = generate_runner_model_sweep_config (Args (), sample_master_config )
1244+ # Should only have entries for h200-nv_1
1245+ runners = set (entry ['runner' ] for entry in result )
1246+ assert 'h200-nv_1' in runners
1247+ assert 'h200-nv_2' not in runners
1248+
1249+
1250+ def test_generate_runner_model_sweep_config_with_node_filter_multiple_matches (sample_master_config , temp_config_files ):
1251+ """Test runner-model sweep with runner node filter matching multiple nodes."""
1252+ _ , runner_file = temp_config_files
1253+
1254+ class Args :
1255+ runner_type = "h200"
1256+ runner_config = runner_file
1257+ runner_node_filter = "nv" # Should match both nv_1 and nv_2
1258+
1259+ result = generate_runner_model_sweep_config (Args (), sample_master_config )
1260+ runners = set (entry ['runner' ] for entry in result )
1261+ assert 'h200-nv_1' in runners
1262+ assert 'h200-nv_2' in runners
1263+
1264+
1265+ def test_generate_runner_model_sweep_config_with_node_filter_no_matches (sample_master_config , temp_config_files ):
1266+ """Test runner-model sweep with runner node filter that matches no nodes."""
1267+ _ , runner_file = temp_config_files
1268+
1269+ class Args :
1270+ runner_type = "h200"
1271+ runner_config = runner_file
1272+ runner_node_filter = "nonexistent"
1273+
1274+ with pytest .raises (ValueError , match = "No runner nodes found matching filter" ):
1275+ generate_runner_model_sweep_config (Args (), sample_master_config )
1276+
1277+
1278+ def test_generate_runner_model_sweep_config_without_node_filter (sample_master_config , temp_config_files ):
1279+ """Test runner-model sweep without runner node filter (default behavior)."""
1280+ _ , runner_file = temp_config_files
1281+
1282+ class Args :
1283+ runner_type = "h200"
1284+ runner_config = runner_file
1285+ runner_node_filter = None
1286+
1287+ result = generate_runner_model_sweep_config (Args (), sample_master_config )
1288+ # Should have entries for all h200 nodes
1289+ runners = set (entry ['runner' ] for entry in result )
1290+ assert 'h200-nv_1' in runners
1291+ assert 'h200-nv_2' in runners
1292+
1293+
12321294# Tests for generate_runner_sweep_config
12331295def test_generate_runner_sweep_config (sample_master_config , temp_config_files ):
12341296 """Test runner sweep config generation."""
@@ -1387,6 +1449,27 @@ def test_main_runner_model_sweep(temp_config_files):
13871449 assert len (result ) > 0
13881450
13891451
1452+ def test_main_runner_model_sweep_with_node_filter (temp_config_files ):
1453+ """Test main function with runner-model-sweep command with node filter."""
1454+ master_file , runner_file = temp_config_files
1455+
1456+ test_args = [
1457+ "generate_sweep_configs.py" ,
1458+ "runner-model-sweep" ,
1459+ "--config-files" , master_file ,
1460+ "--runner-config" , runner_file ,
1461+ "--runner-type" , "h200" ,
1462+ "--runner-node-filter" , "nv_1"
1463+ ]
1464+
1465+ with patch ('sys.argv' , test_args ):
1466+ result = main ()
1467+ assert len (result ) > 0
1468+ runners = set (entry ['runner' ] for entry in result )
1469+ assert 'h200-nv_1' in runners
1470+ assert 'h200-nv_2' not in runners
1471+
1472+
13901473def test_main_runner_sweep (temp_config_files ):
13911474 """Test main function with runner-sweep command."""
13921475 master_file , runner_file = temp_config_files
0 commit comments