Skip to content

Commit 57a3604

Browse files
Adds auto-login support to @run()
1 parent ebb54cb commit 57a3604

File tree

8 files changed

+616
-149
lines changed

8 files changed

+616
-149
lines changed

AGENTS.md

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ keras_remote/
1616
├── cli/ # CLI for infrastructure provisioning (Pulumi-based)
1717
│ ├── commands/ # up, down, status, config
1818
│ └── infra/ # Pulumi programs and stack management
19+
├── credentials.py # Credential verification & auto-setup (shared by core & CLI)
1920
└── constants.py # Zone/region utilities
2021
```
2122

@@ -24,6 +25,7 @@ keras_remote/
2425
```python
2526
@keras_remote.run() called
2627
→ JobContext.from_params() # Resolve config from args/env vars
28+
→ ensure_credentials() # Verify/auto-configure gcloud, ADC, kubeconfig, Docker auth
2729
→ _prepare_artifacts() # Serialize function (cloudpickle), zip working dir
2830
→ _build_container() # Build or retrieve cached Docker image
2931
→ _upload_artifacts() # Upload payload.pkl, context.zip to GCS
@@ -35,23 +37,24 @@ keras_remote/
3537

3638
## Key Modules
3739

38-
| Module | Responsibility |
39-
| ---------------------------- | ----------------------------------------------------------------------------- |
40-
| `core/core.py` | `@run()` decorator, backend routing, env var capture |
41-
| `core/accelerators.py` | Accelerator registry (`GPUS`, `TPUS`), parser (`parse_accelerator`) |
42-
| `backend/execution.py` | `JobContext` dataclass, `BackendClient` protocol, `execute_remote()` pipeline |
43-
| `backend/gke_client.py` | K8s Job creation, status polling, pod log retrieval |
44-
| `backend/pathways_client.py` | LeaderWorkerSet creation for multi-host TPUs |
45-
| `infra/container_builder.py` | Content-hashed Docker image building via Cloud Build |
46-
| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()` |
47-
| `utils/storage.py` | GCS upload/download/cleanup for job artifacts |
48-
| `runner/remote_runner.py` | Runs inside container: deserialize, execute, upload result |
49-
| `cli/main.py` | CLI entry point (`keras-remote` command) |
40+
| Module | Responsibility |
41+
| ---------------------------- | -------------------------------------------------------------------------------- |
42+
| `core/core.py` | `@run()` decorator, backend routing, env var capture |
43+
| `core/accelerators.py` | Accelerator registry (`GPUS`, `TPUS`), parser (`parse_accelerator`) |
44+
| `credentials.py` | Credential verification & auto-setup (gcloud, ADC, kubeconfig, Docker auth) |
45+
| `backend/execution.py` | `JobContext` dataclass, `BaseK8sBackend` base class, `execute_remote()` pipeline |
46+
| `backend/gke_client.py` | K8s Job creation, status polling, pod log retrieval |
47+
| `backend/pathways_client.py` | LeaderWorkerSet creation for multi-host TPUs |
48+
| `infra/container_builder.py` | Content-hashed Docker image building via Cloud Build |
49+
| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()` |
50+
| `utils/storage.py` | GCS upload/download/cleanup for job artifacts |
51+
| `runner/remote_runner.py` | Runs inside container: deserialize, execute, upload result |
52+
| `cli/main.py` | CLI entry point (`keras-remote` command) |
5053

5154
## Key Abstractions
5255

5356
- **`JobContext`** (`backend/execution.py`): Mutable dataclass carrying all job state through the pipeline — inputs, generated IDs, artifact paths, image URI.
54-
- **`BackendClient`** protocol (`backend/execution.py`): Interface with `submit_job`, `wait_for_job`, `cleanup_job`. Implemented by `GKEBackend` and `PathwaysBackend`.
57+
- **`BaseK8sBackend`** (`backend/execution.py`): Base class with `submit_job`, `wait_for_job`, `cleanup_job`. Subclassed by `GKEBackend` and `PathwaysBackend`.
5558
- **`GpuConfig` / `TpuConfig`** (`core/accelerators.py`): Frozen dataclasses for accelerator metadata. Single source of truth used by runtime, container builder, and CLI.
5659
- **`InfraConfig`** (`cli/config.py`): CLI provisioning configuration (project, zone, cluster, accelerator).
5760

@@ -62,7 +65,6 @@ keras_remote/
6265
- **Formatter/linter**: `ruff` (2-space indent, 80-char line length target, E501 ignored)
6366
- **Rules**: B, E, F, N, PYI, T20, TID, SIM, W, I, NPY
6467
- **Dataclasses**: Frozen for immutable configs, mutable for state objects
65-
- **Protocols**: Used for backend abstraction (not ABCs)
6668

6769
### Environment Variables
6870

