diff --git a/examples/pathways_example.py b/examples/pathways_example.py new file mode 100644 index 0000000..0a34628 --- /dev/null +++ b/examples/pathways_example.py @@ -0,0 +1,49 @@ +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import keras +import numpy as np +from keras import layers + +import keras_remote + + +# A simple model that will be executed remotely on pathways +@keras_remote.run(accelerator="v5litepod-1", backend="pathways") +def train_simple_model(): + print("Running Pathways job on JAX Backend!") + + # Create a simple dataset + x = np.random.rand(1000, 10) + y = np.random.randint(0, 2, size=(1000, 1)) + + # A simple sequential model + model = keras.Sequential( + [ + keras.Input(shape=(10,)), + layers.Dense(32, activation="relu"), + layers.Dense(16, activation="relu"), + layers.Dense(1, activation="sigmoid"), + ] + ) + + model.compile( + optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] + ) + + print("Model Architecture:") + model.summary() + + # Train the model + print("\nStarting Training...") + history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2) + + print("\nTraining completed successfully on Pathways!") + return history.history + + +if __name__ == "__main__": + print("Submitting Pathways training job...") + result_history = train_simple_model() + print("Final validation accuracy:", result_history["val_accuracy"][-1]) diff --git a/keras_remote/backend/execution.py b/keras_remote/backend/execution.py index ba01fbd..cd5e544 100644 --- a/keras_remote/backend/execution.py +++ b/keras_remote/backend/execution.py @@ -14,7 +14,7 @@ import cloudpickle from absl import logging -from keras_remote.backend import gke_client +from keras_remote.backend import gke_client, pathways_client from keras_remote.constants import get_default_zone, zone_to_region from keras_remote.infra import container_builder from keras_remote.utils import packager, storage @@ -105,13 +105,17 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None: ... -class GKEBackend: - """Backend adapter for GKE.""" +class BaseK8sBackend: + """Base class for Kubernetes-based backends.""" def __init__(self, cluster: Optional[str] = None, namespace: str = "default"): self.cluster = cluster self.namespace = namespace + +class GKEBackend(BaseK8sBackend): + """Backend adapter for standard GKE Jobs.""" + def submit_job(self, ctx: JobContext) -> Any: """Submit job to GKE cluster.""" return gke_client.submit_k8s_job( @@ -134,6 +138,31 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None: gke_client.cleanup_job(job_name, namespace=self.namespace) +class PathwaysBackend(BaseK8sBackend): + """Backend adapter for ML Pathways using LeaderWorkerSet.""" + + def submit_job(self, ctx: JobContext) -> Any: + """Submit LWS job to GKE cluster.""" + return pathways_client.submit_pathways_job( + display_name=ctx.display_name, + container_uri=ctx.image_uri, + accelerator=ctx.accelerator, + project=ctx.project, + job_id=ctx.job_id, + bucket_name=ctx.bucket_name, + namespace=self.namespace, + ) + + def wait_for_job(self, job: Any, ctx: JobContext) -> None: + """Wait for Pathways LWS completion.""" + pathways_client.wait_for_job(ctx.job_id, namespace=self.namespace) + + def cleanup_job(self, job: Any, ctx: JobContext) -> None: + """Clean up LWS resources.""" + job_name = pathways_client._get_job_name(ctx.job_id) + pathways_client.cleanup_job(job_name, namespace=self.namespace) + + def _find_requirements(start_dir: str) -> Optional[str]: """Search up directory tree for requirements.txt.""" search_dir = start_dir diff --git a/keras_remote/backend/gke_client.py b/keras_remote/backend/gke_client.py index 0bc72ca..f8400ef 100644 --- a/keras_remote/backend/gke_client.py +++ b/keras_remote/backend/gke_client.py @@ -300,7 +300,7 @@ def _create_job_spec( pod_template = client.V1PodTemplateSpec( metadata=client.V1ObjectMeta( - labels={"app": "keras-remote", "job-id": job_id} + labels={"app": "keras-remote", "job-id": job_id, "job-name": job_name} ), spec=client.V1PodSpec(**pod_spec_kwargs), ) diff --git a/keras_remote/backend/pathways_client.py b/keras_remote/backend/pathways_client.py new file mode 100644 index 0000000..f0dfdb8 --- /dev/null +++ b/keras_remote/backend/pathways_client.py @@ -0,0 +1,317 @@ +"""Pathways (LeaderWorkerSet) job submission for keras_remote.""" + +import time + +from kubernetes import client +from kubernetes.client.rest import ApiException + +from keras_remote.backend.gke_client import ( + _check_pod_scheduling, + _load_kube_config, + _parse_accelerator, + _print_pod_logs, +) +from keras_remote.core import accelerators +from keras_remote.infra import infra + +logger = infra.logger + +LWS_GROUP = "leaderworkerset.x-k8s.io" +LWS_VERSION = "v1" +LWS_PLURAL = "leaderworkersets" + + +def _get_job_name(job_id: str) -> str: + """Get the standardized Pathways job name for a given job ID.""" + return f"keras-pathways-{job_id}" + + +def _get_lws_version(group=LWS_GROUP): + """Get the preferred version for the LeaderWorkerSet API.""" + _load_kube_config() + api = client.ApisApi() + try: + api_groups = api.get_api_versions().groups + for api_group in api_groups: + if api_group.name == group: + return api_group.preferred_version.version + + # If we didn't find the group, raise ApiException to fallback + raise ApiException(status=404, reason=f"API group {group} not found") + except ApiException: + logger.warning( + "Failed to retrieve LWS API version from cluster. Defaulting to '%s'", + LWS_VERSION, + ) + return LWS_VERSION + + +def submit_pathways_job( + display_name, + container_uri, + accelerator, + project, + job_id, + bucket_name, + namespace="default", +): + """Submit a LeaderWorkerSet to GKE cluster. + + Args: + display_name: Job display name (used for K8s LWS name) + container_uri: Docker container image URI + accelerator: TPU type (must be TpuConfig) + project: GCP project ID + job_id: Unique job identifier + bucket_name: GCS bucket name for artifacts + namespace: Kubernetes namespace (default: "default") + + Returns: + dict: The created LeaderWorkerSet object + """ + _load_kube_config() + lws_version = _get_lws_version() + + accel_config = _parse_accelerator(accelerator) + job_name = _get_job_name(job_id) + + # Extract num nodes from the TPU configuration + + parsed_config = accelerators.parse_accelerator(accelerator) + if ( + isinstance(parsed_config, accelerators.TpuConfig) + and parsed_config.num_nodes > 1 + ): + num_workers = parsed_config.num_nodes - 1 + else: + num_workers = 0 + + lws_manifest = _create_lws_spec( + job_name=job_name, + container_uri=container_uri, + accel_config=accel_config, + job_id=job_id, + bucket_name=bucket_name, + num_workers=num_workers, + namespace=namespace, + version=lws_version, + ) + + custom_api = client.CustomObjectsApi() + + try: + created_lws = custom_api.create_namespaced_custom_object( + group=LWS_GROUP, + version=lws_version, + namespace=namespace, + plural=LWS_PLURAL, + body=lws_manifest, + ) + logger.info(f"Submitted Pathways job (LWS): {job_name}") + logger.info( + "View job with: kubectl get %s %s -n %s", LWS_PLURAL, job_name, namespace + ) + return created_lws + except ApiException as e: + if e.status == 404: + raise RuntimeError( + "LeaderWorkerSet CRD not found. Please ensure it is " + "installed on the cluster. You can install it by running " + "the `keras-remote up` command, or by following the " + "official LWS installation guide." + ) from e + else: + raise RuntimeError( + f"Kubernetes API error: {e.status} - {e.reason}: {e.body}" + ) from e + + +def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10): + """Wait for Pathways Job (LeaderWorkerSet) to complete.""" + _load_kube_config() + core_v1 = client.CoreV1Api() + + job_name = _get_job_name(job_id) + start_time = time.time() + logged_running = False + + # The leader pod is suffixed with '-0' by LWS + leader_pod_name = f"{job_name}-0" + + while True: + elapsed = time.time() - start_time + if elapsed > timeout: + raise RuntimeError(f"Pathways job {job_name} timed out after {timeout}s") + + try: + pod = core_v1.read_namespaced_pod(leader_pod_name, namespace) + if not logged_running: + logger.info(f"Found pod: {leader_pod_name}") + logged_running = True + + if pod.status.phase == "Succeeded": + logger.info(f"[REMOTE] Job {job_name} completed successfully") + return "success" + + if pod.status.phase == "Failed": + _print_pod_logs(core_v1, job_name, namespace) + raise RuntimeError(f"Pathways job {job_name} failed") + + elif pod.status.phase == "Pending": + _check_pod_scheduling(core_v1, job_name, namespace) + logger.debug("Pod is Pending...") + + except ApiException as e: + if e.status == 404: + # Pod might not be created yet + pod = None + else: + raise RuntimeError( + f"Failed to read leader pod status: {e.reason}" + ) from e + + if pod is not None and pod.status.container_statuses: + container_status = pod.status.container_statuses[0] + + # Check current state + if container_status.state.terminated: + if container_status.state.terminated.exit_code == 0: + logger.info(f"[REMOTE] Job {job_name} completed successfully") + return "success" + else: + _print_pod_logs(core_v1, job_name, namespace) + raise RuntimeError( + f"Pathways job {job_name} failed with exit code " + f"{container_status.state.terminated.exit_code}" + ) + + # Check last state (in case it restarted) + if container_status.last_state.terminated: + if container_status.last_state.terminated.exit_code == 0: + logger.info( + f"[REMOTE] Job {job_name} completed successfully (restarted)" + ) + return "success" + else: + _print_pod_logs(core_v1, job_name, namespace) + raise RuntimeError( + f"Pathways job {job_name} failed previously with " + f"exit code {container_status.last_state.terminated.exit_code}" + ) + + time.sleep(poll_interval) + + +def cleanup_job(job_name, namespace="default"): + """Delete LeaderWorkerSet.""" + _load_kube_config() + lws_version = _get_lws_version() + custom_api = client.CustomObjectsApi() + + try: + custom_api.delete_namespaced_custom_object( + group=LWS_GROUP, + version=lws_version, + namespace=namespace, + plural=LWS_PLURAL, + name=job_name, + ) + logger.info(f"Deleted LeaderWorkerSet: {job_name}") + except ApiException as e: + if e.status == 404: + # Job already deleted + pass + else: + logger.warning( + "Failed to delete LeaderWorkerSet %s: %s", + job_name, + e.reason, + ) + + +def _create_lws_spec( + job_name, + container_uri, + accel_config, + job_id, + bucket_name, + num_workers, + namespace, + version=LWS_VERSION, +): + """Create a LeaderWorkerSet manifest.""" + + env_vars = [ + {"name": "KERAS_BACKEND", "value": "jax"}, + { + "name": "JAX_PLATFORMS", + "value": accel_config.get("jax_platform", "cpu"), + }, + {"name": "JOB_ID", "value": job_id}, + {"name": "GCS_BUCKET", "value": bucket_name}, + { + "name": "MEGASCALE_COORDINATOR_ADDRESS", + "value": "$(LWS_LEADER_ADDRESS)", + }, + {"name": "MEGASCALE_NUM_SLICES", "value": str(num_workers + 1)}, + {"name": "TPU_WORKER_ID", "value": "$(LWS_WORKER_INDEX)"}, + ] + + tolerations = [ + {"key": t["key"], "operator": t["operator"], "effect": t["effect"]} + for t in accel_config["tolerations"] + ] + + pod_template = { + "metadata": { + "labels": { + "app": "keras-remote-pathways", + "job-id": job_id, + "job-name": job_name, + } + }, + "spec": { + "containers": [ + { + "name": "keras-remote-worker", + "image": container_uri, + "command": ["python3", "-u", "/app/remote_runner.py"], + "args": [ + f"gs://{bucket_name}/{job_id}/context.zip", + f"gs://{bucket_name}/{job_id}/payload.pkl", + f"gs://{bucket_name}/{job_id}/result.pkl", + ], + "env": env_vars, + "resources": { + "limits": accel_config["resource_limits"], + "requests": accel_config["resource_requests"], + }, + } + ], + }, + } + + if tolerations: + pod_template["spec"]["tolerations"] = tolerations + + if accel_config.get("node_selector"): + pod_template["spec"]["nodeSelector"] = accel_config["node_selector"] + + return { + "apiVersion": f"{LWS_GROUP}/{version}", + "kind": "LeaderWorkerSet", + "metadata": { + "name": job_name, + "namespace": namespace, + "labels": {"app": "keras-remote-pathways"}, + }, + "spec": { + "replicas": 1, + "leaderWorkerTemplate": { + "size": num_workers + 1, # 1 leader + N workers + "restartPolicy": "RecreateGroupOnPodRestart", + "leaderTemplate": pod_template, + "workerTemplate": pod_template, + }, + }, + } diff --git a/keras_remote/cli/commands/up.py b/keras_remote/cli/commands/up.py index 3d130da..58159d2 100644 --- a/keras_remote/cli/commands/up.py +++ b/keras_remote/cli/commands/up.py @@ -8,6 +8,7 @@ configure_docker_auth, configure_kubectl, install_gpu_drivers, + install_lws, ) from keras_remote.cli.infra.program import create_program from keras_remote.cli.infra.stack_manager import get_stack @@ -102,6 +103,10 @@ def up(project, zone, accelerator, cluster_name, yes): configure_kubectl(cluster_name, zone, project) success("kubectl configured") + console.print("Installing LeaderWorkerSet CRD for Pathways support...") + install_lws() + success("LWS CRD installed") + if isinstance(accel_config, GpuConfig): console.print("Installing NVIDIA GPU device drivers...") install_gpu_drivers() diff --git a/keras_remote/cli/constants.py b/keras_remote/cli/constants.py index 779606c..a547eb5 100644 --- a/keras_remote/cli/constants.py +++ b/keras_remote/cli/constants.py @@ -23,3 +23,5 @@ "container-engine-accelerators/v1.0.20/" "nvidia-driver-installer/cos/daemonset-preloaded.yaml" ) + +LWS_INSTALL_URL = "https://github.com/kubernetes-sigs/lws/releases/download/v0.5.1/manifests.yaml" diff --git a/keras_remote/cli/infra/post_deploy.py b/keras_remote/cli/infra/post_deploy.py index 9f0930b..0027f3d 100644 --- a/keras_remote/cli/infra/post_deploy.py +++ b/keras_remote/cli/infra/post_deploy.py @@ -7,7 +7,10 @@ import os import subprocess -from keras_remote.cli.constants import NVIDIA_DRIVER_DAEMONSET_URL +from keras_remote.cli.constants import ( + LWS_INSTALL_URL, + NVIDIA_DRIVER_DAEMONSET_URL, +) def configure_docker_auth(ar_location): @@ -62,3 +65,14 @@ def install_gpu_drivers(): ["kubectl", "apply", "-f", NVIDIA_DRIVER_DAEMONSET_URL], check=True, ) + + +def install_lws(): + """Install the LeaderWorkerSet custom resource controller. + + This enables Pathways scheduling on the GKE cluster. + """ + subprocess.run( + ["kubectl", "apply", "--server-side", "-f", LWS_INSTALL_URL], + check=True, + ) diff --git a/keras_remote/core/core.py b/keras_remote/core/core.py index b6e00a9..34bff2b 100644 --- a/keras_remote/core/core.py +++ b/keras_remote/core/core.py @@ -4,8 +4,10 @@ from keras_remote.backend.execution import ( GKEBackend, JobContext, + PathwaysBackend, execute_remote, ) +from keras_remote.core import accelerators def run( @@ -15,6 +17,7 @@ def run( project=None, capture_env_vars=None, cluster=None, + backend=None, namespace="default", ): """Execute function on remote TPU/GPU. @@ -26,7 +29,8 @@ def run( project: GCP project (default: from KERAS_REMOTE_PROJECT) capture_env_vars: List of environment variable names or patterns (ending in *) to propagate to the remote environment. Defaults to None. - cluster: GKE cluster name (default: from KERAS_REMOTE_GKE_CLUSTER) + cluster: GKE cluster name (default: from KERAS_REMOTE_CLUSTER) + backend: Backend to use ('gke' or 'pathways') namespace: Kubernetes namespace (default: 'default') """ @@ -45,18 +49,53 @@ def wrapper(*args, **kwargs): elif pattern in os.environ: env_vars[pattern] = os.environ[pattern] - return _execute_on_gke( - func, - args, - kwargs, - accelerator, - container_image, - zone, - project, - cluster, - namespace, - env_vars, - ) + # Resolve backend + resolved_backend = backend + if resolved_backend is None: + try: + accel_config = accelerators.parse_accelerator(accelerator) + # Use Pathways for multi-host TPUs (if supported) or simplified logic + # For now, let's default to GKE unless explicit or strictly needed + if ( + isinstance(accel_config, accelerators.TpuConfig) + and accel_config.num_nodes > 1 + ): + resolved_backend = "pathways" + else: + resolved_backend = "gke" + except ValueError: + resolved_backend = "gke" + + if resolved_backend == "gke": + return _execute_on_gke( + func, + args, + kwargs, + accelerator, + container_image, + zone, + project, + cluster, + namespace, + env_vars, + ) + elif resolved_backend == "pathways": + return _execute_on_pathways( + func, + args, + kwargs, + accelerator, + container_image, + zone, + project, + cluster, + namespace, + env_vars, + ) + else: + raise ValueError( + f"Unknown backend: {resolved_backend}. Use 'gke', 'pathways', or None for auto-detection" + ) return wrapper @@ -75,10 +114,10 @@ def _execute_on_gke( namespace, env_vars, ): - """Execute function on GKE cluster with GPU nodes.""" + """Execute function on GKE cluster with GPU/TPU nodes.""" # Get GKE-specific defaults if not cluster: - cluster = os.environ.get("KERAS_REMOTE_GKE_CLUSTER") + cluster = os.environ.get("KERAS_REMOTE_CLUSTER") if not namespace: namespace = os.environ.get("KERAS_REMOTE_GKE_NAMESPACE", "default") @@ -86,3 +125,29 @@ def _execute_on_gke( func, args, kwargs, accelerator, container_image, zone, project, env_vars ) return execute_remote(ctx, GKEBackend(cluster=cluster, namespace=namespace)) + + +def _execute_on_pathways( + func, + args, + kwargs, + accelerator, + container_image, + zone, + project, + cluster, + namespace, + env_vars, +): + """Execute function on GKE cluster via ML Pathways.""" + if not cluster: + cluster = os.environ.get("KERAS_REMOTE_CLUSTER") + if not namespace: + namespace = os.environ.get("KERAS_REMOTE_GKE_NAMESPACE", "default") + + ctx = JobContext.from_params( + func, args, kwargs, accelerator, container_image, zone, project, env_vars + ) + return execute_remote( + ctx, PathwaysBackend(cluster=cluster, namespace=namespace) + )