Skip to content

Commit f5bd670

Browse files
committed
feat: improve accelerator validation and container caching
- Group accelerators by category (CPU, GPU, TPU) for container image sharing, reducing redundant builds. - Implement preflight check to validate node pool existence before building containers. - Improve error messages for node selector mismatches in pod scheduling. - Update simple_demo.py to return training loss and use more idiomatic result retrieval. - Add preflight check note to README with link to Quick Start.
1 parent c478d56 commit f5bd670

File tree

5 files changed

+106
-23
lines changed

5 files changed

+106
-23
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ See [examples/Dockerfile.prebuilt](examples/Dockerfile.prebuilt) for a template.
228228

229229
## Supported Accelerators
230230

231+
Note: each accelerator and topology requires [setting up its own NodePool](#quick-start)
232+
as a prerequisite.
233+
231234
### TPUs
232235

233236
| Type | Configurations |

keras_remote/backend/execution.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
101101
self.cluster = cluster
102102
self.namespace = namespace
103103

104+
def validate_preflight(self, ctx: JobContext) -> None:
105+
"""Perform preflight checks before building container or uploading artifacts."""
106+
pass
107+
104108
def submit_job(self, ctx: JobContext) -> Any:
105109
"""Submit a job to the backend. Returns backend-specific job handle."""
106110
raise NotImplementedError
@@ -117,6 +121,16 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
117121
class GKEBackend(BaseK8sBackend):
118122
"""Backend adapter for standard GKE Jobs."""
119123

124+
def validate_preflight(self, ctx: JobContext) -> None:
125+
"""Check if the required node pool exists for the accelerator."""
126+
gke_client.validate_preflight(
127+
accelerator=ctx.accelerator,
128+
project=ctx.project,
129+
cluster=self.cluster,
130+
zone=ctx.zone,
131+
namespace=self.namespace,
132+
)
133+
120134
def submit_job(self, ctx: JobContext) -> Any:
121135
"""Submit job to GKE cluster."""
122136
return gke_client.submit_k8s_job(
@@ -142,6 +156,17 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
142156
class PathwaysBackend(BaseK8sBackend):
143157
"""Backend adapter for ML Pathways using LeaderWorkerSet."""
144158

159+
def validate_preflight(self, ctx: JobContext) -> None:
160+
"""Preflight checks for Pathways (currently same as GKE)."""
161+
# Pathways also runs on GKE nodes with specific labels
162+
gke_client.validate_preflight(
163+
accelerator=ctx.accelerator,
164+
project=ctx.project,
165+
cluster=self.cluster,
166+
zone=ctx.zone,
167+
namespace=self.namespace,
168+
)
169+
145170
def submit_job(self, ctx: JobContext) -> Any:
146171
"""Submit LWS job to GKE cluster."""
147172
return pathways_client.submit_pathways_job(
@@ -289,6 +314,9 @@ def execute_remote(ctx: JobContext, backend: BaseK8sBackend) -> Any:
289314
cluster=backend.cluster,
290315
)
291316

317+
# Preflight check
318+
backend.validate_preflight(ctx)
319+
292320
with tempfile.TemporaryDirectory() as tmpdir:
293321
# Phase 1: Package artifacts
294322
_prepare_artifacts(ctx, tmpdir)

keras_remote/backend/gke_client.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,46 @@ def cleanup_job(job_name, namespace="default"):
179179
logging.warning("Failed to delete job %s: %s", job_name, e.reason)
180180

181181

182+
def validate_preflight(accelerator, project, cluster, zone, namespace="default"):
183+
"""Check if the required node pool exists for the accelerator.
184+
185+
Args:
186+
accelerator: Accelerator string (e.g., 'l4', 'v3-8')
187+
project: GCP project ID
188+
cluster: GKE cluster name
189+
zone: GCP zone
190+
namespace: Kubernetes namespace
191+
192+
Raises:
193+
RuntimeError: If no nodes match the required accelerator selector.
194+
"""
195+
_load_kube_config()
196+
accel_config = _parse_accelerator(accelerator)
197+
node_selector = accel_config.get("node_selector")
198+
199+
if not node_selector:
200+
return # CPU or no selector required
201+
202+
core_v1 = client.CoreV1Api()
203+
try:
204+
# Construct label selector string: "key1=val1,key2=val2"
205+
label_selector = ",".join([f"{k}={v}" for k, v in node_selector.items()])
206+
nodes = core_v1.list_node(label_selector=label_selector)
207+
208+
if not nodes.items:
209+
selector_str = ", ".join([f"{k}: {v}" for k, v in node_selector.items()])
210+
raise RuntimeError(
211+
f"Preflight check failed: No nodes match the accelerator selector: {selector_str}. "
212+
"Check that your GKE cluster has a node pool with the correct accelerator type. "
213+
"See all supported accelerator symbols here: \n"
214+
"https://github.com/keras-team/remote#supported-accelerators"
215+
)
216+
except ApiException as e:
217+
# If we can't list nodes due to permissions, log a warning but proceed
218+
# to avoid blocking users with restricted kubeconfig.
219+
logging.warning("Preflight check: Failed to query nodes: %s", e.reason)
220+
221+
182222
def _parse_accelerator(accelerator):
183223
"""Convert accelerator string to GKE pod spec fields."""
184224
parsed = accelerators.parse_accelerator(accelerator)
@@ -374,7 +414,15 @@ def _check_pod_scheduling(core_v1, job_name, namespace):
374414
"didn't match Pod's node affinity/selector" in msg
375415
or "node selector" in msg.lower()
376416
):
417+
selector = pod.spec.node_selector
418+
selector_str = (
419+
", ".join([f"{k}: {v}" for k, v in selector.items()])
420+
if selector
421+
else "None"
422+
)
377423
raise RuntimeError(
378-
"No nodes match the GPU selector. Check that your node pool "
379-
"has the correct GPU type label."
424+
f"No nodes match the accelerator selector: {selector_str}. "
425+
"Check that your node pool has the correct accelerator type label. "
426+
"See all supported accelerator symbols here: \n"
427+
"https://github.com/keras-team/remote#supported-accelerators"
380428
)

keras_remote/backend/gke_client_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,10 @@ def test_kubeconfig_fallback(self):
356356

357357

358358
class TestCheckPodScheduling(parameterized.TestCase):
359-
def _make_pending_pod(self, message):
359+
def _make_pending_pod(self, message, node_selector=None):
360360
pod = MagicMock()
361361
pod.status.phase = "Pending"
362+
pod.spec.node_selector = node_selector
362363
condition = MagicMock()
363364
condition.type = "PodScheduled"
364365
condition.status = "False"
@@ -371,16 +372,20 @@ def _make_pending_pod(self, message):
371372
testcase_name="insufficient_gpu",
372373
condition_message="Insufficient nvidia.com/gpu",
373374
error_match="No GPU nodes available",
375+
node_selector=None,
374376
),
375377
dict(
376378
testcase_name="node_selector_mismatch",
377379
condition_message="didn't match Pod's node affinity/selector",
378-
error_match="No nodes match",
380+
error_match="No nodes match the accelerator selector: cloud.google.com/gke-accelerator: nvidia-l4",
381+
node_selector={"cloud.google.com/gke-accelerator": "nvidia-l4"},
379382
),
380383
)
381-
def test_scheduling_failure_raises(self, condition_message, error_match):
384+
def test_scheduling_failure_raises(
385+
self, condition_message, error_match, node_selector
386+
):
382387
mock_core = MagicMock()
383-
pod = self._make_pending_pod(condition_message)
388+
pod = self._make_pending_pod(condition_message, node_selector=node_selector)
384389
mock_core.list_namespaced_pod.return_value.items = [pod]
385390

