-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexecution.py
More file actions
434 lines (359 loc) · 13.4 KB
/
execution.py
File metadata and controls
434 lines (359 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
"""Unified remote execution module for GKE backend.
This module consolidates the common execution logic shared between different
backend implementations, reducing code duplication and improving maintainability.
"""
import inspect
import os
import tempfile
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
import cloudpickle
from absl import logging
from google.api_core import exceptions as google_exceptions
from keras_remote.backend import gke_client, pathways_client
from keras_remote.constants import (
get_default_cluster_name,
get_default_zone,
zone_to_region,
)
from keras_remote.credentials import ensure_credentials
from keras_remote.data import _make_data_ref
from keras_remote.infra import container_builder
from keras_remote.infra.infra import get_default_project
from keras_remote.utils import packager, storage
@dataclass
class JobContext:
"""Encapsulates all state for a remote job execution."""
# Function and arguments
func: Callable
args: tuple
kwargs: dict
env_vars: dict
# Configuration
accelerator: str
container_image: Optional[str]
zone: str
project: str
cluster_name: str
# Generated identifiers
job_id: str = field(default_factory=lambda: f"job-{uuid.uuid4().hex[:8]}")
# Derived values (computed in __post_init__)
bucket_name: str = field(init=False)
region: str = field(init=False)
display_name: str = field(init=False)
# Data volumes {mount_path: Data}
volumes: Optional[dict] = None
# Configuration modifiers
spot: bool = False
# Artifact paths (set during prepare phase)
payload_path: Optional[str] = None
context_path: Optional[str] = None
requirements_path: Optional[str] = None # requirements.txt or pyproject.toml
image_uri: Optional[str] = None
def __post_init__(self):
self.bucket_name = f"{self.project}-kr-{self.cluster_name}-jobs"
self.region = zone_to_region(self.zone)
self.display_name = f"keras-remote-{self.func.__name__}-{self.job_id}"
@classmethod
def from_params(
cls,
func: Callable,
args: tuple,
kwargs: dict,
accelerator: str,
container_image: Optional[str],
zone: Optional[str],
project: Optional[str],
env_vars: dict,
cluster_name: Optional[str] = None,
volumes: Optional[dict] = None,
spot: bool = False,
) -> "JobContext":
"""Factory method with default resolution for zone/project/cluster."""
if not zone:
zone = get_default_zone()
if not project:
project = get_default_project()
if not project:
raise ValueError(
"project must be specified or set KERAS_REMOTE_PROJECT"
" (or GOOGLE_CLOUD_PROJECT) environment variable"
)
if not cluster_name:
cluster_name = get_default_cluster_name()
return cls(
func=func,
args=args,
kwargs=kwargs,
env_vars=env_vars,
accelerator=accelerator,
container_image=container_image,
zone=zone,
project=project,
cluster_name=cluster_name,
volumes=volumes,
spot=spot,
)
class BaseK8sBackend:
"""Base class for Kubernetes-based backends."""
def __init__(self, cluster: str, namespace: str = "default"):
self.cluster = cluster
self.namespace = namespace
def validate_preflight(self, ctx: JobContext) -> None:
"""Perform preflight checks before building container or uploading artifacts."""
pass
def submit_job(self, ctx: JobContext) -> Any:
"""Submit a job to the backend. Returns backend-specific job handle."""
raise NotImplementedError
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
"""Wait for job completion. Raises RuntimeError if job fails."""
raise NotImplementedError
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
"""Optional cleanup after job completion."""
raise NotImplementedError
class GKEBackend(BaseK8sBackend):
"""Backend adapter for standard GKE Jobs."""
def validate_preflight(self, ctx: JobContext) -> None:
"""Check if the required node pool exists for the accelerator."""
gke_client.validate_preflight(
accelerator=ctx.accelerator,
project=ctx.project,
cluster=self.cluster,
zone=ctx.zone,
namespace=self.namespace,
)
def submit_job(self, ctx: JobContext) -> Any:
"""Submit job to GKE cluster."""
return gke_client.submit_k8s_job(
display_name=ctx.display_name,
container_uri=ctx.image_uri,
accelerator=ctx.accelerator,
project=ctx.project,
job_id=ctx.job_id,
bucket_name=ctx.bucket_name,
namespace=self.namespace,
spot=ctx.spot,
)
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
"""Wait for GKE job completion."""
gke_client.wait_for_job(job, namespace=self.namespace)
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
"""Clean up K8s job resources."""
job_name = job.metadata.name
gke_client.cleanup_job(job_name, namespace=self.namespace)
class PathwaysBackend(BaseK8sBackend):
"""Backend adapter for ML Pathways using LeaderWorkerSet."""
def validate_preflight(self, ctx: JobContext) -> None:
"""Preflight checks for Pathways (currently same as GKE)."""
# Pathways also runs on GKE nodes with specific labels
gke_client.validate_preflight(
accelerator=ctx.accelerator,
project=ctx.project,
cluster=self.cluster,
zone=ctx.zone,
namespace=self.namespace,
)
def submit_job(self, ctx: JobContext) -> Any:
"""Submit LWS job to GKE cluster."""
return pathways_client.submit_pathways_job(
display_name=ctx.display_name,
container_uri=ctx.image_uri,
accelerator=ctx.accelerator,
project=ctx.project,
job_id=ctx.job_id,
bucket_name=ctx.bucket_name,
namespace=self.namespace,
spot=ctx.spot,
)
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
"""Wait for Pathways LWS completion."""
pathways_client.wait_for_job(ctx.job_id, namespace=self.namespace)
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
"""Clean up LWS resources."""
job_name = pathways_client._get_job_name(ctx.job_id)
pathways_client.cleanup_job(job_name, namespace=self.namespace)
def _find_requirements(start_dir: str) -> Optional[str]:
"""Search up directory tree for requirements.txt or pyproject.toml.
At each directory level, ``requirements.txt`` is preferred over
``pyproject.toml``. The first match found while walking towards the
filesystem root is returned.
"""
search_dir = start_dir
while search_dir != "/":
req_path = os.path.join(search_dir, "requirements.txt")
if os.path.exists(req_path):
return req_path
pyproject_path = os.path.join(search_dir, "pyproject.toml")
if os.path.exists(pyproject_path):
return pyproject_path
parent_dir = os.path.dirname(search_dir)
if parent_dir == search_dir:
break
search_dir = parent_dir
return None
def _maybe_exclude(data_path, caller_path, exclude_paths):
"""Add data_path to exclude_paths if it's inside the caller directory."""
data_abs = os.path.normpath(data_path)
caller_abs = os.path.normpath(caller_path)
if data_abs.startswith(caller_abs + os.sep) or data_abs == caller_abs:
exclude_paths.add(data_abs)
def _prepare_artifacts(
ctx: JobContext, tmpdir: str, caller_frame_depth: int = 3
) -> None:
"""Phase 1: Package function payload and working directory context."""
logging.info("Packaging function and context...")
# Get caller directory
frame = inspect.stack()[caller_frame_depth]
module = inspect.getmodule(frame[0])
caller_path: str
if module and module.__file__:
caller_path = os.path.dirname(os.path.abspath(module.__file__))
else:
caller_path = os.getcwd()
# Process Data objects
exclude_paths: set[str] = set()
ref_map = {} # id(Data) -> ref dict (for arg replacement)
volume_refs = [] # list of ref dicts (for volumes)
# Process volumes
if ctx.volumes:
for mount_path, data_obj in ctx.volumes.items():
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
volume_refs.append(
_make_data_ref(gcs_uri, data_obj.is_dir, mount_path=mount_path)
)
if not data_obj.is_gcs:
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
# Process Data in function args
data_refs = packager.extract_data_refs(ctx.args, ctx.kwargs)
for data_obj, _position in data_refs:
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
ref_map[id(data_obj)] = _make_data_ref(gcs_uri, data_obj.is_dir)
if not data_obj.is_gcs:
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
# Replace Data with refs in args/kwargs
if ref_map:
ctx.args, ctx.kwargs = packager.replace_data_with_refs(
ctx.args, ctx.kwargs, ref_map
)
# Serialize function + args (with volume refs)
ctx.payload_path = os.path.join(tmpdir, "payload.pkl")
packager.save_payload(
ctx.func,
ctx.args,
ctx.kwargs,
ctx.env_vars,
ctx.payload_path,
volumes=volume_refs or None,
)
logging.info("Payload serialized to %s", ctx.payload_path)
# Zip working directory (excluding Data paths)
ctx.context_path = os.path.join(tmpdir, "context.zip")
packager.zip_working_dir(
caller_path, ctx.context_path, exclude_paths=exclude_paths
)
logging.info("Context packaged to %s", ctx.context_path)
# Find requirements.txt or pyproject.toml
ctx.requirements_path = _find_requirements(caller_path)
if ctx.requirements_path:
logging.info("Found dependency file: %s", ctx.requirements_path)
else:
logging.info("No requirements.txt or pyproject.toml found")
def _build_container(ctx: JobContext) -> None:
"""Phase 2: Build or get cached container image."""
if ctx.container_image:
ctx.image_uri = ctx.container_image
logging.info("Using custom container: %s", ctx.image_uri)
else:
import sys
logging.info("Building container image...")
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
ctx.image_uri = container_builder.get_or_build_container(
base_image=f"python:{py_version}-slim",
requirements_path=ctx.requirements_path,
accelerator_type=ctx.accelerator,
project=ctx.project,
zone=ctx.zone,
cluster_name=ctx.cluster_name,
)
def _upload_artifacts(ctx: JobContext) -> None:
"""Phase 3: Upload artifacts to Cloud Storage."""
if ctx.payload_path is None or ctx.context_path is None:
raise ValueError("payload_path and context_path must be set before upload")
logging.info("Uploading artifacts to Cloud Storage (job: %s)...", ctx.job_id)
storage.upload_artifacts(
bucket_name=ctx.bucket_name,
job_id=ctx.job_id,
payload_path=ctx.payload_path,
context_path=ctx.context_path,
project=ctx.project,
)
def _download_result(ctx: JobContext) -> dict:
"""Phase 6: Download and deserialize result from Cloud Storage."""
logging.info("Downloading result...")
result_path = storage.download_result(
ctx.bucket_name, ctx.job_id, project=ctx.project
)
with open(result_path, "rb") as f:
return cloudpickle.load(f)
def _cleanup_and_return(ctx: JobContext, result_payload: dict) -> Any:
"""Phase 7: Cleanup Cloud Storage artifacts and handle result."""
logging.info("Cleaning up artifacts...")
storage.cleanup_artifacts(ctx.bucket_name, ctx.job_id, project=ctx.project)
if result_payload["success"]:
logging.info("Remote execution completed successfully")
return result_payload["result"]
else:
logging.error("Remote execution failed:\n%s", result_payload["traceback"])
raise result_payload["exception"]
def execute_remote(ctx: JobContext, backend: BaseK8sBackend) -> Any:
"""Execute a function remotely using the specified backend.
This is the unified executor that handles all common phases
and delegates backend-specific operations to the backend client.
Args:
ctx: Job context with function and configuration
backend: Backend instance (GKEBackend or PathwaysBackend)
Returns:
The result of the remote function execution
Raises:
Exception: Re-raised from remote execution if it failed
"""
ensure_credentials(
project=ctx.project,
zone=ctx.zone,
cluster=backend.cluster,
)
# Preflight check
backend.validate_preflight(ctx)
with tempfile.TemporaryDirectory() as tmpdir:
# Phase 1: Package artifacts
_prepare_artifacts(ctx, tmpdir)
# Phase 2: Build or get cached container image
_build_container(ctx)
# Phase 3: Upload artifacts to Cloud Storage
_upload_artifacts(ctx)
# Phase 4: Submit job (backend-specific)
logging.info("Submitting job to %s...", backend.__class__.__name__)
job = backend.submit_job(ctx)
# Phase 5: Wait for completion (with cleanup on failure)
job_error = None
try:
backend.wait_for_job(job, ctx)
except RuntimeError as e:
job_error = e
finally:
backend.cleanup_job(job, ctx)
# Phase 6: Download and deserialize result
# Try even if the job failed — the runner may have captured a user
# exception and uploaded the result before exiting with non-zero.
if job_error is not None:
try:
result_payload = _download_result(ctx)
except google_exceptions.NotFound:
# Result wasn't uploaded (infrastructure failure), surface the
# original job error.
raise job_error from None
else:
result_payload = _download_result(ctx)
# Phase 7: Cleanup and return/raise
return _cleanup_and_return(ctx, result_payload)