2020LWS_PLURAL = "leaderworkersets"
2121
2222
23+
24+ def _get_lws_version (group = LWS_GROUP ):
25+ """Get the preferred version for the LeaderWorkerSet API."""
26+ _load_kube_config ()
27+ api = client .ApisApi ()
28+ try :
29+ api_group = api .get_api_group (group )
30+ return api_group .preferred_version .version
31+ except ApiException :
32+ logger .warning (
33+ "Failed to retrieve LWS API version from cluster. Defaulting to '%s'" ,
34+ LWS_VERSION ,
35+ )
36+ return LWS_VERSION
37+
38+
2339def submit_pathways_job (
2440 display_name ,
2541 container_uri ,
@@ -44,6 +60,7 @@ def submit_pathways_job(
4460 dict: The created LeaderWorkerSet object
4561 """
4662 _load_kube_config ()
63+ lws_version = _get_lws_version ()
4764
4865 accel_config = _parse_accelerator (accelerator )
4966 job_name = f"keras-pathways-{ job_id } "
@@ -68,14 +85,15 @@ def submit_pathways_job(
6885 bucket_name = bucket_name ,
6986 num_workers = num_workers ,
7087 namespace = namespace ,
88+ version = lws_version ,
7189 )
7290
7391 custom_api = client .CustomObjectsApi ()
7492
7593 try :
7694 created_lws = custom_api .create_namespaced_custom_object (
7795 group = LWS_GROUP ,
78- version = LWS_VERSION ,
96+ version = lws_version ,
7997 namespace = namespace ,
8098 plural = LWS_PLURAL ,
8199 body = lws_manifest ,
@@ -176,12 +194,13 @@ def wait_for_job(
176194def cleanup_job (job_name , namespace = "default" ):
177195 """Delete LeaderWorkerSet."""
178196 _load_kube_config ()
197+ lws_version = _get_lws_version ()
179198 custom_api = client .CustomObjectsApi ()
180199
181200 try :
182201 custom_api .delete_namespaced_custom_object (
183202 group = LWS_GROUP ,
184- version = LWS_VERSION ,
203+ version = lws_version ,
185204 namespace = namespace ,
186205 plural = LWS_PLURAL ,
187206 name = job_name ,
@@ -204,6 +223,7 @@ def _create_lws_spec(
204223 bucket_name ,
205224 num_workers ,
206225 namespace ,
226+ version = LWS_VERSION ,
207227):
208228 """Create a LeaderWorkerSet manifest."""
209229
@@ -252,7 +272,7 @@ def _create_lws_spec(
252272 pod_template ["spec" ]["nodeSelector" ] = accel_config ["node_selector" ]
253273
254274 return {
255- "apiVersion" : f"{ LWS_GROUP } /{ LWS_VERSION } " ,
275+ "apiVersion" : f"{ LWS_GROUP } /{ version } " ,
256276 "kind" : "LeaderWorkerSet" ,
257277 "metadata" : {
258278 "name" : job_name ,
0 commit comments