Skip to content

Commit d7fbf7e

Browse files
Merge branch 'keras-team:main' into simplify-hw-names
2 parents 4744905 + b57bf26 commit d7fbf7e

26 files changed

+1351
-681
lines changed

.gemini/styleguide.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
When performing code reviews on pull requests, you must strictly adhere to the following principles in addition to the API design guidelines above:
2+
3+
1. **Question the Necessity of Changes**: Do not assume that the pull request changes are strictly necessary. Critically review the proposed changes to ensure they add real value. Point out any code that solving a non-existent problem or adding unnecessary complexity.
4+
5+
2. **Call out "AI Slop"**: Actively look for and identify "AI slop"—generic, overly verbose, or hallucinated code that lacks context or violates best practices. If you suspect the code is AI slop, explicitly call it out.
6+
7+
3. **Poke Holes in the Implementation**: Your goal is to critically test the logic. Actively search for and point out failing edge cases, race conditions, or unhandled exceptions in the implementation.
8+
9+
4. **Demand Robustness**: Do not accept fragile code. If the proposed code is not robust enough or lacks proper error handling, explicitly tell the author why the current approach is brittle and what must be done to reinforce it.
10+
11+
5. **Respect Existing Repo Patterns**: Before suggesting review comments (like asking users to add boilerplate or specific patterns), actively check for existing design patterns across the repository. Do not suggest adding useless code or structures that contradict or fall outside the established Keras repo coding style.
12+
13+
14+
15+
116
# Keras Remote API design guidelines
217

318
These guidelines are meant to help focus design discussions and help us create delightful developer experiences for remote execution.

README.md

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ This adds the `keras-remote up`, `keras-remote down`, `keras-remote status`, and
7272
- Python 3.11+
7373
- Google Cloud SDK (`gcloud`)
7474
- Run `gcloud auth login` and `gcloud auth application-default login`
75-
- [Pulumi CLI](https://www.pulumi.com/docs/install/) (required for `[cli]` install only)
7675
- A Google Cloud project with billing enabled
7776

77+
Note: The Pulumi CLI is bundled and managed automatically. It will be installed to `~/.keras-remote/pulumi` on first use if not already present.
78+
7879
## Quick Start
7980

8081
### 1. Configure Google Cloud
@@ -203,15 +204,102 @@ def train():
203204

204205
See [examples/Dockerfile.prebuilt](examples/Dockerfile.prebuilt) for a template.
205206

207+
## Handling Data
208+
209+
Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions.
210+
211+
The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function.
212+
213+
Under the hood, the Data API optimizes your workflows with two key features:
214+
215+
- **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution.
216+
- **Automatic Zip Exclusion:** When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice.
217+
218+
There are three main ways to handle data depending on your workflow:
219+
220+
### 1. Dynamic Data (The `Data` Class)
221+
222+
The simplest and most Pythonic approach is to pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI.
223+
224+
On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs.
225+
226+
```python
227+
import pandas as pd
228+
import keras_remote
229+
from keras_remote import Data
230+
231+
@keras_remote.run(accelerator="v6e-8")
232+
def train(data_dir):
233+
# data_dir is resolved to a dynamic local path on the remote machine
234+
df = pd.read_csv(f"{data_dir}/train.csv")
235+
# ...
236+
237+
# Uploads the local directory to the remote pod automatically
238+
train(Data("./my_dataset/"))
239+
240+
# Cache hit: subsequent runs with the same data skip the upload!
241+
train(Data("./my_dataset/"))
242+
```
243+
244+
**Note on GCS Directories:** When referencing a GCS directory with the `Data` class, you must include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). If you omit the trailing slash, the system will treat it as a single file object.
245+
246+
You can also pass multiple `Data` arguments, or nest them inside lists and dictionaries (e.g., `train(datasets=[Data("./d1"), Data("./d2")])`).
247+
248+
### 2. Static Data (The `volumes` Parameter)
249+
250+
For established training scripts where data requirements are static, you can use the `volumes` parameter in the `@keras_remote.run` decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop `keras_remote` into existing codebases without altering the function signature.
251+
252+
```python
253+
import pandas as pd
254+
import keras_remote
255+
from keras_remote import Data
256+
257+
@keras_remote.run(
258+
accelerator="v6e-8",
259+
volumes={
260+
"/data": Data("./my_dataset/"),
261+
"/weights": Data("gs://my-bucket/pretrained-weights/")
262+
}
263+
)
264+
def train():
265+
# Data is guaranteed to be available at these absolute paths
266+
df = pd.read_csv("/data/train.csv")
267+
model.load_weights("/weights/model.h5")
268+
# ...
269+
270+
# No data arguments needed!
271+
train()
272+
273+
```
274+
275+
### 3. Direct GCS Streaming (For Large Datasets)
276+
277+
If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the `Data` wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly.
278+
279+
```python
280+
import grain.python as grain
281+
import keras_remote
282+
283+
@keras_remote.run(accelerator="v6e-8")
284+
def train(data_uri):
285+
# Native GCS reading, no download overhead
286+
data_source = grain.ArrayRecordDataSource(data_uri)
287+
# ...
288+
289+
# Pass as a plain string, no Data() wrapper needed
290+
train("gs://my-bucket/arrayrecords/")
291+
292+
```
293+
206294
## Configuration
207295

