Skip to content

Commit 541b6be

Browse files
Multi-stack support
1 parent 91ba1da commit 541b6be

File tree

8 files changed

+61
-22
lines changed

8 files changed

+61
-22
lines changed

keras_remote/backend/execution.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from google.api_core import exceptions as google_exceptions
1717

1818
from keras_remote.backend import gke_client, pathways_client
19-
from keras_remote.constants import get_default_zone, zone_to_region
19+
from keras_remote.constants import (
20+
get_default_cluster_name,
21+
get_default_zone,
22+
zone_to_region,
23+
)
2024
from keras_remote.credentials import ensure_credentials
2125
from keras_remote.data import _make_data_ref
2226
from keras_remote.infra import container_builder
@@ -39,6 +43,7 @@ class JobContext:
3943
container_image: Optional[str]
4044
zone: str
4145
project: str
46+
cluster_name: str
4247

4348
# Generated identifiers
4449
job_id: str = field(default_factory=lambda: f"job-{uuid.uuid4().hex[:8]}")
@@ -58,7 +63,7 @@ class JobContext:
5863
image_uri: Optional[str] = None
5964

6065
def __post_init__(self):
61-
self.bucket_name = f"{self.project}-keras-remote-jobs"
66+
self.bucket_name = f"{self.project}-kr-{self.cluster_name}-jobs"
6267
self.region = zone_to_region(self.zone)
6368
self.display_name = f"keras-remote-{self.func.__name__}-{self.job_id}"
6469

@@ -73,9 +78,10 @@ def from_params(
7378
zone: Optional[str],
7479
project: Optional[str],
7580
env_vars: dict,
81+
cluster_name: Optional[str] = None,
7682
volumes: Optional[dict] = None,
7783
) -> "JobContext":
78-
"""Factory method with default resolution for zone/project."""
84+
"""Factory method with default resolution for zone/project/cluster."""
7985
if not zone:
8086
zone = get_default_zone()
8187
if not project:
@@ -85,6 +91,8 @@ def from_params(
8591
"project must be specified or set KERAS_REMOTE_PROJECT"
8692
" (or GOOGLE_CLOUD_PROJECT) environment variable"
8793
)
94+
if not cluster_name:
95+
cluster_name = get_default_cluster_name()
8896

8997
return cls(
9098
func=func,
@@ -95,6 +103,7 @@ def from_params(
95103
container_image=container_image,
96104
zone=zone,
97105
project=project,
106+
cluster_name=cluster_name,
98107
volumes=volumes,
99108
)
100109

@@ -303,6 +312,7 @@ def _build_container(ctx: JobContext) -> None:
303312
accelerator_type=ctx.accelerator,
304313
project=ctx.project,
305314
zone=ctx.zone,
315+
cluster_name=ctx.cluster_name,
306316
)
307317

308318

keras_remote/backend/execution_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def test_post_init_derived_fields(self):
4040
container_image=None,
4141
zone="europe-west4-b",
4242
project="my-proj",
43+
cluster_name="my-cluster",
4344
)
44-
self.assertEqual(ctx.bucket_name, "my-proj-keras-remote-jobs")
45+
self.assertEqual(ctx.bucket_name, "my-proj-kr-my-cluster-jobs")
4546
self.assertEqual(ctx.region, "europe-west4")
4647
self.assertTrue(ctx.display_name.startswith("keras-remote-my_train-"))
4748
self.assertRegex(ctx.job_id, r"^job-[0-9a-f]{8}$")
@@ -171,6 +172,7 @@ def _make_ctx(self, container_image=None):
171172
container_image=container_image,
172173
zone="us-central1-a",
173174
project="proj",
175+
cluster_name="keras-remote-cluster",
174176
)
175177

176178
def test_success_flow(self):

