Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/example_gke.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def simple_computation(x, y):


# Example 2: Keras model training on CPU
@keras_remote.run(accelerator="cpu")
@keras_remote.run(accelerator="v6e-2x4", cluster="spot-tpu-nodes", spot=True)
def train_simple_model_cpu():
"""Train a simple Keras model on remote CPU."""

Expand Down Expand Up @@ -111,10 +111,10 @@ def main():
print("=" * 60)

# Example 1: Simple computation (CPU)
print("\n--- Example 1: Simple Computation (CPU) ---")
print("Running simple_computation(10, 20) on GKE...")
result = simple_computation(10, 20)
print(f"Result: {result}")
# print("\n--- Example 1: Simple Computation (CPU) ---")
# print("Running simple_computation(10, 20) on GKE...")
# result = simple_computation(10, 20)
# print(f"Result: {result}")

# Example 2: Model training on CPU
print("\n--- Example 2: Keras Model Training (CPU) ---")
Expand Down
47 changes: 46 additions & 1 deletion examples/pathways_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,55 @@


# A simple model that will be executed remotely on pathways
@keras_remote.run(accelerator="v5litepod-1", backend="pathways")
@keras_remote.run(
accelerator="v6e-16", backend="pathways", cluster="keras-team-dogfood"
)
def train_simple_model():
import jax
from jax import lax

print("Running Pathways job on JAX Backend!")

# Verify distributed JAX setup (Pathways auto-initialization)
process_count = jax.process_count()
process_index = jax.process_index()
device_count = jax.device_count()
local_device_count = jax.local_device_count()

print("JAX Distributed Environment:")
print(f" Process Count: {process_count}")
print(f" Process Index: {process_index}")
print(f" Total Devices: {device_count}")
print(f" Local Devices: {local_device_count}")

# Fail if not actually running on multiple hosts
if process_count <= 1:
raise RuntimeError(
f"Pathways verification failed: Expected > 1 processes, but found {process_count}. "
"This indicates the job is NOT running in a multi-host Pathways environment."
)

# Verify collective communication (cross-host psum)
try:
# Use jax.pmap to sum values across all devices in the cluster
x = np.ones(local_device_count)
distributed_sum = jax.pmap(lambda val: lax.psum(val, "i"), axis_name="i")(x)
total_sum = distributed_sum[0]

if total_sum != device_count:
raise RuntimeError(
f"Collective verification failed: Expected psum {device_count}, got {total_sum}"
)
print(
f"Successfully verified collective communication across all {total_sum} devices!"
)
except Exception as e:
print(f"Warning: Collective verification failed: {e}")
if isinstance(e, RuntimeError) and "Collective verification failed" in str(
e
):
raise
Comment on lines +55 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current exception handling logic is risky because it swallows most exceptions. The except Exception block only re-raises a RuntimeError if it contains the specific string "Collective verification failed". Any other exception (including other RuntimeErrors or errors from JAX) will be caught, a warning printed, and then the exception will be swallowed, allowing the script to continue as if no error occurred. This can hide underlying problems and lead to silent failures.

It's safer to ensure that any failure during verification causes the job to fail. I recommend simplifying the exception handling to always re-raise after logging, like this:

  except Exception as e:
    print(f"Error: Collective verification failed: {e}")
    raise

Alternatively, you could remove the try...except block entirely if the goal is to just let any exception from the verification logic fail the job.

References
  1. The current exception handling is not robust as it can swallow critical errors, violating the principle of demanding robust code. The proposed change ensures failures are not silently ignored. (Rule 4: Demand Robustness) (link)


# Create a simple dataset
x = np.random.rand(1000, 10)
y = np.random.randint(0, 2, size=(1000, 1))
Expand Down
7 changes: 7 additions & 0 deletions keras_remote/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class JobContext:
# Data volumes {mount_path: Data}
volumes: Optional[dict] = None

# Configuration modifiers
spot: bool = False

# Artifact paths (set during prepare phase)
payload_path: Optional[str] = None
context_path: Optional[str] = None
Expand All @@ -80,6 +83,7 @@ def from_params(
env_vars: dict,
cluster_name: Optional[str] = None,
volumes: Optional[dict] = None,
spot: bool = False,
) -> "JobContext":
"""Factory method with default resolution for zone/project/cluster."""
if not zone:
Expand All @@ -105,6 +109,7 @@ def from_params(
project=project,
cluster_name=cluster_name,
volumes=volumes,
spot=spot,
)