386391
with self.assertRaisesRegex(RuntimeError, error_match):

keras_remote/infra/container_builder.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def get_or_build_container(
4242
Container image URI in Artifact Registry
4343
"""
4444
ar_location = zone_to_ar_location(zone or get_default_zone())
45+
category = accelerators.get_category(accelerator_type)
4546

46-
# Generate deterministic hash from requirements + base image
47+
# Generate deterministic hash from requirements + base image + category
4748
requirements_hash = _hash_requirements(
48-
requirements_path, accelerator_type, base_image
49+
requirements_path, category, base_image
4950
)
5051

51-
# Sanitize accelerator type for image name
52-
sanitized_accel = accelerator_type.replace(":", "-").replace("/", "-")
53-
image_tag = f"{sanitized_accel}-{requirements_hash[:12]}"
52+
# Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
53+
image_tag = f"{category}-{requirements_hash[:12]}"
5454

5555
# Use Artifact Registry
5656
registry = f"{ar_location}-docker.pkg.dev/{project}/keras-remote"
@@ -72,25 +72,25 @@ def get_or_build_container(
7272
return _build_and_push(
7373
base_image,
7474
requirements_path,
75-
accelerator_type,
75+
category,
7676
project,
7777
image_uri,
7878
ar_location,
7979
)
8080

8181

82-
def _hash_requirements(requirements_path, accelerator_type, base_image):
83-
"""Create deterministic hash from requirements + accelerator + remote_runner + base image.
82+
def _hash_requirements(requirements_path, category, base_image):
83+
"""Create deterministic hash from requirements + category + remote_runner + base image.
8484
8585
Args:
8686
requirements_path: Path to requirements.txt (or None)
87-
accelerator_type: TPU/GPU type
87+
category: Accelerator category ('cpu', 'gpu', 'tpu')
8888
base_image: Base Docker image (e.g., 'python:3.12-slim')
8989
9090
Returns:
9191
SHA256 hex digest
9292
"""
93-
content = f"base_image={base_image}\naccelerator={accelerator_type}\n"
93+
content = f"base_image={base_image}\ncategory={category}\n"
9494

9595
if requirements_path and os.path.exists(requirements_path):
9696
with open(requirements_path, "r") as f:
@@ -150,7 +150,7 @@ def _image_exists(image_uri, project):
150150
def _build_and_push(
151151
base_image,
152152
requirements_path,
153-
accelerator_type,
153+
category,
154154
project,
155155
image_uri,
156156
ar_location="us",
@@ -160,7 +160,7 @@ def _build_and_push(
160160
Args:
161161
base_image: Base Docker image
162162
requirements_path: Path to requirements.txt (or None)
163-
accelerator_type: TPU/GPU type
163+
category: Accelerator category ('cpu', 'gpu', 'tpu')
164164
project: GCP project ID
165165
image_uri: Target image URI
166166
ar_location: Artifact Registry multi-region (e.g., 'us')
@@ -173,7 +173,7 @@ def _build_and_push(
173173
dockerfile_content = _generate_dockerfile(
174174
base_image=base_image,
175175
requirements_path=requirements_path,
176-
accelerator_type=accelerator_type,
176+
category=category,
177177
)
178178

179179
dockerfile_path = os.path.join(tmpdir, "Dockerfile")
@@ -255,19 +255,18 @@ def _build_and_push(
255255
raise RuntimeError(f"Build failed with status: {result.status}")
256256

257257

258-
def _generate_dockerfile(base_image, requirements_path, accelerator_type):
258+
def _generate_dockerfile(base_image, requirements_path, category):
259259
"""Generate Dockerfile content based on configuration.
260260
261261
Args:
262262
base_image: Base Docker image
263263
requirements_path: Path to requirements.txt (or None)
264-
accelerator_type: TPU/GPU type
264+
category: Accelerator category ('cpu', 'gpu', 'tpu')
265265
266266
Returns:
267267
Dockerfile content as string
268268
"""
269-
# Determine JAX installation command based on accelerator
270-
category = accelerators.get_category(accelerator_type)
269+
# Determine JAX installation command based on accelerator category
271270
if category == "cpu":
272271
jax_install = "RUN python3 -m pip install jax"
273272
elif category == "tpu":

0 commit comments

Comments
 (0)