208296
### Environment Variables
209297

210-
| Variable | Required | Default | Description |
211-
| ---------------------- | -------- | --------------- | ---------------------------------- |
212-
| `KERAS_REMOTE_PROJECT` | Yes || Google Cloud project ID |
213-
| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone |
214-
| `KERAS_REMOTE_CLUSTER` | No || GKE cluster name |
298+
| Variable | Required | Default | Description |
299+
| ---------------------- | -------- | --------------- | ----------------------- |
300+
| `KERAS_REMOTE_PROJECT` | Yes || Google Cloud project ID |
301+
| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone |
302+
| `KERAS_REMOTE_CLUSTER` | No || GKE cluster name |
215303

216304
### Decorator Parameters
217305

@@ -345,10 +433,10 @@ keras-remote down
345433

346434
This removes:
347435

348-
- GKE cluster and accelerator node pools (via Pulumi)
436+
- GKE cluster and accelerator node pools
349437
- Artifact Registry repository and container images
350438
- Cloud Storage buckets (jobs and builds)
351-
Use `--yes` to skip the confirmation prompt.
439+
Use `--yes` to skip the confirmation prompt.
352440

353441
## Contributing
354442

keras_remote/backend/execution.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from google.api_core import exceptions as google_exceptions
1717

1818
from keras_remote.backend import gke_client, pathways_client
19-
from keras_remote.constants import get_default_zone, zone_to_region
19+
from keras_remote.constants import (
20+
get_default_cluster_name,
21+
get_default_zone,
22+
zone_to_region,
23+
)
2024
from keras_remote.credentials import ensure_credentials
2125
from keras_remote.data import _make_data_ref
2226
from keras_remote.infra import container_builder
@@ -39,6 +43,7 @@ class JobContext:
3943
container_image: Optional[str]
4044
zone: str
4145
project: str
46+
cluster_name: str
4247

4348
# Generated identifiers
4449
job_id: str = field(default_factory=lambda: f"job-{uuid.uuid4().hex[:8]}")
@@ -58,7 +63,7 @@ class JobContext:
5863
image_uri: Optional[str] = None
5964

6065
def __post_init__(self):
61-
self.bucket_name = f"{self.project}-keras-remote-jobs"
66+
self.bucket_name = f"{self.project}-kr-{self.cluster_name}-jobs"
6267
self.region = zone_to_region(self.zone)
6368
self.display_name = f"keras-remote-{self.func.__name__}-{self.job_id}"
6469