Expand Down Expand Up @@ -155,6 +160,7 @@ def submit_job(self, ctx: JobContext) -> Any:
job_id=ctx.job_id,
bucket_name=ctx.bucket_name,
namespace=self.namespace,
spot=ctx.spot,
)

def wait_for_job(self, job: Any, ctx: JobContext) -> None:
Expand Down Expand Up @@ -191,6 +197,7 @@ def submit_job(self, ctx: JobContext) -> Any:
job_id=ctx.job_id,
bucket_name=ctx.bucket_name,
namespace=self.namespace,
spot=ctx.spot,
)

def wait_for_job(self, job: Any, ctx: JobContext) -> None:
Expand Down
66 changes: 55 additions & 11 deletions keras_remote/backend/gke_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def submit_k8s_job(
job_id,
bucket_name,
namespace="default",
spot=False,
):
"""Submit a Kubernetes Job to GKE cluster.

Expand All @@ -42,7 +43,7 @@ def submit_k8s_job(
_load_kube_config()

# Parse accelerator configuration
accel_config = _parse_accelerator(accelerator)
accel_config = _parse_accelerator(accelerator, spot=spot)

# Create job specification
job_name = f"keras-remote-{job_id}"
Expand Down Expand Up @@ -224,9 +225,9 @@ def validate_preflight(
logging.warning("Preflight check: Failed to query nodes: %s", e.reason)


def _parse_accelerator(accelerator):
def _parse_accelerator(accelerator, spot=False):
"""Convert accelerator string to GKE pod spec fields."""
parsed = accelerators.parse_accelerator(accelerator)
parsed = accelerators.parse_accelerator(accelerator, spot=spot)

if parsed is None:
return {
Expand All @@ -238,21 +239,36 @@ def _parse_accelerator(accelerator):
}

if isinstance(parsed, TpuConfig):
return {
# For TPU Podslices (multi-node), resource requests must be per-node.
# num_nodes is 1 for single-host TPUs (v3-8, v4-8, v5litepod-1/4/8).
chips_per_node = parsed.chips // parsed.num_nodes
config = {
"node_selector": {
"cloud.google.com/gke-tpu-accelerator": parsed.gke_accelerator,
"cloud.google.com/gke-tpu-topology": parsed.topology,
},
"resource_limits": {"google.com/tpu": str(parsed.chips)},
"resource_requests": {"google.com/tpu": str(parsed.chips)},
"resource_limits": {"google.com/tpu": str(chips_per_node)},
"resource_requests": {"google.com/tpu": str(chips_per_node)},
"tolerations": [
{"key": "google.com/tpu", "operator": "Exists", "effect": "NoSchedule"}
],
"jax_platform": "tpu",
}

if parsed.spot:
config["node_selector"]["cloud.google.com/gke-spot"] = "true"
config["tolerations"].append(
{
"key": "cloud.google.com/gke-spot",
"operator": "Equal",
"value": "true",
"effect": "NoSchedule",
}
)
return config

# GpuConfig
return {
config = {
"node_selector": {"cloud.google.com/gke-accelerator": parsed.gke_label},
"resource_limits": {"nvidia.com/gpu": str(parsed.count)},
"resource_requests": {"nvidia.com/gpu": str(parsed.count)},
Expand All @@ -261,6 +277,17 @@ def _parse_accelerator(accelerator):
],
"jax_platform": "gpu",
}
if parsed.spot:
config["node_selector"]["cloud.google.com/gke-spot"] = "true"
config["tolerations"].append(
{
"key": "cloud.google.com/gke-spot",
"operator": "Equal",
"value": "true",
"effect": "NoSchedule",
}
)
return config


def _load_kube_config():
Expand Down Expand Up @@ -330,8 +357,10 @@ def _create_job_spec(
],
env=env_vars,
resources=client.V1ResourceRequirements(
limits=accel_config["resource_limits"],
requests=accel_config["resource_requests"],
limits={k: str(v) for k, v in accel_config["resource_limits"].items()},
requests={
k: str(v) for k, v in accel_config["resource_requests"].items()
},
),
)

Expand Down Expand Up @@ -436,6 +465,10 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
config_dict = pool.get("config", {})
pool_labels = config_dict.get("labels", {}).copy()

# Spot VM mapping
if config_dict.get("spot"):
pool_labels["cloud.google.com/gke-spot"] = "true"

# Map GKE injected node labels for accelerators mapping
accel_config_list = config_dict.get("accelerators", [])
if accel_config_list:
Expand All @@ -445,6 +478,13 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
else:
pool_labels["cloud.google.com/gke-accelerator"] = accel_type

# TPU topology mapping from placement policy
placement_policy = pool.get("placementPolicy", {})
if placement_policy and placement_policy.get("tpuTopology"):
pool_labels["cloud.google.com/gke-tpu-topology"] = placement_policy[
"tpuTopology"
]

# TPU mapping fallback
machine_type = config_dict.get("machineType", "")

Expand All @@ -455,7 +495,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
"goog-gke-accelerator-type"
]

if machine_type.startswith("ct"):
if machine_type.startswith("ct") and not pool_labels.get(
"cloud.google.com/gke-tpu-topology"
):
# We roughly map TPU topology presence for preflight
pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get(
"cloud.google.com/gke-tpu-topology", ""
Expand All @@ -466,7 +508,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
for tpu_spec in accelerators.TPUS.values():
for chips, topo_spec in tpu_spec.topologies.items():
if topo_spec.machine_type == machine_type:
pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips)
pool_labels["cloud.google.com/gke-accelerator-count"] = str(
chips // topo_spec.num_nodes
)
break

if all(pool_labels.get(k) == str(v) for k, v in selector.items()):
Expand Down
64 changes: 64 additions & 0 deletions keras_remote/backend/gke_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def test_tpu_v3_8(self):
self.assertLen(result["tolerations"], 1)
self.assertEqual(result["tolerations"][0]["key"], "google.com/tpu")

def test_tpu_v3_16_multi_node(self):
# v3-16 has 4 nodes and 16 total chips -> 4 chips per node
result = _parse_accelerator("v3-16")
self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"})
self.assertEqual(result["resource_requests"], {"google.com/tpu": "4"})
self.assertEqual(
result["node_selector"]["cloud.google.com/gke-tpu-topology"], "4x4"
)

def test_tpu_v5litepod_4(self):
result = _parse_accelerator("v5litepod-4")
self.assertEqual(
Expand All @@ -69,6 +78,38 @@ def test_tpu_v5litepod_4(self):
)
self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"})

def test_spot_gpu(self):
result = _parse_accelerator("l4:spot")
self.assertEqual(
result["node_selector"]["cloud.google.com/gke-spot"], "true"
)
# Check for spot toleration
spot_tol = [
t
for t in result["tolerations"]
if t.get("key") == "cloud.google.com/gke-spot"
]
self.assertLen(spot_tol, 1)
self.assertEqual(spot_tol[0]["value"], "true")

def test_spot_tpu(self):
result = _parse_accelerator("v6e-8:spot")
self.assertEqual(
result["node_selector"]["cloud.google.com/gke-spot"], "true"
)
# Check for spot toleration
spot_tol = [
t
for t in result["tolerations"]
if t.get("key") == "cloud.google.com/gke-spot"
]
self.assertLen(spot_tol, 1)
self.assertEqual(spot_tol[0]["value"], "true")
# Should still have TPU toleration
self.assertTrue(
any(t.get("key") == "google.com/tpu" for t in result["tolerations"])
)


class TestCreateJobSpec(absltest.TestCase):
def _make_gpu_config(self):
Expand Down Expand Up @@ -421,6 +462,29 @@ def test_tpu_match(self):
)
self.assertTrue(result)

def test_tpu_multi_node_match(self):
"""Test that it correctly identifies a 4-chip-per-node pool for v6e-16."""
self.mock_run.return_value = json.dumps(
[
{
"config": {
"machineType": "ct6e-standard-4t",
"accelerators": [{"acceleratorType": "tpu-v6e-slice"}],
"labels": {},
}
}
]
)

result = _check_node_pool_exists_cached(
(
("cloud.google.com/gke-tpu-accelerator", "tpu-v6e-slice"),
("cloud.google.com/gke-tpu-topology", "4x4"),
("cloud.google.com/gke-accelerator-count", "4"),
)
)
self.assertTrue(result)

def test_no_match(self):
self.mock_run.return_value = json.dumps(
[
Expand Down
Loading
Loading