Skip to content

Commit e6ab84b

Browse files
Sets min_nodes to 0 multi-host configs (#88)
* Sets min_nodes to 0 multi-host configs * fix tests
1 parent 83f83c6 commit e6ab84b

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

keras_remote/cli/infra/program.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id, pool_name):
246246
# Single-host TPU slices (1 node) must not specify placement_policy;
247247
# multi-host slices require COMPACT placement with an explicit topology.
248248
is_multi_host = tpu.num_nodes > 1
249-
min_nodes = tpu.num_nodes if is_multi_host else 0
249+
# Autoscaling is enabled, so we need to set the min_node_count to 0.
250+
min_nodes = 0
250251

251252
placement = (
252253
gcp.container.NodePoolPlacementPolicyArgs(

keras_remote/cli/infra/program_test.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def test_node_count_matches_config(self, gcp_mock):
8080

8181
# Due to multi-host TPU workaround, initial_node_count is equal to num_nodes
8282
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
83-
self.assertEqual(call_kwargs.get("initial_node_count"), 4)
83+
self.assertEqual(call_kwargs.get("initial_node_count"), 0)
8484
autoscaling_kwargs = (
8585
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
8686
)
8787
self.assertEqual(autoscaling_kwargs.get("max_node_count"), 4)
88-
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 4)
88+
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 0)
8989

9090
@mock.patch.object(program, "gcp")
9191
def test_pool_name_passed_through(self, gcp_mock):
@@ -352,21 +352,13 @@ def test_node_pool_scale_to_zero(
352352
cluster, accelerator, "us-central2-b", "my-project", "test-pool"
353353
)
354354

355-
is_multi_host = getattr(accelerator, "num_nodes", 1) > 1
356-
357355
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
358-
self.assertEqual(
359-
call_kwargs.get("initial_node_count"),
360-
expected_max_count if is_multi_host else 0,
361-
)
356+
self.assertEqual(call_kwargs.get("initial_node_count"), 0)
362357

363358
autoscaling_kwargs = (
364359
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
365360
)
366-
self.assertEqual(
367-
autoscaling_kwargs.get("min_node_count"),
368-
expected_max_count if is_multi_host else 0,
369-
)
361+
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 0)
370362
self.assertEqual(
371363
autoscaling_kwargs.get("max_node_count"), expected_max_count
372364
)

0 commit comments

Comments
 (0)