|
14 | 14 | import cloudpickle |
15 | 15 | from absl import logging |
16 | 16 |
|
17 | | -from keras_remote.backend import gke_client |
| 17 | +from keras_remote.backend import gke_client, pathways_client |
18 | 18 | from keras_remote.constants import get_default_zone, zone_to_region |
19 | 19 | from keras_remote.infra import container_builder |
20 | 20 | from keras_remote.utils import packager, storage |
@@ -106,59 +106,61 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None: |
106 | 106 |
|
107 | 107 |
|
108 | 108 | class BaseK8sBackend: |
109 | | - """Base class for Kubernetes-based backends.""" |
| 109 | + """Base class for Kubernetes-based backends.""" |
| 110 | + |
| 111 | + def __init__(self, cluster: Optional[str] = None, namespace: str = "default"): |
| 112 | + self.cluster = cluster |
| 113 | + self.namespace = namespace |
110 | 114 |
|
111 | | - def __init__(self, cluster: Optional[str] = None, namespace: str = "default"): |
112 | | - self.cluster = cluster |
113 | | - self.namespace = namespace |
114 | 115 |
|
115 | 116 | class GKEBackend(BaseK8sBackend): |
116 | | - """Backend adapter for standard GKE Jobs.""" |
117 | | - def submit_job(self, ctx: JobContext) -> Any: |
118 | | - """Submit job to GKE cluster.""" |
119 | | - return gke_client.submit_k8s_job( |
120 | | - display_name=ctx.display_name, |
121 | | - container_uri=ctx.image_uri, |
122 | | - accelerator=ctx.accelerator, |
123 | | - project=ctx.project, |
124 | | - job_id=ctx.job_id, |
125 | | - bucket_name=ctx.bucket_name, |
126 | | - namespace=self.namespace, |
127 | | - ) |
| 117 | + """Backend adapter for standard GKE Jobs.""" |
128 | 118 |
|
129 | | - def wait_for_job(self, job: Any, ctx: JobContext) -> None: |
130 | | - """Wait for GKE job completion.""" |
131 | | - gke_client.wait_for_job(job, namespace=self.namespace) |
| 119 | + def submit_job(self, ctx: JobContext) -> Any: |
| 120 | + """Submit job to GKE cluster.""" |
| 121 | + return gke_client.submit_k8s_job( |
| 122 | + display_name=ctx.display_name, |
| 123 | + container_uri=ctx.image_uri, |
| 124 | + accelerator=ctx.accelerator, |
| 125 | + project=ctx.project, |
| 126 | + job_id=ctx.job_id, |
| 127 | + bucket_name=ctx.bucket_name, |
| 128 | + namespace=self.namespace, |
| 129 | + ) |
| 130 | + |
| 131 | + def wait_for_job(self, job: Any, ctx: JobContext) -> None: |
| 132 | + """Wait for GKE job completion.""" |
| 133 | + gke_client.wait_for_job(job, namespace=self.namespace) |
132 | 134 |
|
133 | | - def cleanup_job(self, job: Any, ctx: JobContext) -> None: |
134 | | - """Clean up K8s job resources.""" |
135 | | - job_name = job.metadata.name |
136 | | - gke_client.cleanup_job(job_name, namespace=self.namespace) |
| 135 | + def cleanup_job(self, job: Any, ctx: JobContext) -> None: |
| 136 | + """Clean up K8s job resources.""" |
| 137 | + job_name = job.metadata.name |
| 138 | + gke_client.cleanup_job(job_name, namespace=self.namespace) |
137 | 139 |
|
138 | 140 |
|
139 | 141 | class PathwaysBackend(BaseK8sBackend): |
140 | | - """Backend adapter for ML Pathways using LeaderWorkerSet.""" |
141 | | - def submit_job(self, ctx: JobContext) -> Any: |
142 | | - """Submit LWS job to GKE cluster.""" |
143 | | - return pathways_client.submit_pathways_job( |
144 | | - display_name=ctx.display_name, |
145 | | - container_uri=ctx.image_uri, |
146 | | - accelerator=ctx.accelerator, |
147 | | - project=ctx.project, |
148 | | - job_id=ctx.job_id, |
149 | | - bucket_name=ctx.bucket_name, |
150 | | - namespace=self.namespace, |
151 | | - ) |
| 142 | + """Backend adapter for ML Pathways using LeaderWorkerSet.""" |
152 | 143 |
|
153 | | - def wait_for_job(self, job: Any, ctx: JobContext) -> None: |
154 | | - """Wait for Pathways LWS completion.""" |
155 | | - pathways_client.wait_for_job(job, ctx.job_id, namespace=self.namespace) |
| 144 | + def submit_job(self, ctx: JobContext) -> Any: |
| 145 | + """Submit LWS job to GKE cluster.""" |
| 146 | + return pathways_client.submit_pathways_job( |
| 147 | + display_name=ctx.display_name, |
| 148 | + container_uri=ctx.image_uri, |
| 149 | + accelerator=ctx.accelerator, |
| 150 | + project=ctx.project, |
| 151 | + job_id=ctx.job_id, |
| 152 | + bucket_name=ctx.bucket_name, |
| 153 | + namespace=self.namespace, |
| 154 | + ) |
156 | 155 |
|
157 | | - def cleanup_job(self, job: Any, ctx: JobContext) -> None: |
158 | | - """Clean up LWS resources.""" |
159 | | - job_name = f"keras-pathways-{ctx.job_id}" |
160 | | - pathways_client.cleanup_job(job_name, namespace=self.namespace) |
| 156 | + def wait_for_job(self, job: Any, ctx: JobContext) -> None: |
| 157 | + """Wait for Pathways LWS completion.""" |
| 158 | + pathways_client.wait_for_job(job, ctx.job_id, namespace=self.namespace) |
161 | 159 |
|
| 160 | + def cleanup_job(self, job: Any, ctx: JobContext) -> None: |
| 161 | + """Clean up LWS resources.""" |
| 162 | + job_name = f"keras-pathways-{ctx.job_id}" |
| 163 | + pathways_client.cleanup_job(job_name, namespace=self.namespace) |
162 | 164 |
|
163 | 165 |
|
164 | 166 | def _find_requirements(start_dir: str) -> Optional[str]: |
|
0 commit comments