examples/example_gke.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
"""
3636

3737
import os
38-
import subprocess
38+
39+
os.environ["KERAS_BACKEND"] = "jax"
3940

4041
import keras
4142
import numpy as np
@@ -135,31 +136,6 @@ def main():
135136
if __name__ == "__main__":
136137
# Prerequisites:
137138
# 1. Set KERAS_REMOTE_PROJECT environment variable to your GCP project ID
138-
# 2. Configure kubectl: gcloud container clusters get-credentials <cluster> --zone <zone>
139-
# 3. Ensure your GKE cluster has GPU nodes with the required accelerator type
140-
if not os.environ.get("KERAS_REMOTE_PROJECT"):
141-
print("ERROR: KERAS_REMOTE_PROJECT environment variable not set")
142-
print("Please set it to your GCP project ID:")
143-
print(" export KERAS_REMOTE_PROJECT=your-project-id")
144-
exit(1)
145-
146-
# Verify kubectl is configured
147-
try:
148-
result = subprocess.run(
149-
["kubectl", "cluster-info"], capture_output=True, text=True, timeout=10
150-
)
151-
if result.returncode != 0:
152-
print("ERROR: kubectl is not configured or cluster is not accessible")
153-
print("Please configure kubectl:")
154-
print(
155-
" gcloud container clusters get-credentials <cluster-name> --zone <zone>"
156-
)
157-
exit(1)
158-
except FileNotFoundError:
159-
print("ERROR: kubectl not found. Please install kubectl.")
160-
exit(1)
161-
except subprocess.TimeoutExpired:
162-
print("ERROR: kubectl timed out. Check your cluster connectivity.")
163-
exit(1)
164-
139+
# (if `project` param is not provided in the decorator)
140+
# 2. Ensure your GKE cluster has GPU nodes with the required accelerator type
165141
main()

keras_remote/backend/execution.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
import tempfile
1010
import uuid
1111
from dataclasses import dataclass, field
12-
from typing import Any, Callable, Optional, Protocol
12+
from typing import Any, Callable, Optional
1313

1414
import cloudpickle
1515
from absl import logging
1616
from google.api_core import exceptions as google_exceptions
1717

1818
from keras_remote.backend import gke_client, pathways_client
1919
from keras_remote.constants import get_default_zone, zone_to_region
20+
from keras_remote.credentials import ensure_credentials
2021
from keras_remote.infra import container_builder
2122
from keras_remote.utils import packager, storage
2223

@@ -90,28 +91,24 @@ def from_params(
9091
)
9192

9293

93-
class BackendClient(Protocol):
94-
"""Protocol defining the interface for backend clients."""
94+
class BaseK8sBackend:
95+
"""Base class for Kubernetes-based backends."""
96+
97+
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
98+
self.cluster = cluster
99+
self.namespace = namespace
95100

96101
def submit_job(self, ctx: JobContext) -> Any:
97102
"""Submit a job to the backend. Returns backend-specific job handle."""
98-
...
103+
raise NotImplementedError
99104

100105
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
101106
"""Wait for job completion. Raises RuntimeError if job fails."""
102-
...
107+
raise NotImplementedError
103108

104109
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
105110
"""Optional cleanup after job completion."""
106-
...
107-
108-
109-
class BaseK8sBackend:
110-
"""Base class for Kubernetes-based backends."""
111-
112-
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
113-
self.cluster = cluster
114-
self.namespace = namespace
111+
raise NotImplementedError
115112

116113

117114
class GKEBackend(BaseK8sBackend):
@@ -264,22 +261,28 @@ def _cleanup_and_return(ctx: JobContext, result_payload: dict) -> Any:
264261
raise result_payload["exception"]
265262

266263

267-
def execute_remote(ctx: JobContext, backend: BackendClient) -> Any:
264+
def execute_remote(ctx: JobContext, backend: BaseK8sBackend) -> Any:
268265
"""Execute a function remotely using the specified backend.
269266
270267
This is the unified executor that handles all common phases
271268
and delegates backend-specific operations to the backend client.
272269
273270
Args:
274271
ctx: Job context with function and configuration
275-
backend: Backend client implementing BackendClient protocol
272+
backend: Backend instance (GKEBackend or PathwaysBackend)
276273
277274
Returns:
278275
The result of the remote function execution
279276
280277
Raises:
281278
Exception: Re-raised from remote execution if it failed
282279
"""
280+
ensure_credentials(
281+
project=ctx.project,
282+
zone=ctx.zone,
283+
cluster=backend.cluster,
284+
)
285+
283286
with tempfile.TemporaryDirectory() as tmpdir:
284287
# Phase 1: Package artifacts
285288
_prepare_artifacts(ctx, tmpdir)

keras_remote/backend/execution_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _make_ctx(self, container_image=None):
151151

152152
def test_success_flow(self):
153153
with (
154+
mock.patch("keras_remote.backend.execution.ensure_credentials"),
154155
mock.patch("keras_remote.backend.execution._build_container"),
155156
mock.patch("keras_remote.backend.execution._upload_artifacts"),
156157
mock.patch(
@@ -174,6 +175,7 @@ def test_success_flow(self):
174175

175176
def test_cleanup_on_wait_failure(self):
176177
with (
178+
mock.patch("keras_remote.backend.execution.ensure_credentials"),
177179
mock.patch("keras_remote.backend.execution._build_container"),
178180
mock.patch("keras_remote.backend.execution._upload_artifacts"),
179181
mock.patch(
@@ -191,6 +193,34 @@ def test_cleanup_on_wait_failure(self):
191193
# cleanup_job is called in finally block even when wait fails
192194
backend.cleanup_job.assert_called_once()
193195

196+
def test_ensure_credentials_called_with_correct_args(self):
197+
with (
198+
mock.patch(
199+
"keras_remote.backend.execution.ensure_credentials"
200+
) as mock_creds,
201+
mock.patch("keras_remote.backend.execution._build_container"),
202+
mock.patch("keras_remote.backend.execution._upload_artifacts"),
203+
mock.patch(
204+
"keras_remote.backend.execution._download_result",
205+
return_value={"success": True, "result": 0},
206+
),
207+
mock.patch(
208+
"keras_remote.backend.execution._cleanup_and_return",
209+
return_value=0,
210+
),
211+
):
212+
ctx = self._make_ctx()
213+
backend = MagicMock()
214+
backend.cluster = "test-cluster"
215+
216+
execute_remote(ctx, backend)
217+
218+
mock_creds.assert_called_once_with(
219+
project="proj",
220+
zone="us-central1-a",
221+
cluster="test-cluster",
222+
)
223+
194224

195225
if __name__ == "__main__":
196226
absltest.main()

keras_remote/cli/prerequisites_check.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1-
"""Prerequisite checks for the keras-remote CLI."""
1+
"""Prerequisite checks for the keras-remote CLI.
2+
3+
Delegates common credential checks (gcloud, auth plugin, ADC) to
4+
:mod:`keras_remote.credentials` and converts ``RuntimeError`` into
5+
``click.ClickException``. CLI-only tool checks (Pulumi, kubectl, Docker)
6+
remain here.
7+
"""
28

39
import shutil
4-
import subprocess
510

611
import click
712

8-
from keras_remote.cli.output import warning
13+
from keras_remote import credentials
914

1015

1116
def check_gcloud():
1217
"""Verify gcloud CLI is installed."""
13-
if not shutil.which("gcloud"):
14-
raise click.ClickException(
15-
"gcloud CLI not found. "
16-
"Install from: https://cloud.google.com/sdk/docs/install"
17-
)
18+
try:
19+
credentials.ensure_gcloud()
20+
except RuntimeError as e:
21+
raise click.ClickException(str(e)) # noqa: B904
1822

1923

2024
def check_pulumi():
@@ -42,37 +46,19 @@ def check_docker():
4246

4347

4448
def check_gke_auth_plugin():
45-
"""Verify gke-gcloud-auth-plugin is installed."""
46-
if not shutil.which("gke-gcloud-auth-plugin"):
47-
warning("gke-gcloud-auth-plugin not found.")
48-
if click.confirm(
49-
"Install it via `gcloud components install gke-gcloud-auth-plugin`?",
50-
default=True,
51-
):
52-
subprocess.run(
53-
["gcloud", "components", "install", "gke-gcloud-auth-plugin"],
54-
check=True,
55-
)
56-
else:
57-
raise click.ClickException(
58-
"gke-gcloud-auth-plugin is required. "
59-
"Install with: gcloud components install gke-gcloud-auth-plugin"
60-
)
49+
"""Verify gke-gcloud-auth-plugin is installed; auto-install if missing."""
50+
try:
51+
credentials.ensure_gke_auth_plugin()
52+
except RuntimeError as e:
53+
raise click.ClickException(str(e)) # noqa: B904
6154

6255

6356
def check_gcloud_auth():
6457
"""Check if gcloud Application Default Credentials are configured."""
65-
result = subprocess.run(
66-
["gcloud", "auth", "application-default", "print-access-token"],
67-
capture_output=True,
68-
)
69-
if result.returncode != 0:
70-
warning("Application Default Credentials not found.")
71-
click.echo("Running: gcloud auth application-default login")
72-
subprocess.run(
73-
["gcloud", "auth", "application-default", "login"],
74-
check=True,
75-
)
58+
try:
59+
credentials.ensure_adc()
60+
except RuntimeError as e:
61+
raise click.ClickException(str(e)) # noqa: B904
7662

7763

7864
def check_all():

0 commit comments

Comments
 (0)