@@ -73,9 +78,10 @@ def from_params(
7378
zone: Optional[str],
7479
project: Optional[str],
7580
env_vars: dict,
81+
cluster_name: Optional[str] = None,
7682
volumes: Optional[dict] = None,
7783
) -> "JobContext":
78-
"""Factory method with default resolution for zone/project."""
84+
"""Factory method with default resolution for zone/project/cluster."""
7985
if not zone:
8086
zone = get_default_zone()
8187
if not project:
@@ -85,6 +91,8 @@ def from_params(
8591
"project must be specified or set KERAS_REMOTE_PROJECT"
8692
" (or GOOGLE_CLOUD_PROJECT) environment variable"
8793
)
94+
if not cluster_name:
95+
cluster_name = get_default_cluster_name()
8896

8997
return cls(
9098
func=func,
@@ -95,6 +103,7 @@ def from_params(
95103
container_image=container_image,
96104
zone=zone,
97105
project=project,
106+
cluster_name=cluster_name,
98107
volumes=volumes,
99108
)
100109

@@ -303,6 +312,7 @@ def _build_container(ctx: JobContext) -> None:
303312
accelerator_type=ctx.accelerator,
304313
project=ctx.project,
305314
zone=ctx.zone,
315+
cluster_name=ctx.cluster_name,
306316
)
307317

308318

keras_remote/backend/execution_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def test_post_init_derived_fields(self):
4040
container_image=None,
4141
zone="europe-west4-b",
4242
project="my-proj",
43+
cluster_name="my-cluster",
4344
)
44-
self.assertEqual(ctx.bucket_name, "my-proj-keras-remote-jobs")
45+
self.assertEqual(ctx.bucket_name, "my-proj-kr-my-cluster-jobs")
4546
self.assertEqual(ctx.region, "europe-west4")
4647
self.assertTrue(ctx.display_name.startswith("keras-remote-my_train-"))
4748
self.assertRegex(ctx.job_id, r"^job-[0-9a-f]{8}$")
@@ -171,6 +172,7 @@ def _make_ctx(self, container_image=None):
171172
container_image=container_image,
172173
zone="us-central1-a",
173174
project="proj",
175+
cluster_name="keras-remote-cluster",
174176
)
175177

176178
def test_success_flow(self):

keras_remote/backend/gke_client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,22 +437,38 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
437437
pool_labels = config_dict.get("labels", {}).copy()
438438

439439
# Map GKE injected node labels for accelerators mapping
440-
accelerators = config_dict.get("accelerators", [])
441-
if accelerators:
442-
accel_type = accelerators[0].get("acceleratorType", "")
440+
accel_config_list = config_dict.get("accelerators", [])
441+
if accel_config_list:
442+
accel_type = accel_config_list[0].get("acceleratorType", "")
443443
if accel_type.startswith("tpu-"):
444444
pool_labels["cloud.google.com/gke-tpu-accelerator"] = accel_type
445445
else:
446446
pool_labels["cloud.google.com/gke-accelerator"] = accel_type
447447

448448
# TPU mapping fallback
449449
machine_type = config_dict.get("machineType", "")
450+
451+
# Check resource labels for TPU type (common in v5e/v5litepod)
452+
resource_labels = config_dict.get("resourceLabels", {})
453+
if "goog-gke-accelerator-type" in resource_labels:
454+
pool_labels["cloud.google.com/gke-tpu-accelerator"] = resource_labels[
455+
"goog-gke-accelerator-type"
456+
]
457+
450458
if machine_type.startswith("ct"):
451459
# We roughly map TPU topology presence for preflight
452460
pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get(
453461
"cloud.google.com/gke-tpu-topology", ""
454462
)
455463

464+
# Infer accelerator count from machine type using registry
465+
# This is robust because it uses the same source of truth as the Pod spec generation
466+
for tpu_spec in accelerators.TPUS.values():
467+
for chips, topo_spec in tpu_spec.topologies.items():
468+
if topo_spec.machine_type == machine_type:
469+
pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips)
470+
break
471+
456472
if all(pool_labels.get(k) == str(v) for k, v in selector.items()):
457473
return True
458474
return False

