@@ -80,12 +80,12 @@ def test_node_count_matches_config(self, gcp_mock):
8080
8181 # Due to multi-host TPU workaround, initial_node_count is equal to num_nodes
8282 call_kwargs = gcp_mock .container .NodePool .call_args .kwargs
83- self .assertEqual (call_kwargs .get ("initial_node_count" ), 4 )
83+ self .assertEqual (call_kwargs .get ("initial_node_count" ), 0 )
8484 autoscaling_kwargs = (
8585 gcp_mock .container .NodePoolAutoscalingArgs .call_args .kwargs
8686 )
8787 self .assertEqual (autoscaling_kwargs .get ("max_node_count" ), 4 )
88- self .assertEqual (autoscaling_kwargs .get ("min_node_count" ), 4 )
88+ self .assertEqual (autoscaling_kwargs .get ("min_node_count" ), 0 )
8989
9090 @mock .patch .object (program , "gcp" )
9191 def test_pool_name_passed_through (self , gcp_mock ):
@@ -352,21 +352,13 @@ def test_node_pool_scale_to_zero(
352352 cluster , accelerator , "us-central2-b" , "my-project" , "test-pool"
353353 )
354354
355- is_multi_host = getattr (accelerator , "num_nodes" , 1 ) > 1
356-
357355 call_kwargs = gcp_mock .container .NodePool .call_args .kwargs
358- self .assertEqual (
359- call_kwargs .get ("initial_node_count" ),
360- expected_max_count if is_multi_host else 0 ,
361- )
356+ self .assertEqual (call_kwargs .get ("initial_node_count" ), 0 )
362357
363358 autoscaling_kwargs = (
364359 gcp_mock .container .NodePoolAutoscalingArgs .call_args .kwargs
365360 )
366- self .assertEqual (
367- autoscaling_kwargs .get ("min_node_count" ),
368- expected_max_count if is_multi_host else 0 ,
369- )
361+ self .assertEqual (autoscaling_kwargs .get ("min_node_count" ), 0 )
370362 self .assertEqual (
371363 autoscaling_kwargs .get ("max_node_count" ), expected_max_count
372364 )
0 commit comments