Skip to content

Commit 3b222ce

Browse files
spot instance support
1 parent c24041c commit 3b222ce

File tree

11 files changed

+188
-35
lines changed

11 files changed

+188
-35
lines changed

examples/example_gke.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def simple_computation(x, y):
5454

5555

5656
# Example 2: Keras model training on CPU
57-
@keras_remote.run(accelerator="cpu")
57+
@keras_remote.run(accelerator="v6e-2x4", cluster="spot-tpu-nodes", spot=True)
5858
def train_simple_model_cpu():
5959
"""Train a simple Keras model on remote CPU."""
6060

@@ -111,10 +111,10 @@ def main():
111111
print("=" * 60)
112112

113113
# Example 1: Simple computation (CPU)
114-
print("\n--- Example 1: Simple Computation (CPU) ---")
115-
print("Running simple_computation(10, 20) on GKE...")
116-
result = simple_computation(10, 20)
117-
print(f"Result: {result}")
114+
# print("\n--- Example 1: Simple Computation (CPU) ---")
115+
# print("Running simple_computation(10, 20) on GKE...")
116+
# result = simple_computation(10, 20)
117+
# print(f"Result: {result}")
118118

119119
# Example 2: Model training on CPU
120120
print("\n--- Example 2: Keras Model Training (CPU) ---")

keras_remote/backend/execution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class JobContext:
5656
# Data volumes {mount_path: Data}
5757
volumes: Optional[dict] = None
5858

59+
# Configuration modifiers
60+
spot: bool = False
61+
5962
# Artifact paths (set during prepare phase)
6063
payload_path: Optional[str] = None
6164
context_path: Optional[str] = None
@@ -80,6 +83,7 @@ def from_params(
8083
env_vars: dict,
8184
cluster_name: Optional[str] = None,
8285
volumes: Optional[dict] = None,
86+
spot: bool = False,
8387
) -> "JobContext":
8488
"""Factory method with default resolution for zone/project/cluster."""
8589
if not zone:
@@ -105,6 +109,7 @@ def from_params(
105109
project=project,
106110
cluster_name=cluster_name,
107111
volumes=volumes,
112+
spot=spot,
108113
)
109114

110115

@@ -155,6 +160,7 @@ def submit_job(self, ctx: JobContext) -> Any:
155160
job_id=ctx.job_id,
156161
bucket_name=ctx.bucket_name,
157162
namespace=self.namespace,
163+
spot=ctx.spot,
158164
)
159165

160166
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
@@ -191,6 +197,7 @@ def submit_job(self, ctx: JobContext) -> Any:
191197
job_id=ctx.job_id,
192198
bucket_name=ctx.bucket_name,
193199
namespace=self.namespace,
200+
spot=ctx.spot,
194201
)
195202

196203
def wait_for_job(self, job: Any, ctx: JobContext) -> None:

keras_remote/backend/gke_client.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def submit_k8s_job(
2323
job_id,
2424
bucket_name,
2525
namespace="default",
26+
spot=False,
2627
):
2728
"""Submit a Kubernetes Job to GKE cluster.
2829
@@ -42,7 +43,7 @@ def submit_k8s_job(
4243
_load_kube_config()
4344

4445
# Parse accelerator configuration
45-
accel_config = _parse_accelerator(accelerator)
46+
accel_config = _parse_accelerator(accelerator, spot=spot)
4647

4748
# Create job specification
4849
job_name = f"keras-remote-{job_id}"
@@ -224,9 +225,9 @@ def validate_preflight(
224225
logging.warning("Preflight check: Failed to query nodes: %s", e.reason)
225226

226227

227-
def _parse_accelerator(accelerator):
228+
def _parse_accelerator(accelerator, spot=False):
228229
"""Convert accelerator string to GKE pod spec fields."""
229-
parsed = accelerators.parse_accelerator(accelerator)
230+
parsed = accelerators.parse_accelerator(accelerator, spot=spot)
230231

231232
if parsed is None:
232233
return {
@@ -241,7 +242,7 @@ def _parse_accelerator(accelerator):
241242
# For TPU Podslices (multi-node), resource requests must be per-node.
242243
# num_nodes is 1 for single-host TPUs (v3-8, v4-8, v5litepod-1/4/8).
243244
chips_per_node = parsed.chips // parsed.num_nodes
244-
return {
245+
config = {
245246
"node_selector": {
246247
"cloud.google.com/gke-tpu-accelerator": parsed.gke_accelerator,
247248
"cloud.google.com/gke-tpu-topology": parsed.topology,
@@ -254,8 +255,20 @@ def _parse_accelerator(accelerator):
254255
"jax_platform": "tpu",
255256
}
256257

258+
if parsed.spot:
259+
config["node_selector"]["cloud.google.com/gke-spot"] = "true"
260+
config["tolerations"].append(
261+
{
262+
"key": "cloud.google.com/gke-spot",
263+
"operator": "Equal",
264+
"value": "true",
265+
"effect": "NoSchedule",
266+
}
267+
)
268+
return config
269+
257270
# GpuConfig
258-
return {
271+
config = {
259272
"node_selector": {"cloud.google.com/gke-accelerator": parsed.gke_label},
260273
"resource_limits": {"nvidia.com/gpu": str(parsed.count)},
261274
"resource_requests": {"nvidia.com/gpu": str(parsed.count)},
@@ -264,6 +277,17 @@ def _parse_accelerator(accelerator):
264277
],
265278
"jax_platform": "gpu",
266279
}
280+
if parsed.spot:
281+
config["node_selector"]["cloud.google.com/gke-spot"] = "true"
282+
config["tolerations"].append(
283+
{
284+
"key": "cloud.google.com/gke-spot",
285+
"operator": "Equal",
286+
"value": "true",
287+
"effect": "NoSchedule",
288+
}
289+
)
290+
return config
267291

268292

269293
def _load_kube_config():
@@ -441,6 +465,10 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
441465
config_dict = pool.get("config", {})
442466
pool_labels = config_dict.get("labels", {}).copy()
443467

468+
# Spot VM mapping
469+
if config_dict.get("spot"):
470+
pool_labels["cloud.google.com/gke-spot"] = "true"
471+
444472
# Map GKE injected node labels for accelerators mapping
445473
accel_config_list = config_dict.get("accelerators", [])
446474
if accel_config_list:
@@ -450,6 +478,13 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
450478
else:
451479
pool_labels["cloud.google.com/gke-accelerator"] = accel_type
452480

481+
# TPU topology mapping from placement policy
482+
placement_policy = pool.get("placementPolicy", {})
483+
if placement_policy and placement_policy.get("tpuTopology"):
484+
pool_labels["cloud.google.com/gke-tpu-topology"] = placement_policy[
485+
"tpuTopology"
486+
]
487+
453488
# TPU mapping fallback
454489
machine_type = config_dict.get("machineType", "")
455490

@@ -460,7 +495,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
460495
"goog-gke-accelerator-type"
461496
]
462497

463-
if machine_type.startswith("ct"):
498+
if machine_type.startswith("ct") and not pool_labels.get(
499+
"cloud.google.com/gke-tpu-topology"
500+
):
464501
# We roughly map TPU topology presence for preflight
465502
pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get(
466503
"cloud.google.com/gke-tpu-topology", ""

keras_remote/backend/gke_client_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,38 @@ def test_tpu_v5litepod_4(self):
7878
)
7979
self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"})
8080

81+
def test_spot_gpu(self):
82+
result = _parse_accelerator("l4:spot")
83+
self.assertEqual(
84+
result["node_selector"]["cloud.google.com/gke-spot"], "true"
85+
)
86+
# Check for spot toleration
87+
spot_tol = [
88+
t
89+
for t in result["tolerations"]
90+
if t.get("key") == "cloud.google.com/gke-spot"
91+
]
92+
self.assertLen(spot_tol, 1)
93+
self.assertEqual(spot_tol[0]["value"], "true")
94+
95+
def test_spot_tpu(self):
96+
result = _parse_accelerator("v6e-8:spot")
97+
self.assertEqual(
98+
result["node_selector"]["cloud.google.com/gke-spot"], "true"
99+
)
100+
# Check for spot toleration
101+
spot_tol = [
102+
t
103+
for t in result["tolerations"]
104+
if t.get("key") == "cloud.google.com/gke-spot"
105+
]
106+
self.assertLen(spot_tol, 1)
107+
self.assertEqual(spot_tol[0]["value"], "true")
108+
# Should still have TPU toleration
109+
self.assertTrue(
110+
any(t.get("key") == "google.com/tpu" for t in result["tolerations"])
111+
)
112+
81113

82114
class TestCreateJobSpec(absltest.TestCase):
83115
def _make_gpu_config(self):

keras_remote/backend/pathways_client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def submit_pathways_job(
5353
job_id,
5454
bucket_name,
5555
namespace="default",
56+
spot=False,
5657
):
5758
"""Submit a LeaderWorkerSet to GKE cluster.
5859
@@ -71,12 +72,10 @@ def submit_pathways_job(
7172
_load_kube_config()
7273
lws_version = _get_lws_version()
7374

74-
accel_config = _parse_accelerator(accelerator)
75+
parsed_config = accelerators.parse_accelerator(accelerator, spot=spot)
76+
accel_config = _parse_accelerator(accelerator, spot=spot)
7577
job_name = _get_job_name(job_id)
7678

77-
# Extract num nodes from the TPU configuration
78-
79-
parsed_config = accelerators.parse_accelerator(accelerator)
8079
if (
8180
isinstance(parsed_config, accelerators.TpuConfig)
8281
and parsed_config.num_nodes > 1
@@ -263,10 +262,12 @@ def _create_lws_spec(
263262
{"name": "TPU_WORKER_ID", "value": "$(LWS_WORKER_INDEX)"},
264263
]
265264

266-
tolerations = [
267-
{"key": t["key"], "operator": t["operator"], "effect": t["effect"]}
268-
for t in accel_config["tolerations"]
269-
]
265+
tolerations = []
266+
for t in accel_config["tolerations"]:
267+
entry = {"key": t["key"], "operator": t["operator"], "effect": t["effect"]}
268+
if "value" in t:
269+
entry["value"] = t["value"]
270+
tolerations.append(entry)
270271

271272
pod_template = {
272273
"metadata": {

keras_remote/backend/pathways_client_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,33 @@ def test_env_vars(self):
143143
self.assertEqual(env["MEGASCALE_NUM_SLICES"], "4")
144144
self.assertEqual(env["TPU_WORKER_ID"], "$(LWS_WORKER_INDEX)")
145145

146+
def test_spot_spec(self):
147+
"""Test that spot selectors and tolerations are added when present."""
148+
accel_config = self._make_tpu_accel_config()
149+
accel_config["node_selector"]["cloud.google.com/gke-spot"] = "true"
150+
accel_config["tolerations"].append(
151+
{
152+
"key": "cloud.google.com/gke-spot",
153+
"operator": "Equal",
154+
"value": "true",
155+
"effect": "NoSchedule",
156+
}
157+
)
158+
159+
spec = self._make_spec(accel_config=accel_config)
160+
pod_spec = spec["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"]
161+
162+
self.assertEqual(
163+
pod_spec["nodeSelector"]["cloud.google.com/gke-spot"], "true"
164+
)
165+
spot_tol = [
166+
t
167+
for t in pod_spec["tolerations"]
168+
if t.get("key") == "cloud.google.com/gke-spot"
169+
]
170+
self.assertLen(spot_tol, 1)
171+
self.assertEqual(spot_tol[0]["value"], "true")
172+
146173
def test_tpu_accel_config(self):
147174
"""Test resources, tolerations, and node selector for TPU config."""
148175
spec = self._make_spec(accel_config=self._make_tpu_accel_config())

keras_remote/cli/commands/pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ def pool():
2929
"v5litepod, v5p, v6e, v3 (with optional count/topology)",
3030
)
3131
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
32-
def pool_add(project, zone, cluster_name, accelerator, yes):
32+
@click.option("--spot", is_flag=True, help="Use Spot VMs for node pool")
33+
def pool_add(project, zone, cluster_name, accelerator, yes, spot):
3334
"""Add an accelerator node pool to the cluster."""
3435
banner("keras-remote Pool Add")
3536

3637
# Parse the accelerator spec first to fail fast on bad input.
3738
try:
38-
accel_config = accelerators.parse_accelerator(accelerator)
39+
accel_config = accelerators.parse_accelerator(accelerator, spot=spot)
3940
except ValueError as e:
4041
raise click.BadParameter(str(e), param_hint="--accelerator") from e
4142

keras_remote/cli/infra/program.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id, pool_name):
237237
],
238238
labels={RESOURCE_NAME_PREFIX: "true"},
239239
max_run_duration=f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours
240+
spot=gpu.spot,
240241
),
241242
)
242243

@@ -276,7 +277,10 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id, pool_name):
276277
machine_type=tpu.machine_type,
277278
oauth_scopes=_BASE_OAUTH_SCOPES,
278279
labels={RESOURCE_NAME_PREFIX: "true"},
279-
max_run_duration=f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours
280+
max_run_duration=None
281+
if tpu.spot
282+
else f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours
283+
spot=tpu.spot,
280284
),
281285
placement_policy=placement,
282286
)

0 commit comments

Comments
 (0)