diff --git a/examples/example_gke.py b/examples/example_gke.py index 08bd0a5..5901b5b 100644 --- a/examples/example_gke.py +++ b/examples/example_gke.py @@ -54,7 +54,7 @@ def simple_computation(x, y): # Example 2: Keras model training on CPU -@keras_remote.run(accelerator="cpu") +@keras_remote.run(accelerator="v6e-2x4", cluster="spot-tpu-nodes", spot=True) def train_simple_model_cpu(): """Train a simple Keras model on remote CPU.""" @@ -111,10 +111,10 @@ def main(): print("=" * 60) # Example 1: Simple computation (CPU) - print("\n--- Example 1: Simple Computation (CPU) ---") - print("Running simple_computation(10, 20) on GKE...") - result = simple_computation(10, 20) - print(f"Result: {result}") + # print("\n--- Example 1: Simple Computation (CPU) ---") + # print("Running simple_computation(10, 20) on GKE...") + # result = simple_computation(10, 20) + # print(f"Result: {result}") # Example 2: Model training on CPU print("\n--- Example 2: Keras Model Training (CPU) ---") diff --git a/examples/pathways_example.py b/examples/pathways_example.py index 0a34628..c4ba1aa 100644 --- a/examples/pathways_example.py +++ b/examples/pathways_example.py @@ -10,10 +10,55 @@ # A simple model that will be executed remotely on pathways -@keras_remote.run(accelerator="v5litepod-1", backend="pathways") +@keras_remote.run( + accelerator="v6e-16", backend="pathways", cluster="keras-team-dogfood" +) def train_simple_model(): + import jax + from jax import lax + print("Running Pathways job on JAX Backend!") + # Verify distributed JAX setup (Pathways auto-initialization) + process_count = jax.process_count() + process_index = jax.process_index() + device_count = jax.device_count() + local_device_count = jax.local_device_count() + + print("JAX Distributed Environment:") + print(f" Process Count: {process_count}") + print(f" Process Index: {process_index}") + print(f" Total Devices: {device_count}") + print(f" Local Devices: {local_device_count}") + + # Fail if not actually running on multiple hosts + if process_count <= 1: + raise RuntimeError( + f"Pathways verification failed: Expected > 1 processes, but found {process_count}. " + "This indicates the job is NOT running in a multi-host Pathways environment." + ) + + # Verify collective communication (cross-host psum) + try: + # Use jax.pmap to sum values across all devices in the cluster + x = np.ones(local_device_count) + distributed_sum = jax.pmap(lambda val: lax.psum(val, "i"), axis_name="i")(x) + total_sum = distributed_sum[0] + + if total_sum != device_count: + raise RuntimeError( + f"Collective verification failed: Expected psum {device_count}, got {total_sum}" + ) + print( + f"Successfully verified collective communication across all {total_sum} devices!" + ) + except Exception as e: + print(f"Warning: Collective verification failed: {e}") + if isinstance(e, RuntimeError) and "Collective verification failed" in str( + e + ): + raise + # Create a simple dataset x = np.random.rand(1000, 10) y = np.random.randint(0, 2, size=(1000, 1)) diff --git a/keras_remote/backend/execution.py b/keras_remote/backend/execution.py index ae9b985..5c30aa8 100644 --- a/keras_remote/backend/execution.py +++ b/keras_remote/backend/execution.py @@ -56,6 +56,9 @@ class JobContext: # Data volumes {mount_path: Data} volumes: Optional[dict] = None + # Configuration modifiers + spot: bool = False + # Artifact paths (set during prepare phase) payload_path: Optional[str] = None context_path: Optional[str] = None @@ -80,6 +83,7 @@ def from_params( env_vars: dict, cluster_name: Optional[str] = None, volumes: Optional[dict] = None, + spot: bool = False, ) -> "JobContext": """Factory method with default resolution for zone/project/cluster.""" if not zone: @@ -105,6 +109,7 @@ def from_params( project=project, cluster_name=cluster_name, volumes=volumes, + spot=spot, ) @@ -155,6 +160,7 @@ def submit_job(self, ctx: JobContext) -> Any: job_id=ctx.job_id, bucket_name=ctx.bucket_name, namespace=self.namespace, + spot=ctx.spot, ) def wait_for_job(self, job: Any, ctx: JobContext) -> None: @@ -191,6 +197,7 @@ def submit_job(self, ctx: JobContext) -> Any: job_id=ctx.job_id, bucket_name=ctx.bucket_name, namespace=self.namespace, + spot=ctx.spot, ) def wait_for_job(self, job: Any, ctx: JobContext) -> None: diff --git a/keras_remote/backend/gke_client.py b/keras_remote/backend/gke_client.py index 58f11c8..d557a31 100644 --- a/keras_remote/backend/gke_client.py +++ b/keras_remote/backend/gke_client.py @@ -23,6 +23,7 @@ def submit_k8s_job( job_id, bucket_name, namespace="default", + spot=False, ): """Submit a Kubernetes Job to GKE cluster. @@ -42,7 +43,7 @@ def submit_k8s_job( _load_kube_config() # Parse accelerator configuration - accel_config = _parse_accelerator(accelerator) + accel_config = _parse_accelerator(accelerator, spot=spot) # Create job specification job_name = f"keras-remote-{job_id}" @@ -224,9 +225,9 @@ def validate_preflight( logging.warning("Preflight check: Failed to query nodes: %s", e.reason) -def _parse_accelerator(accelerator): +def _parse_accelerator(accelerator, spot=False): """Convert accelerator string to GKE pod spec fields.""" - parsed = accelerators.parse_accelerator(accelerator) + parsed = accelerators.parse_accelerator(accelerator, spot=spot) if parsed is None: return { @@ -238,21 +239,36 @@ def _parse_accelerator(accelerator): } if isinstance(parsed, TpuConfig): - return { + # For TPU Podslices (multi-node), resource requests must be per-node. + # num_nodes is 1 for single-host TPUs (v3-8, v4-8, v5litepod-1/4/8). + chips_per_node = parsed.chips // parsed.num_nodes + config = { "node_selector": { "cloud.google.com/gke-tpu-accelerator": parsed.gke_accelerator, "cloud.google.com/gke-tpu-topology": parsed.topology, }, - "resource_limits": {"google.com/tpu": str(parsed.chips)}, - "resource_requests": {"google.com/tpu": str(parsed.chips)}, + "resource_limits": {"google.com/tpu": str(chips_per_node)}, + "resource_requests": {"google.com/tpu": str(chips_per_node)}, "tolerations": [ {"key": "google.com/tpu", "operator": "Exists", "effect": "NoSchedule"} ], "jax_platform": "tpu", } + if parsed.spot: + config["node_selector"]["cloud.google.com/gke-spot"] = "true" + config["tolerations"].append( + { + "key": "cloud.google.com/gke-spot", + "operator": "Equal", + "value": "true", + "effect": "NoSchedule", + } + ) + return config + # GpuConfig - return { + config = { "node_selector": {"cloud.google.com/gke-accelerator": parsed.gke_label}, "resource_limits": {"nvidia.com/gpu": str(parsed.count)}, "resource_requests": {"nvidia.com/gpu": str(parsed.count)}, @@ -261,6 +277,17 @@ def _parse_accelerator(accelerator): ], "jax_platform": "gpu", } + if parsed.spot: + config["node_selector"]["cloud.google.com/gke-spot"] = "true" + config["tolerations"].append( + { + "key": "cloud.google.com/gke-spot", + "operator": "Equal", + "value": "true", + "effect": "NoSchedule", + } + ) + return config def _load_kube_config(): @@ -330,8 +357,10 @@ def _create_job_spec( ], env=env_vars, resources=client.V1ResourceRequirements( - limits=accel_config["resource_limits"], - requests=accel_config["resource_requests"], + limits={k: str(v) for k, v in accel_config["resource_limits"].items()}, + requests={ + k: str(v) for k, v in accel_config["resource_requests"].items() + }, ), ) @@ -436,6 +465,10 @@ def _check_node_pool_exists_cached(selector_items) -> bool: config_dict = pool.get("config", {}) pool_labels = config_dict.get("labels", {}).copy() + # Spot VM mapping + if config_dict.get("spot"): + pool_labels["cloud.google.com/gke-spot"] = "true" + # Map GKE injected node labels for accelerators mapping accel_config_list = config_dict.get("accelerators", []) if accel_config_list: @@ -445,6 +478,13 @@ def _check_node_pool_exists_cached(selector_items) -> bool: else: pool_labels["cloud.google.com/gke-accelerator"] = accel_type + # TPU topology mapping from placement policy + placement_policy = pool.get("placementPolicy", {}) + if placement_policy and placement_policy.get("tpuTopology"): + pool_labels["cloud.google.com/gke-tpu-topology"] = placement_policy[ + "tpuTopology" + ] + # TPU mapping fallback machine_type = config_dict.get("machineType", "") @@ -455,7 +495,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool: "goog-gke-accelerator-type" ] - if machine_type.startswith("ct"): + if machine_type.startswith("ct") and not pool_labels.get( + "cloud.google.com/gke-tpu-topology" + ): # We roughly map TPU topology presence for preflight pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get( "cloud.google.com/gke-tpu-topology", "" @@ -466,7 +508,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool: for tpu_spec in accelerators.TPUS.values(): for chips, topo_spec in tpu_spec.topologies.items(): if topo_spec.machine_type == machine_type: - pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips) + pool_labels["cloud.google.com/gke-accelerator-count"] = str( + chips // topo_spec.num_nodes + ) break if all(pool_labels.get(k) == str(v) for k, v in selector.items()): diff --git a/keras_remote/backend/gke_client_test.py b/keras_remote/backend/gke_client_test.py index 0897717..41fd77f 100644 --- a/keras_remote/backend/gke_client_test.py +++ b/keras_remote/backend/gke_client_test.py @@ -58,6 +58,15 @@ def test_tpu_v3_8(self): self.assertLen(result["tolerations"], 1) self.assertEqual(result["tolerations"][0]["key"], "google.com/tpu") + def test_tpu_v3_16_multi_node(self): + # v3-16 has 4 nodes and 16 total chips -> 4 chips per node + result = _parse_accelerator("v3-16") + self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"}) + self.assertEqual(result["resource_requests"], {"google.com/tpu": "4"}) + self.assertEqual( + result["node_selector"]["cloud.google.com/gke-tpu-topology"], "4x4" + ) + def test_tpu_v5litepod_4(self): result = _parse_accelerator("v5litepod-4") self.assertEqual( @@ -69,6 +78,38 @@ def test_tpu_v5litepod_4(self): ) self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"}) + def test_spot_gpu(self): + result = _parse_accelerator("l4:spot") + self.assertEqual( + result["node_selector"]["cloud.google.com/gke-spot"], "true" + ) + # Check for spot toleration + spot_tol = [ + t + for t in result["tolerations"] + if t.get("key") == "cloud.google.com/gke-spot" + ] + self.assertLen(spot_tol, 1) + self.assertEqual(spot_tol[0]["value"], "true") + + def test_spot_tpu(self): + result = _parse_accelerator("v6e-8:spot") + self.assertEqual( + result["node_selector"]["cloud.google.com/gke-spot"], "true" + ) + # Check for spot toleration + spot_tol = [ + t + for t in result["tolerations"] + if t.get("key") == "cloud.google.com/gke-spot" + ] + self.assertLen(spot_tol, 1) + self.assertEqual(spot_tol[0]["value"], "true") + # Should still have TPU toleration + self.assertTrue( + any(t.get("key") == "google.com/tpu" for t in result["tolerations"]) + ) + class TestCreateJobSpec(absltest.TestCase): def _make_gpu_config(self): @@ -421,6 +462,29 @@ def test_tpu_match(self): ) self.assertTrue(result) + def test_tpu_multi_node_match(self): + """Test that it correctly identifies a 4-chip-per-node pool for v6e-16.""" + self.mock_run.return_value = json.dumps( + [ + { + "config": { + "machineType": "ct6e-standard-4t", + "accelerators": [{"acceleratorType": "tpu-v6e-slice"}], + "labels": {}, + } + } + ] + ) + + result = _check_node_pool_exists_cached( + ( + ("cloud.google.com/gke-tpu-accelerator", "tpu-v6e-slice"), + ("cloud.google.com/gke-tpu-topology", "4x4"), + ("cloud.google.com/gke-accelerator-count", "4"), + ) + ) + self.assertTrue(result) + def test_no_match(self): self.mock_run.return_value = json.dumps( [ diff --git a/keras_remote/backend/pathways_client.py b/keras_remote/backend/pathways_client.py index c923286..1a77707 100644 --- a/keras_remote/backend/pathways_client.py +++ b/keras_remote/backend/pathways_client.py @@ -53,6 +53,7 @@ def submit_pathways_job( job_id, bucket_name, namespace="default", + spot=False, ): """Submit a LeaderWorkerSet to GKE cluster. @@ -71,12 +72,10 @@ def submit_pathways_job( _load_kube_config() lws_version = _get_lws_version() - accel_config = _parse_accelerator(accelerator) + parsed_config = accelerators.parse_accelerator(accelerator, spot=spot) + accel_config = _parse_accelerator(accelerator, spot=spot) job_name = _get_job_name(job_id) - # Extract num nodes from the TPU configuration - - parsed_config = accelerators.parse_accelerator(accelerator) if ( isinstance(parsed_config, accelerators.TpuConfig) and parsed_config.num_nodes > 1 @@ -137,6 +136,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10): # The leader pod is suffixed with '-0' by LWS leader_pod_name = f"{job_name}-0" + logged_pending = set() with LogStreamer(core_v1, namespace) as streamer: while True: elapsed = time.time() - start_time @@ -160,7 +160,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10): raise RuntimeError(f"Pathways job {job_name} failed") elif pod.status.phase == "Pending": - _check_pod_scheduling(core_v1, job_name, namespace) + _check_pod_scheduling(core_v1, job_name, namespace, logged_pending) logging.debug("Pod is Pending...") elif pod.status.phase == "Running": @@ -262,10 +262,12 @@ def _create_lws_spec( {"name": "TPU_WORKER_ID", "value": "$(LWS_WORKER_INDEX)"}, ] - tolerations = [ - {"key": t["key"], "operator": t["operator"], "effect": t["effect"]} - for t in accel_config["tolerations"] - ] + tolerations = [] + for t in accel_config["tolerations"]: + entry = {"key": t["key"], "operator": t["operator"], "effect": t["effect"]} + if "value" in t: + entry["value"] = t["value"] + tolerations.append(entry) pod_template = { "metadata": { @@ -288,8 +290,12 @@ def _create_lws_spec( ], "env": env_vars, "resources": { - "limits": accel_config["resource_limits"], - "requests": accel_config["resource_requests"], + "limits": { + k: str(v) for k, v in accel_config["resource_limits"].items() + }, + "requests": { + k: str(v) for k, v in accel_config["resource_requests"].items() + }, }, } ], diff --git a/keras_remote/backend/pathways_client_test.py b/keras_remote/backend/pathways_client_test.py index c4f5b06..9e6bad5 100644 --- a/keras_remote/backend/pathways_client_test.py +++ b/keras_remote/backend/pathways_client_test.py @@ -143,6 +143,33 @@ def test_env_vars(self): self.assertEqual(env["MEGASCALE_NUM_SLICES"], "4") self.assertEqual(env["TPU_WORKER_ID"], "$(LWS_WORKER_INDEX)") + def test_spot_spec(self): + """Test that spot selectors and tolerations are added when present.""" + accel_config = self._make_tpu_accel_config() + accel_config["node_selector"]["cloud.google.com/gke-spot"] = "true" + accel_config["tolerations"].append( + { + "key": "cloud.google.com/gke-spot", + "operator": "Equal", + "value": "true", + "effect": "NoSchedule", + } + ) + + spec = self._make_spec(accel_config=accel_config) + pod_spec = spec["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"] + + self.assertEqual( + pod_spec["nodeSelector"]["cloud.google.com/gke-spot"], "true" + ) + spot_tol = [ + t + for t in pod_spec["tolerations"] + if t.get("key") == "cloud.google.com/gke-spot" + ] + self.assertLen(spot_tol, 1) + self.assertEqual(spot_tol[0]["value"], "true") + def test_tpu_accel_config(self): """Test resources, tolerations, and node selector for TPU config.""" spec = self._make_spec(accel_config=self._make_tpu_accel_config()) diff --git a/keras_remote/cli/commands/pool.py b/keras_remote/cli/commands/pool.py index 40c728e..b996be9 100644 --- a/keras_remote/cli/commands/pool.py +++ b/keras_remote/cli/commands/pool.py @@ -29,13 +29,14 @@ def pool(): "v5litepod, v5p, v6e, v3 (with optional count/topology)", ) @click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") -def pool_add(project, zone, cluster_name, accelerator, yes): +@click.option("--spot", is_flag=True, help="Use Spot VMs for node pool") +def pool_add(project, zone, cluster_name, accelerator, yes, spot): """Add an accelerator node pool to the cluster.""" banner("keras-remote Pool Add") # Parse the accelerator spec first to fail fast on bad input. try: - accel_config = accelerators.parse_accelerator(accelerator) + accel_config = accelerators.parse_accelerator(accelerator, spot=spot) except ValueError as e: raise click.BadParameter(str(e), param_hint="--accelerator") from e diff --git a/keras_remote/cli/infra/program.py b/keras_remote/cli/infra/program.py index 18dd209..bf1d7c0 100644 --- a/keras_remote/cli/infra/program.py +++ b/keras_remote/cli/infra/program.py @@ -237,6 +237,7 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id, pool_name): ], labels={RESOURCE_NAME_PREFIX: "true"}, max_run_duration=f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours + spot=gpu.spot, ), ) @@ -276,7 +277,10 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id, pool_name): machine_type=tpu.machine_type, oauth_scopes=_BASE_OAUTH_SCOPES, labels={RESOURCE_NAME_PREFIX: "true"}, - max_run_duration=f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours + max_run_duration=None + if tpu.spot + else f"{NODE_MAX_RUN_DURATION_SECONDS}s", # 24 hours + spot=tpu.spot, ), placement_policy=placement, ) diff --git a/keras_remote/core/accelerators.py b/keras_remote/core/accelerators.py index fa57acd..6150db0 100644 --- a/keras_remote/core/accelerators.py +++ b/keras_remote/core/accelerators.py @@ -18,6 +18,7 @@ class GpuConfig: count: int # number of GPUs (1, 2, 4, …) gke_label: str # "nvidia-l4" — K8s node selector value machine_type: str # "g2-standard-4" — GKE node pool machine type + spot: bool = False @dataclass(frozen=True) @@ -30,6 +31,7 @@ class TpuConfig: gke_accelerator: str # "tpu-v5-lite-podslice" machine_type: str # "ct5lp-hightpu-4t" num_nodes: int # GKE node pool node count + spot: bool = False Accelerator = Union[GpuConfig, TpuConfig, None] @@ -245,7 +247,7 @@ def _resolve_tpu_alias(name: str) -> str: return _TPU_ALIASES.get(name, name) -def parse_accelerator(accel_str: str) -> Accelerator: +def parse_accelerator(accel_str: str, spot: bool = False) -> Accelerator: """Parse an accelerator string into a fully resolved config. Returns GpuConfig, TpuConfig, or None (for "cpu"). @@ -268,15 +270,18 @@ def parse_accelerator(accel_str: str) -> Accelerator: v6e > v5p > v5litepod for TPUs). """ s = accel_str.strip().lower() + if s.endswith(":spot"): + spot = True + s = s[:-5] if s == "cpu" or (s.startswith("cpu:") and s[4:].isdigit()): return None if s == "gpu": - return make_gpu(DEFAULT_GPU, 1) + return make_gpu(DEFAULT_GPU, 1, spot=spot) if s == "tpu": - return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips) + return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips, spot=spot) # 1) Try parsing as GPU is_gpu_explicit = s.startswith("gpu:") @@ -286,7 +291,7 @@ def parse_accelerator(accel_str: str) -> Accelerator: count = int(gpu_str) for gpu_name in _PREFERRED_GPUS: if gpu_name in GPUS and count in GPUS[gpu_name].counts: - return make_gpu(gpu_name, count) + return make_gpu(gpu_name, count, spot=spot) if is_gpu_explicit: valid_counts = sorted( set(c for spec in GPUS.values() for c in spec.counts) @@ -297,13 +302,13 @@ def parse_accelerator(accel_str: str) -> Accelerator: name = _resolve_gpu_alias(gpu_str) if name in GPUS: - return make_gpu(name, 1) + return make_gpu(name, 1, spot=spot) m = _MULTI_GPU_RE.match(gpu_str) if m: name = _resolve_gpu_alias(m.group(1)) if name in GPUS: - return make_gpu(name, int(m.group(2))) + return make_gpu(name, int(m.group(2)), spot=spot) if is_gpu_explicit: raise ValueError(f"Unknown GPU accelerator: '{accel_str}'") @@ -316,7 +321,7 @@ def parse_accelerator(accel_str: str) -> Accelerator: chips = int(tpu_str) for tpu_name in _PREFERRED_TPUS: if tpu_name in TPUS and chips in TPUS[tpu_name].topologies: - return make_tpu(tpu_name, chips) + return make_tpu(tpu_name, chips, spot=spot) if is_tpu_explicit: valid_chips = sorted( set(c for spec in TPUS.values() for c in spec.topologies) @@ -327,7 +332,7 @@ def parse_accelerator(accel_str: str) -> Accelerator: name = _resolve_tpu_alias(tpu_str) if name in TPUS: - return make_tpu(name, TPUS[name].default_chips) + return make_tpu(name, TPUS[name].default_chips, spot=spot) m = _TPU_TOPO_RE.match(tpu_str) if m: @@ -336,7 +341,7 @@ def parse_accelerator(accel_str: str) -> Accelerator: topo_str = m.group(2) for chips, topo_spec in TPUS[name].topologies.items(): if topo_spec.topology == topo_str: - return make_tpu(name, chips) + return make_tpu(name, chips, spot=spot) valid = [ts.topology for ts in TPUS[name].topologies.values()] raise ValueError( f"Topology '{topo_str}' not supported for '{name}'. " @@ -347,7 +352,7 @@ def parse_accelerator(accel_str: str) -> Accelerator: if m: name = _resolve_tpu_alias(m.group(1)) if name in TPUS: - return make_tpu(name, int(m.group(2))) + return make_tpu(name, int(m.group(2)), spot=spot) raise ValueError( f"Unknown accelerator: '{accel_str}'. " @@ -380,7 +385,7 @@ def generate_pool_name(accel: GpuConfig | TpuConfig) -> str: raise TypeError(f"Expected GpuConfig or TpuConfig, got {type(accel)}") -def make_gpu(name: str, count: int) -> GpuConfig: +def make_gpu(name: str, count: int, spot: bool = False) -> GpuConfig: spec = GPUS[name] if count not in spec.counts: raise ValueError( @@ -392,10 +397,11 @@ def make_gpu(name: str, count: int) -> GpuConfig: count=count, gke_label=spec.gke_label, machine_type=spec.counts[count], + spot=spot, ) -def make_tpu(name: str, chips: int) -> TpuConfig: +def make_tpu(name: str, chips: int, spot: bool = False) -> TpuConfig: spec = TPUS[name] if chips not in spec.topologies: raise ValueError( @@ -410,4 +416,5 @@ def make_tpu(name: str, chips: int) -> TpuConfig: gke_accelerator=spec.gke_accelerator, machine_type=topo_spec.machine_type, num_nodes=topo_spec.num_nodes, + spot=spot, ) diff --git a/keras_remote/core/accelerators_test.py b/keras_remote/core/accelerators_test.py index ba48cf4..7ff818a 100644 --- a/keras_remote/core/accelerators_test.py +++ b/keras_remote/core/accelerators_test.py @@ -14,6 +14,36 @@ ) +class TestParseSpot(absltest.TestCase): + def test_suffix_gpu(self): + result = parse_accelerator("l4:spot") + self.assertIsInstance(result, GpuConfig) + self.assertTrue(result.spot) + self.assertEqual(result.name, "l4") + + def test_suffix_tpu(self): + result = parse_accelerator("v6e-8:spot") + self.assertIsInstance(result, TpuConfig) + self.assertTrue(result.spot) + self.assertEqual(result.name, "v6e") + + def test_explicit_flag_gpu(self): + result = parse_accelerator("l4", spot=True) + self.assertIsInstance(result, GpuConfig) + self.assertTrue(result.spot) + + def test_explicit_flag_tpu(self): + result = parse_accelerator("v6e-8", spot=True) + self.assertIsInstance(result, TpuConfig) + self.assertTrue(result.spot) + + def test_default_is_not_spot(self): + result = parse_accelerator("l4") + self.assertFalse(result.spot) + result = parse_accelerator("v6e-8") + self.assertFalse(result.spot) + + class TestParseGpuDirect(parameterized.TestCase): def test_l4(self): result = parse_accelerator("gpu:l4") diff --git a/keras_remote/core/core.py b/keras_remote/core/core.py index 1d1534e..0e08166 100644 --- a/keras_remote/core/core.py +++ b/keras_remote/core/core.py @@ -22,6 +22,7 @@ def run( backend=None, namespace="default", volumes=None, + spot=False, ): """Execute function on remote TPU/GPU. @@ -74,7 +75,7 @@ def wrapper(*args, **kwargs): resolved_backend = backend if resolved_backend is None: try: - accel_config = accelerators.parse_accelerator(accelerator) + accel_config = accelerators.parse_accelerator(accelerator, spot=spot) # Use Pathways for multi-host TPUs (if supported) or simplified logic # For now, let's default to GKE unless explicit or strictly needed if ( @@ -100,6 +101,7 @@ def wrapper(*args, **kwargs): namespace, env_vars, volumes, + spot, ) elif resolved_backend == "pathways": return _execute_on_pathways( @@ -114,6 +116,7 @@ def wrapper(*args, **kwargs): namespace, env_vars, volumes, + spot, ) else: raise ValueError( @@ -137,6 +140,7 @@ def _execute_on_gke( namespace, env_vars, volumes, + spot, ): """Execute function on GKE cluster with GPU/TPU nodes.""" # Get GKE-specific defaults @@ -156,6 +160,7 @@ def _execute_on_gke( env_vars, cluster_name=cluster, volumes=volumes, + spot=spot, ) return execute_remote(ctx, GKEBackend(cluster=cluster, namespace=namespace)) @@ -172,6 +177,7 @@ def _execute_on_pathways( namespace, env_vars, volumes, + spot, ): """Execute function on GKE cluster via ML Pathways.""" if not cluster: @@ -190,6 +196,7 @@ def _execute_on_pathways( env_vars, cluster_name=cluster, volumes=volumes, + spot=spot, ) return execute_remote( ctx, PathwaysBackend(cluster=cluster, namespace=namespace)