keras_remote/cli/infra/program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def pulumi_program():
7171
# 2. Artifact Registry docker repository
7272
repo = gcp.artifactregistry.Repository(
7373
"keras-remote-repo",
74-
repository_id="keras-remote",
74+
repository_id=f"kr-{cluster_name}",
7575
location=ar_location,
7676
format="DOCKER",
7777
description="keras-remote container images",
@@ -84,7 +84,7 @@ def pulumi_program():
8484

8585
gcp.storage.Bucket(
8686
"keras-remote-jobs-bucket",
87-
name=f"{project_id}-keras-remote-jobs",
87+
name=f"{project_id}-kr-{cluster_name}-jobs",
8888
location=region,
8989
project=project_id,
9090
force_destroy=True,
@@ -93,7 +93,7 @@ def pulumi_program():
9393

9494
gcp.storage.Bucket(
9595
"keras-remote-builds-bucket",
96-
name=f"{project_id}-keras-remote-builds",
96+
name=f"{project_id}-kr-{cluster_name}-builds",
9797
location=ar_location,
9898
project=project_id,
9999
force_destroy=True,
@@ -170,7 +170,7 @@ def pulumi_program():
170170
pulumi.export(
171171
"ar_registry",
172172
repo.name.apply(
173-
lambda _: f"{ar_location}-docker.pkg.dev/{project_id}/keras-remote"
173+
lambda _: f"{ar_location}-docker.pkg.dev/{project_id}/kr-{cluster_name}"
174174
),
175175
)
176176

keras_remote/cli/infra/stack_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def get_stack(program_fn, config):
3333
click.echo("Pulumi CLI not found. Installing...")
3434
pulumi_cmd = auto.PulumiCommand.install(root=PULUMI_ROOT)
3535

36-
# Use project ID as stack name so each GCP project gets its own stack
37-
stack_name = config.project
36+
# Each (project, cluster) pair gets its own stack, so multiple clusters
37+
# within the same GCP project are fully independent.
38+
stack_name = f"{config.project}-{config.cluster_name}"
3839

3940
project_settings = auto.ProjectSettings(
4041
name=RESOURCE_NAME_PREFIX,

keras_remote/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def get_default_zone():
1313
return os.environ.get(ZONE_ENV_VAR, DEFAULT_ZONE)
1414

1515

16+
def get_default_cluster_name():
17+
"""Return cluster name from KERAS_REMOTE_CLUSTER env var, or DEFAULT_CLUSTER_NAME."""
18+
return os.environ.get("KERAS_REMOTE_CLUSTER", DEFAULT_CLUSTER_NAME)
19+
20+
1621
def zone_to_region(zone):
1722
"""Convert a GCP zone to its region (e.g. 'us-central1-a' -> 'us-central1')."""
1823
return zone.rsplit("-", 1)[0] if zone and "-" in zone else DEFAULT_REGION

keras_remote/core/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _execute_on_gke(
154154
zone,
155155
project,
156156
env_vars,
157+
cluster_name=cluster,
157158
volumes=volumes,
158159
)
159160
return execute_remote(ctx, GKEBackend(cluster=cluster, namespace=namespace))
@@ -187,6 +188,7 @@ def _execute_on_pathways(
187188
zone,
188189
project,
189190
env_vars,
191+
cluster_name=cluster,
190192
volumes=volumes,
191193
)
192194
return execute_remote(

keras_remote/infra/container_builder.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from google.cloud import artifactregistry_v1, storage
1515
from google.cloud.devtools import cloudbuild_v1
1616

17-
from keras_remote.constants import get_default_zone, zone_to_ar_location
17+
from keras_remote.constants import (
18+
get_default_cluster_name,
19+
get_default_zone,
20+
zone_to_ar_location,
21+
)
1822
from keras_remote.core import accelerators
1923

2024
REMOTE_RUNNER_FILE_NAME = "remote_runner.py"
@@ -26,7 +30,12 @@
2630

2731

2832
def get_or_build_container(
29-
base_image, requirements_path, accelerator_type, project, zone=None
33+
base_image,
34+
requirements_path,
35+
accelerator_type,
36+
project,
37+
zone=None,
38+
cluster_name=None,
3039
):
3140
"""Get existing container or build if requirements changed.
3241
@@ -38,11 +47,13 @@ def get_or_build_container(
3847
accelerator_type: TPU/GPU type (e.g., 'v3-8')
3948
project: GCP project ID
4049
zone: GCP zone for region derivation (defaults to KERAS_REMOTE_ZONE)
50+
cluster_name: GKE cluster name (defaults to KERAS_REMOTE_CLUSTER)
4151
4252
Returns:
4353
Container image URI in Artifact Registry
4454
"""
4555
ar_location = zone_to_ar_location(zone or get_default_zone())
56+
cluster_name = cluster_name or get_default_cluster_name()
4657
category = accelerators.get_category(accelerator_type)
4758

4859
# Generate deterministic hash from requirements + base image + category
@@ -53,8 +64,9 @@ def get_or_build_container(
5364
# Use category for image name (e.g., 'tpu-hash', 'gpu-hash')
5465
image_tag = f"{category}-{requirements_hash[:12]}"
5566

56-
# Use Artifact Registry
57-
registry = f"{ar_location}-docker.pkg.dev/{project}/keras-remote"
67+
# Use Artifact Registry (cluster-scoped repo)
68+
repo_id = f"kr-{cluster_name}"
69+
registry = f"{ar_location}-docker.pkg.dev/{project}/{repo_id}"
5870
image_uri = f"{registry}/base:{image_tag}"
5971

6072
# Check if image exists
@@ -63,7 +75,7 @@ def get_or_build_container(
6375
ar_url = (
6476
"https://console.cloud.google.com/artifacts"
6577
f"/docker/{project}/{ar_location}"
66-
f"/keras-remote/base?project={project}"
78+
f"/{repo_id}/base?project={project}"
6779
)
6880
logging.info("View image: %s", ar_url)
6981
return image_uri
@@ -77,6 +89,7 @@ def get_or_build_container(
7789
project,
7890
image_uri,
7991
ar_location,
92+
cluster_name,
8093
)
8194

8295

@@ -155,6 +168,7 @@ def _build_and_push(
155168
project,
156169
image_uri,
157170
ar_location="us",
171+
cluster_name=None,
158172
):
159173
"""Build and push Docker image using Cloud Build.
160174
@@ -200,8 +214,9 @@ def _build_and_push(
200214
os.path.join(tmpdir, "requirements.txt"), arcname="requirements.txt"
201215
)
202216

203-
# Upload source to GCS
204-
bucket_name = f"{project}-keras-remote-builds"
217+
# Upload source to GCS (cluster-scoped bucket)
218+
cluster_name = cluster_name or get_default_cluster_name()
219+
bucket_name = f"{project}-kr-{cluster_name}-builds"
205220
source_gcs = _upload_build_source(tarball_path, bucket_name, project)
206221

207222
# Submit build to Cloud Build

keras_remote/infra/container_builder_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,11 @@ def test_returns_cached_when_image_exists(self):
236236
accelerator_type="l4",
237237
project="test-proj",
238238
zone="us-central1-a",
239+
cluster_name="my-cluster",
239240
)
240241

241242
mock_build.assert_not_called()
242-
self.assertIn("us-docker.pkg.dev/test-proj/keras-remote/base:", result)
243+
self.assertIn("us-docker.pkg.dev/test-proj/kr-my-cluster/base:", result)
243244

244245
def test_builds_when_image_missing(self):
245246
with (
@@ -249,7 +250,7 @@ def test_builds_when_image_missing(self):
249250
),
250251
mock.patch(
251252
"keras_remote.infra.container_builder._build_and_push",
252-
return_value="us-docker.pkg.dev/proj/keras-remote/base:gpu-bbbbbbbbbbbb",
253+
return_value="us-docker.pkg.dev/proj/kr-my-cluster/base:gpu-bbbbbbbbbbbb",
253254
) as mock_build,
254255
):
255256
result = get_or_build_container(
@@ -258,11 +259,13 @@ def test_builds_when_image_missing(self):
258259
accelerator_type="l4",
259260
project="proj",
260261
zone="us-central1-a",
262+
cluster_name="my-cluster",
261263
)
262264

263265
mock_build.assert_called_once()
264266
self.assertEqual(
265-
result, "us-docker.pkg.dev/proj/keras-remote/base:gpu-bbbbbbbbbbbb"
267+
result,
268+
"us-docker.pkg.dev/proj/kr-my-cluster/base:gpu-bbbbbbbbbbbb",
266269
)
267270

268271
def _get_image_uri(self, accelerator_type, project, zone):
@@ -276,13 +279,14 @@ def _get_image_uri(self, accelerator_type, project, zone):
276279
accelerator_type=accelerator_type,
277280
project=project,
278281
zone=zone,
282+
cluster_name="my-cluster",
279283
)
280284

281285
def test_image_uri_format_tpu_europe(self):
282286
result = self._get_image_uri("v3-4", "my-proj", "europe-west4-b")
283287

284288
self.assertTrue(
285-
result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:")
289+
result.startswith("europe-docker.pkg.dev/my-proj/kr-my-cluster/base:")
286290
)
287291
tag = result.split(":")[-1]
288292
self.assertRegex(tag, r"^tpu-[0-9a-f]{12}$")
@@ -291,7 +295,7 @@ def test_image_uri_format_gpu_us(self):
291295
result = self._get_image_uri("a100-80gb", "proj", "us-central1-a")
292296

293297
self.assertTrue(
294-
result.startswith("us-docker.pkg.dev/proj/keras-remote/base:")
298+
result.startswith("us-docker.pkg.dev/proj/kr-my-cluster/base:")
295299
)
296300
tag = result.split(":")[-1]
297301
self.assertRegex(tag, r"^gpu-[0-9a-f]{12}$")

0 commit comments

Comments
 (0)