Skip to content

Commit 97df36f

Browse files
reformat code
1 parent da077fb commit 97df36f

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

keras_remote/backend/gke_client.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""GKE job submission for keras_remote."""
22

3+
import json
4+
import subprocess
35
import time
46
from contextlib import suppress
5-
import subprocess
6-
import json
77

88
from absl import logging
99
from kubernetes import client, config
@@ -400,9 +400,11 @@ def _validate_node_pool_exists(selector: dict) -> bool:
400400
"""Use gcloud to verify that a GKE NodePool matches the pod node selector."""
401401
try:
402402
# Requires gcloud CLI and valid credentials.
403-
out = subprocess.check_output([
404-
"gcloud", "container", "node-pools", "list",
405-
"--format", "json"], text=True, stderr=subprocess.DEVNULL)
403+
out = subprocess.check_output(
404+
["gcloud", "container", "node-pools", "list", "--format", "json"],
405+
text=True,
406+
stderr=subprocess.DEVNULL,
407+
)
406408
pools = json.loads(out)
407409
for pool in pools:
408410
config = pool.get("config", {})
@@ -415,6 +417,7 @@ def _validate_node_pool_exists(selector: dict) -> bool:
415417
# If gcloud is missing or unauthenticated, degrade gracefully and assume pool exists
416418
return True
417419

420+
418421
def _check_pod_scheduling(core_v1, job_name, namespace):
419422
"""Check for pod scheduling issues and raise helpful errors."""
420423
with suppress(ApiException):
@@ -429,9 +432,15 @@ def _check_pod_scheduling(core_v1, job_name, namespace):
429432
if "Insufficient nvidia.com/gpu" in msg:
430433
selector = pod.spec.node_selector or {}
431434
if not _validate_node_pool_exists(selector):
432-
selector_str = ", ".join([f"{k}: {v}" for k, v in selector.items()]) if selector else "None"
433-
raise RuntimeError(f"No GKE node pool exists with selector '{selector_str}'. "
434-
"Please use 'keras-remote pool add' to configure this accelerator.")
435+
selector_str = (
436+
", ".join([f"{k}: {v}" for k, v in selector.items()])
437+
if selector
438+
else "None"
439+
)
440+
raise RuntimeError(
441+
f"No GKE node pool exists with selector '{selector_str}'. "
442+
"Please use 'keras-remote pool add' to configure this accelerator."
443+
)
435444
logging.info(
436445
f"Pod {pod.metadata.name} is Pending: Insufficient nvidia.com/gpu. "
437446
"Waiting for GKE Cluster Autoscaler to provision a new node... (scale-to-zero)\n"
@@ -442,12 +451,18 @@ def _check_pod_scheduling(core_v1, job_name, namespace):
442451
or "node selector" in msg.lower()
443452
):
444453
selector = pod.spec.node_selector or {}
445-
selector_str = ", ".join([f"{k}: {v}" for k, v in selector.items()]) if selector else "None"
446-
454+
selector_str = (
455+
", ".join([f"{k}: {v}" for k, v in selector.items()])
456+
if selector
457+
else "None"
458+
)
459+
447460
if not _validate_node_pool_exists(selector):
448-
raise RuntimeError(f"No GKE node pool exists with selector '{selector_str}'. "
449-
"Please use 'keras-remote pool add' to configure this accelerator.")
450-
461+
raise RuntimeError(
462+
f"No GKE node pool exists with selector '{selector_str}'. "
463+
"Please use 'keras-remote pool add' to configure this accelerator."
464+
)
465+
451466
logging.info(
452467
f"Pod {pod.metadata.name} is Pending: No currently running nodes "
453468
f"match accelerator selector '{selector_str}'. "

keras_remote/cli/infra/program_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,18 @@ def test_node_pool_scale_to_zero(
350350
is_multi_host = getattr(accelerator, "num_nodes", 1) > 1
351351

352352
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
353-
self.assertEqual(call_kwargs.get("initial_node_count"), expected_max_count if is_multi_host else 0)
353+
self.assertEqual(
354+
call_kwargs.get("initial_node_count"),
355+
expected_max_count if is_multi_host else 0,
356+
)
354357

355358
autoscaling_kwargs = (
356359
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
357360
)
358-
self.assertEqual(autoscaling_kwargs.get("min_node_count"), expected_max_count if is_multi_host else 0)
361+
self.assertEqual(
362+
autoscaling_kwargs.get("min_node_count"),
363+
expected_max_count if is_multi_host else 0,
364+
)
359365
self.assertEqual(
360366
autoscaling_kwargs.get("max_node_count"), expected_max_count
361367
)

0 commit comments

Comments
 (0)