Skip to content

Commit 08881d9

Browse files
auto detect lws version
1 parent 3a69ed8 commit 08881d9

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

keras_remote/backend/pathways_client.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@
2020
LWS_PLURAL = "leaderworkersets"
2121

2222

23+
24+
def _get_lws_version(group=LWS_GROUP):
25+
"""Get the preferred version for the LeaderWorkerSet API."""
26+
_load_kube_config()
27+
api = client.ApisApi()
28+
try:
29+
api_group = api.get_api_group(group)
30+
return api_group.preferred_version.version
31+
except ApiException:
32+
logger.warning(
33+
"Failed to retrieve LWS API version from cluster. Defaulting to '%s'",
34+
LWS_VERSION,
35+
)
36+
return LWS_VERSION
37+
38+
2339
def submit_pathways_job(
2440
display_name,
2541
container_uri,
@@ -44,6 +60,7 @@ def submit_pathways_job(
4460
dict: The created LeaderWorkerSet object
4561
"""
4662
_load_kube_config()
63+
lws_version = _get_lws_version()
4764

4865
accel_config = _parse_accelerator(accelerator)
4966
job_name = f"keras-pathways-{job_id}"
@@ -68,14 +85,15 @@ def submit_pathways_job(
6885
bucket_name=bucket_name,
6986
num_workers=num_workers,
7087
namespace=namespace,
88+
version=lws_version,
7189
)
7290

7391
custom_api = client.CustomObjectsApi()
7492

7593
try:
7694
created_lws = custom_api.create_namespaced_custom_object(
7795
group=LWS_GROUP,
78-
version=LWS_VERSION,
96+
version=lws_version,
7997
namespace=namespace,
8098
plural=LWS_PLURAL,
8199
body=lws_manifest,
@@ -176,12 +194,13 @@ def wait_for_job(
176194
def cleanup_job(job_name, namespace="default"):
177195
"""Delete LeaderWorkerSet."""
178196
_load_kube_config()
197+
lws_version = _get_lws_version()
179198
custom_api = client.CustomObjectsApi()
180199

181200
try:
182201
custom_api.delete_namespaced_custom_object(
183202
group=LWS_GROUP,
184-
version=LWS_VERSION,
203+
version=lws_version,
185204
namespace=namespace,
186205
plural=LWS_PLURAL,
187206
name=job_name,
@@ -204,6 +223,7 @@ def _create_lws_spec(
204223
bucket_name,
205224
num_workers,
206225
namespace,
226+
version=LWS_VERSION,
207227
):
208228
"""Create a LeaderWorkerSet manifest."""
209229

@@ -252,7 +272,7 @@ def _create_lws_spec(
252272
pod_template["spec"]["nodeSelector"] = accel_config["node_selector"]
253273

254274
return {
255-
"apiVersion": f"{LWS_GROUP}/{LWS_VERSION}",
275+
"apiVersion": f"{LWS_GROUP}/{version}",
256276
"kind": "LeaderWorkerSet",
257277
"metadata": {
258278
"name": job_name,

0 commit comments

Comments
 (0)