Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion keras_remote/cli/infra/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id, pool_name):
# Single-host TPU slices (1 node) must not specify placement_policy;
# multi-host slices require COMPACT placement with an explicit topology.
is_multi_host = tpu.num_nodes > 1
min_nodes = tpu.num_nodes if is_multi_host else 0
# Autoscaling is enabled, so we need to set the min_node_count to 0.
min_nodes = 0

placement = (
gcp.container.NodePoolPlacementPolicyArgs(
Expand Down
16 changes: 4 additions & 12 deletions keras_remote/cli/infra/program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def test_node_count_matches_config(self, gcp_mock):

# Due to multi-host TPU workaround, initial_node_count is equal to num_nodes
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
self.assertEqual(call_kwargs.get("initial_node_count"), 4)
self.assertEqual(call_kwargs.get("initial_node_count"), 0)
autoscaling_kwargs = (
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
)
self.assertEqual(autoscaling_kwargs.get("max_node_count"), 4)
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 4)
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 0)

@mock.patch.object(program, "gcp")
def test_pool_name_passed_through(self, gcp_mock):
Expand Down Expand Up @@ -352,21 +352,13 @@ def test_node_pool_scale_to_zero(
cluster, accelerator, "us-central2-b", "my-project", "test-pool"
)

is_multi_host = getattr(accelerator, "num_nodes", 1) > 1

call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
self.assertEqual(
call_kwargs.get("initial_node_count"),
expected_max_count if is_multi_host else 0,
)
self.assertEqual(call_kwargs.get("initial_node_count"), 0)

autoscaling_kwargs = (
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
)
self.assertEqual(
autoscaling_kwargs.get("min_node_count"),
expected_max_count if is_multi_host else 0,
)
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 0)
self.assertEqual(
autoscaling_kwargs.get("max_node_count"), expected_max_count
)
Expand Down
Loading