keras_remote/backend/log_streaming.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
job execution.
66
"""
77

8-
import sys
98
import threading
10-
from collections import deque
119

1210
import urllib3
1311
from absl import logging
1412
from kubernetes.client.rest import ApiException
1513
from rich.console import Console
16-
from rich.live import Live
17-
from rich.panel import Panel
14+
15+
from keras_remote.cli.output import LiveOutputPanel
1816

1917
_MAX_DISPLAY_LINES = 25
2018

@@ -27,14 +25,13 @@ def _stream_pod_logs(core_v1, pod_name, namespace):
2725
2826
In interactive terminals, logs are displayed in a Rich Live panel.
2927
In non-interactive contexts (piped output, CI), logs are streamed
30-
as raw lines with Rich Rule delimiters.
28+
as plain lines with Rule delimiters.
3129
3230
Args:
3331
core_v1: Kubernetes CoreV1Api client.
3432
pod_name: Name of the pod to stream logs from.
3533
namespace: Kubernetes namespace.
3634
"""
37-
console = Console()
3835
resp = None
3936
try:
4037
resp = core_v1.read_namespaced_pod_log(
@@ -43,10 +40,22 @@ def _stream_pod_logs(core_v1, pod_name, namespace):
4340
follow=True,
4441
_preload_content=False,
4542
)
46-
if console.is_terminal:
47-
_render_live_panel(resp, pod_name, console)
48-
else:
49-
_render_plain(resp, pod_name, console)
43+
title = f"Remote logs \u2022 {pod_name}"
44+
with LiveOutputPanel(
45+
title,
46+
max_lines=_MAX_DISPLAY_LINES,
47+
target_console=Console(),
48+
show_subtitle=False,
49+
) as panel:
50+
buffer = ""
51+
for chunk in resp.stream(decode_content=True):
52+
buffer += chunk.decode("utf-8", errors="replace")
53+
while "\n" in buffer:
54+
line, buffer = buffer.split("\n", 1)
55+
panel.on_output(line)
56+
# Flush remaining partial line
57+
if buffer.strip():
58+
panel.on_output(buffer)
5059
except ApiException:
5160
pass # Pod deleted or not found
5261
except urllib3.exceptions.ProtocolError:
@@ -60,45 +69,6 @@ def _stream_pod_logs(core_v1, pod_name, namespace):
6069
resp.release_conn()
6170

6271

63-
def _render_live_panel(resp, pod_name, console):
64-
"""Render streaming logs inside a Rich Live panel."""
65-
lines = deque(maxlen=_MAX_DISPLAY_LINES)
66-
title = f"Remote logs \u2022 {pod_name}"
67-
buffer = ""
68-
69-
with Live(
70-
_make_log_panel(lines, title),
71-
console=console,
72-
refresh_per_second=4,
73-
) as live:
74-
for chunk in resp.stream(decode_content=True):
75-
buffer += chunk.decode("utf-8", errors="replace")
76-
while "\n" in buffer:
77-
line, buffer = buffer.split("\n", 1)
78-
lines.append(line)
79-
live.update(_make_log_panel(lines, title))
80-
81-
# Flush remaining partial line
82-
if buffer.strip():
83-
lines.append(buffer)
84-
live.update(_make_log_panel(lines, title))
85-
86-
87-
def _render_plain(resp, pod_name, console):
88-
"""Render streaming logs as raw lines with Rule delimiters."""
89-
console.rule(f"Remote logs ({pod_name})", style="blue")
90-
for chunk in resp.stream(decode_content=True):
91-
sys.stdout.write(chunk.decode("utf-8", errors="replace"))
92-
sys.stdout.flush()
93-
console.rule("End remote logs", style="blue")
94-
95-
96-
def _make_log_panel(lines, title):
97-
"""Build a Panel renderable from accumulated log lines."""
98-
content = "\n".join(lines) if lines else "Waiting for output..."
99-
return Panel(content, title=title, border_style="blue")
100-
101-
10272
class LogStreamer:
10373
"""Context manager that owns the log-streaming thread lifecycle.
10474

0 commit comments

Comments
 (0)