Skip to content

Commit a1c9ce0

Browse files
committed
fix(container): address PR review comments for initializer support
- Use GHCR images as default for dataset/model initializers - Replace suppress with try-except blocks - Refactor initializer utils with ContainerInitializer dataclass - Add get_dataset_initializer and get_model_initializer functions - Remove DataCache support (unsupported in container backend) - Merge initializer tests into test_train() and test_get_job_logs() - Remove duplicate test functions Signed-off-by: HKanoje <hrithik.kanoje@gmail.com>
1 parent 6241fae commit a1c9ce0

File tree

4 files changed

+348
-457
lines changed

4 files changed

+348
-457
lines changed

kubeflow/trainer/backends/container/backend.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -210,20 +210,24 @@ def _cleanup_container_resources(
210210
network_id: Network ID to delete.
211211
stop_timeout: Timeout in seconds for stopping containers.
212212
"""
213-
from contextlib import suppress
214-
215213
# Stop and remove containers
216214
if container_ids:
217215
for container_id in container_ids:
218-
with suppress(Exception):
216+
try: # noqa: SIM105
219217
self._adapter.stop_container(container_id, timeout=stop_timeout)
220-
with suppress(Exception):
218+
except Exception:
219+
pass
220+
try: # noqa: SIM105
221221
self._adapter.remove_container(container_id, force=True)
222+
except Exception:
223+
pass
222224

223225
# Delete network
224226
if network_id:
225-
with suppress(Exception):
227+
try: # noqa: SIM105
226228
self._adapter.delete_network(network_id)
229+
except Exception:
230+
pass
227231

228232
# ---- Runtime APIs ----
229233
def list_runtimes(self) -> list[types.Runtime]:
@@ -311,10 +315,10 @@ def train(
311315
except Exception as e:
312316
# Clean up network if initializers fail
313317
logger.error(f"Initializer failed, cleaning up network: {e}")
314-
from contextlib import suppress
315-
316-
with suppress(Exception):
318+
try: # noqa: SIM105
317319
self._adapter.delete_network(network_id)
320+
except Exception:
321+
pass
318322
raise
319323

320324
# Generate training script code (inline, not written to disk)
@@ -527,33 +531,31 @@ def _run_initializers(
527531
"""
528532
# Run dataset initializer if configured
529533
if initializer.dataset:
530-
# Get and pull dataset initializer image
531-
dataset_image = container_utils.get_initializer_image(self.cfg, "dataset")
532-
container_utils.maybe_pull_image(self._adapter, dataset_image, self.cfg.pull_policy)
534+
dataset_init = container_utils.get_dataset_initializer(initializer.dataset, self.cfg)
535+
container_utils.maybe_pull_image(
536+
self._adapter, dataset_init.image, self.cfg.pull_policy
537+
)
533538

534539
logger.debug("Running dataset initializer")
535540
self._run_single_initializer(
536541
job_name=job_name,
537-
initializer_config=initializer.dataset,
542+
container_init=dataset_init,
538543
init_type="dataset",
539-
image=dataset_image,
540544
workdir=workdir,
541545
network_id=network_id,
542546
)
543547
logger.debug("Dataset initializer completed")
544548

545549
# Run model initializer if configured
546550
if initializer.model:
547-
# Get and pull model initializer image
548-
model_image = container_utils.get_initializer_image(self.cfg, "model")
549-
container_utils.maybe_pull_image(self._adapter, model_image, self.cfg.pull_policy)
551+
model_init = container_utils.get_model_initializer(initializer.model, self.cfg)
552+
container_utils.maybe_pull_image(self._adapter, model_init.image, self.cfg.pull_policy)
550553

551554
logger.debug("Running model initializer")
552555
self._run_single_initializer(
553556
job_name=job_name,
554-
initializer_config=initializer.model,
557+
container_init=model_init,
555558
init_type="model",
556-
image=model_image,
557559
workdir=workdir,
558560
network_id=network_id,
559561
)
@@ -562,9 +564,8 @@ def _run_initializers(
562564
def _run_single_initializer(
563565
self,
564566
job_name: str,
565-
initializer_config: types.BaseInitializer,
567+
container_init: container_utils.ContainerInitializer,
566568
init_type: str,
567-
image: str,
568569
workdir: str,
569570
network_id: str,
570571
):
@@ -573,9 +574,8 @@ def _run_single_initializer(
573574
574575
Args:
575576
job_name: Name of the training job.
576-
initializer_config: Initializer configuration.
577+
container_init: ContainerInitializer with image, command, and env.
577578
init_type: Type of initializer ("dataset" or "model").
578-
image: Container image to use.
579579
workdir: Working directory path on host.
580580
network_id: Network ID for containers.
581581
@@ -584,10 +584,6 @@ def _run_single_initializer(
584584
"""
585585
container_name = f"{job_name}-{init_type}-initializer"
586586

587-
# Build command and environment
588-
command = container_utils.build_initializer_command(initializer_config, init_type)
589-
env = container_utils.build_initializer_env(initializer_config, init_type)
590-
591587
# Create labels for tracking
592588
labels = {
593589
f"{self.label_prefix}/trainjob-name": job_name,
@@ -609,11 +605,11 @@ def _run_single_initializer(
609605
# The initializer images use /app as their working directory
610606
# See: https://github.com/kubeflow/trainer/blob/master/cmd/initializers/dataset/Dockerfile
611607
container_id = self._adapter.create_and_start_container(
612-
image=image,
613-
command=command,
608+
image=container_init.image,
609+
command=container_init.command,
614610
name=container_name,
615611
network_id=network_id,
616-
environment=env,
612+
environment=container_init.env,
617613
labels=labels,
618614
volumes=volumes,
619615
working_dir="/app",

0 commit comments

Comments
 (0)