Skip to content

Commit e9add5e

Browse files
Termination message capture (#33)
* feat: Add termination message capture and improve final progress logic * test: Add unit tests for termination message capture * refactor: use constants.Node instead of hardcoded container name Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com> --------- Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent cf3112e commit e9add5e

2 files changed

Lines changed: 765 additions & 136 deletions

File tree

pkg/rhai/progression/progression.go

Lines changed: 145 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ import (
3030
corev1 "k8s.io/api/core/v1"
3131
"k8s.io/apimachinery/pkg/api/meta"
3232
"k8s.io/apimachinery/pkg/labels"
33-
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
3433
ctrl "sigs.k8s.io/controller-runtime"
3534
"sigs.k8s.io/controller-runtime/pkg/client"
3635

3736
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
37+
pkgconstants "github.com/kubeflow/trainer/v2/pkg/constants"
3838
"github.com/kubeflow/trainer/v2/pkg/rhai/constants"
3939
)
4040

@@ -411,10 +411,11 @@ func PollAndUpdateProgress(ctx context.Context, c client.Client, reader client.R
411411
annotationStatus := ToAnnotationStatus(status)
412412

413413
// Use Patch to avoid conflicts with main controller
414-
patch := client.MergeFrom(trainJob.DeepCopy())
414+
oldTrainJob := trainJob.DeepCopy()
415415
if err := UpdateTrainerStatusAnnotation(trainJob, annotationStatus); err != nil {
416416
return false, fmt.Errorf("failed to update trainer status annotation: %w", err)
417417
}
418+
patch := client.MergeFrom(oldTrainJob)
418419
if err := c.Patch(ctx, trainJob, patch); err != nil {
419420
return false, fmt.Errorf("failed to patch TrainJob annotations: %w", err)
420421
}
@@ -437,14 +438,11 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
437438
return false
438439
}
439440

440-
// Consider final status captured if:
441-
// 1. We have a progress percentage (any value, including <100% for epoch-based training)
442-
// 2. Estimated remaining time is 0 (indicates training has ended)
441+
// Final status is captured if: progress percentage exists AND (remaining time is 0 OR grace period expired)
443442
if status.ProgressPercentage == nil {
444443
return false
445444
}
446445

447-
// Check if remaining time is explicitly 0 or summary indicates completion
448446
hasZeroRemaining := status.EstimatedRemainingSeconds != nil && *status.EstimatedRemainingSeconds == 0
449447
hasCompleteSummary := status.EstimatedRemainingTimeSummary == "complete" ||
450448
status.EstimatedRemainingTimeSummary == "0 seconds"
@@ -453,17 +451,14 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
453451
return true
454452
}
455453

456-
// Also consider captured if job completed/failed AND preStop window has expired
457-
// This handles cases where on_train_end() never fired (pod killed, crash, etc.)
458-
// Check if we're past the preStop hook duration since last update
454+
// Consider captured if sufficient time has passed since last update (handles pod termination/crash)
459455
if status.LastUpdatedTime != "" {
460456
lastUpdate, err := time.Parse(time.RFC3339, status.LastUpdatedTime)
461457
if err == nil {
462458
pollInterval := GetMetricsPollInterval(trainJob)
463-
preStopDuration := pollInterval*2 + time.Duration(constants.PreStopBufferSecs)*time.Second
464-
gracePeriod := preStopDuration + time.Duration(constants.TerminationGraceBufferSecs)*time.Second
459+
// Grace period = 3 polling cycles + buffer (allows multiple retry attempts)
460+
gracePeriod := pollInterval*3 + time.Duration(constants.TerminationGraceBufferSecs)*time.Second
465461

466-
// If it's been longer than preStop + grace period, pod is definitely gone
467462
if time.Since(lastUpdate) > gracePeriod {
468463
return true // Stop trying, preserve last known state
469464
}
@@ -473,64 +468,175 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
473468
return false
474469
}
475470

471+
// CaptureMetricsFromTerminationMessage reads final metrics from pod's termination message.
472+
// Called when job completes/fails to capture final state written
473+
func CaptureMetricsFromTerminationMessage(ctx context.Context, pod *corev1.Pod) (*AnnotationStatus, error) {
474+
if pod == nil {
475+
return nil, fmt.Errorf("pod is nil")
476+
}
477+
478+
// Look for the trainer container (typically named "node")
479+
for _, containerStatus := range pod.Status.ContainerStatuses {
480+
// Check if this is the trainer container
481+
if containerStatus.Name != pkgconstants.Node {
482+
continue
483+
}
484+
485+
// Check if container has terminated
486+
if containerStatus.State.Terminated == nil {
487+
continue
488+
}
489+
490+
// Read termination message
491+
message := containerStatus.State.Terminated.Message
492+
if message == "" {
493+
return nil, fmt.Errorf("termination message is empty for container %s", containerStatus.Name)
494+
}
495+
496+
// Parse JSON from termination message
497+
var status AnnotationStatus
498+
if err := json.Unmarshal([]byte(message), &status); err != nil {
499+
return nil, fmt.Errorf("failed to parse termination message JSON: %w", err)
500+
}
501+
502+
// Validate critical fields
503+
if status.ProgressPercentage == nil {
504+
return nil, fmt.Errorf("termination message missing progressPercentage")
505+
}
506+
507+
status.LastUpdatedTime = time.Now().UTC().Format(time.RFC3339)
508+
509+
return &status, nil
510+
}
511+
512+
return nil, fmt.Errorf("no terminated trainer container found in pod")
513+
}
514+
476515
func PollAndUpdateFinalProgress(ctx context.Context, c client.Client, reader client.Reader, trainJob *trainer.TrainJob, completed bool) (bool, error) {
477516
if !IsProgressionTrackingEnabled(trainJob) {
478517
return false, nil
479518
}
480519

481-
// Try to get final metrics from pod if it still exists
482-
pod, err := GetPrimaryPod(ctx, reader, trainJob)
483-
if err == nil {
484-
metricsPort := GetMetricsPort(trainJob)
485-
if status, pollErr := PollTrainingProgress(ctx, pod, metricsPort); pollErr == nil {
486-
// Got real metrics from pod - update with final status
487-
annotationStatus := ToAnnotationStatus(status)
488-
annotationStatus.LastUpdatedTime = time.Now().UTC().Format(time.RFC3339)
520+
log := ctrl.Log.WithName("progression").WithValues("trainjob", trainJob.Name, "namespace", trainJob.Namespace)
521+
522+
// Get pod (including terminated pods for termination message reading)
523+
// For final progress, we need to check terminated pods, not just running ones
524+
podList := &corev1.PodList{}
525+
if err := reader.List(ctx, podList,
526+
client.InNamespace(trainJob.Namespace),
527+
client.MatchingLabels{"jobset.sigs.k8s.io/jobset-name": trainJob.Name}); err != nil {
528+
log.V(1).Info("Failed to list pods", "error", err)
529+
// Fall through to updateFinalStatus
530+
} else if len(podList.Items) > 0 {
531+
// Find the primary pod (prefer index-0 pods, or use the first one)
532+
var pod *corev1.Pod
533+
for i := range podList.Items {
534+
p := &podList.Items[i]
535+
// Look for index-0 pod (primary in most frameworks)
536+
if labels := p.Labels; labels != nil {
537+
if idx := labels["jobset.sigs.k8s.io/job-index"]; idx == "0" {
538+
pod = p
539+
break
540+
}
541+
if idx := labels["training.kubeflow.org/replica-index"]; idx == "0" {
542+
pod = p
543+
break
544+
}
545+
}
546+
}
547+
if pod == nil {
548+
pod = &podList.Items[0] // Fallback to first pod
549+
}
550+
551+
// Priority 1: Try termination message (most authoritative for final state)
552+
if terminationStatus, termErr := CaptureMetricsFromTerminationMessage(ctx, pod); termErr == nil {
553+
log.Info("Captured final metrics from termination message",
554+
"progress", *terminationStatus.ProgressPercentage)
489555

490556
// Add descriptive summary
491557
if completed {
492-
// Detect early stop: currentStep < totalSteps
493558
earlyStop := false
494-
if annotationStatus.CurrentStep != nil && annotationStatus.TotalSteps != nil && *annotationStatus.TotalSteps > 0 {
495-
if *annotationStatus.CurrentStep < *annotationStatus.TotalSteps {
559+
if terminationStatus.CurrentStep != nil && terminationStatus.TotalSteps != nil && *terminationStatus.TotalSteps > 0 {
560+
if *terminationStatus.CurrentStep < *terminationStatus.TotalSteps {
496561
earlyStop = true
497562
}
498563
}
499564
if earlyStop {
500-
annotationStatus.EstimatedRemainingTimeSummary = "complete (early stopped)"
565+
terminationStatus.EstimatedRemainingTimeSummary = "complete (early stopped)"
501566
} else {
502-
annotationStatus.EstimatedRemainingTimeSummary = "complete"
567+
terminationStatus.EstimatedRemainingTimeSummary = "complete"
503568
}
504569
} else {
505-
// For failed jobs: show progress context in summary
506570
progressPct := 0
507-
if annotationStatus.ProgressPercentage != nil {
508-
progressPct = *annotationStatus.ProgressPercentage
571+
if terminationStatus.ProgressPercentage != nil {
572+
progressPct = *terminationStatus.ProgressPercentage
509573
}
510-
annotationStatus.EstimatedRemainingTimeSummary = fmt.Sprintf("failed at %d%%", progressPct)
574+
terminationStatus.EstimatedRemainingTimeSummary = fmt.Sprintf("failed at %d%%", progressPct)
511575
}
512576

513-
// Use Patch to avoid conflicts with main controller
514-
patch := client.MergeFrom(trainJob.DeepCopy())
515-
if err := UpdateTrainerStatusAnnotation(trainJob, annotationStatus); err != nil {
577+
// Update annotation with termination message data
578+
oldTrainJob := trainJob.DeepCopy()
579+
if err := UpdateTrainerStatusAnnotation(trainJob, terminationStatus); err != nil {
516580
return false, fmt.Errorf("failed to update trainer status annotation: %w", err)
517581
}
582+
patch := client.MergeFrom(oldTrainJob)
518583
if err := c.Patch(ctx, trainJob, patch); err != nil {
519584
return false, fmt.Errorf("failed to patch TrainJob annotations: %w", err)
520585
}
521586

522587
return true, nil
588+
} else {
589+
log.V(1).Info("Termination message not available, trying HTTP polling", "error", termErr)
590+
}
591+
592+
// Priority 2: Try HTTP polling (only if pod is still running with IP)
593+
if pod.Status.Phase == corev1.PodRunning && pod.Status.PodIP != "" {
594+
metricsPort := GetMetricsPort(trainJob)
595+
if status, pollErr := PollTrainingProgress(ctx, pod, metricsPort); pollErr == nil {
596+
annotationStatus := ToAnnotationStatus(status)
597+
annotationStatus.LastUpdatedTime = time.Now().UTC().Format(time.RFC3339)
598+
599+
// Add descriptive summary
600+
if completed {
601+
earlyStop := false
602+
if annotationStatus.CurrentStep != nil && annotationStatus.TotalSteps != nil && *annotationStatus.TotalSteps > 0 {
603+
if *annotationStatus.CurrentStep < *annotationStatus.TotalSteps {
604+
earlyStop = true
605+
}
606+
}
607+
if earlyStop {
608+
annotationStatus.EstimatedRemainingTimeSummary = "complete (early stopped)"
609+
} else {
610+
annotationStatus.EstimatedRemainingTimeSummary = "complete"
611+
}
612+
} else {
613+
progressPct := 0
614+
if annotationStatus.ProgressPercentage != nil {
615+
progressPct = *annotationStatus.ProgressPercentage
616+
}
617+
annotationStatus.EstimatedRemainingTimeSummary = fmt.Sprintf("failed at %d%%", progressPct)
618+
}
619+
620+
oldTrainJob := trainJob.DeepCopy()
621+
if err := UpdateTrainerStatusAnnotation(trainJob, annotationStatus); err != nil {
622+
return false, fmt.Errorf("failed to update trainer status annotation: %w", err)
623+
}
624+
patch := client.MergeFrom(oldTrainJob)
625+
if err := c.Patch(ctx, trainJob, patch); err != nil {
626+
return false, fmt.Errorf("failed to patch TrainJob annotations: %w", err)
627+
}
628+
629+
return true, nil
630+
}
523631
}
524632
}
525633

526-
// Pod not available - update final status using existing metrics
527-
// For completed: force remaining time to 0 (no work remains)
528-
// For failed: keep remaining time estimate (useful for resume)
529-
// Use Patch to avoid conflicts with main controller
530-
patch := client.MergeFrom(trainJob.DeepCopy())
634+
// Priority 3: Update final status using existing metrics (pod not available or both methods failed)
635+
oldTrainJob := trainJob.DeepCopy()
531636
if err := updateFinalStatus(trainJob, completed); err != nil {
532637
return false, fmt.Errorf("failed to update final status: %w", err)
533638
}
639+
patch := client.MergeFrom(oldTrainJob)
534640
if err := c.Patch(ctx, trainJob, patch); err != nil {
535641
return false, fmt.Errorf("failed to patch TrainJob annotations: %w", err)
536642
}
@@ -580,99 +686,6 @@ func updateFinalStatus(trainJob *trainer.TrainJob, completed bool) error {
580686
return UpdateTrainerStatusAnnotation(trainJob, &status)
581687
}
582688

583-
// InjectPreStopHookToApplyConfig adds a preStop lifecycle hook to the primary pod container
584-
// using Apply Configuration. This is used by the jobset plugin to inject the hook during pod creation.
585-
func InjectPreStopHookToApplyConfig(podSpecAC *corev1ac.PodSpecApplyConfiguration, trainJob *trainer.TrainJob) error {
586-
if !IsProgressionTrackingEnabled(trainJob) || podSpecAC == nil {
587-
return nil
588-
}
589-
590-
if len(podSpecAC.Containers) == 0 {
591-
return fmt.Errorf("no containers in pod spec")
592-
}
593-
594-
// Calculate preStop duration based on poll interval
595-
pollInterval := GetMetricsPollInterval(trainJob)
596-
preStopDuration := pollInterval*2 + time.Duration(constants.PreStopBufferSecs)*time.Second
597-
preStopSleep := int(preStopDuration.Seconds())
598-
599-
// Termination grace must be greater than preStop duration
600-
terminationGrace := int64((preStopDuration + time.Duration(constants.TerminationGraceBufferSecs)*time.Second).Seconds())
601-
602-
// Find the primary trainer container by name (typically "node")
603-
containerIdx := -1
604-
for i, container := range podSpecAC.Containers {
605-
if container.Name != nil && *container.Name == "node" {
606-
containerIdx = i
607-
break
608-
}
609-
}
610-
611-
// Fallback to first container if "node" not found
612-
if containerIdx == -1 {
613-
containerIdx = 0
614-
}
615-
616-
// Inject preStop hook into the target container
617-
lifecycle := corev1ac.Lifecycle().
618-
WithPreStop(corev1ac.LifecycleHandler().
619-
WithExec(corev1ac.ExecAction().
620-
WithCommand("sleep", strconv.Itoa(preStopSleep))))
621-
622-
podSpecAC.Containers[containerIdx].WithLifecycle(lifecycle)
623-
624-
// Set termination grace period (use max of existing and calculated)
625-
if podSpecAC.TerminationGracePeriodSeconds == nil ||
626-
*podSpecAC.TerminationGracePeriodSeconds < terminationGrace {
627-
podSpecAC.WithTerminationGracePeriodSeconds(terminationGrace)
628-
}
629-
630-
return nil
631-
}
632-
633-
// InjectPreStopHook adds a preStop lifecycle hook to the primary pod container.
634-
// The hook keeps the metrics server alive after training completes, ensuring
635-
// the controller can capture final metrics before pod termination.
636-
//
637-
// PreStop duration is calculated as: (2 × poll_interval) + buffer
638-
// This guarantees at least 2 poll opportunities after training completion.
639-
func InjectPreStopHook(podSpec *corev1.PodSpec, trainJob *trainer.TrainJob) error {
640-
if !IsProgressionTrackingEnabled(trainJob) {
641-
return nil
642-
}
643-
644-
if len(podSpec.Containers) == 0 {
645-
return fmt.Errorf("no containers in pod spec")
646-
}
647-
648-
// Inject into primary container (index 0)
649-
container := &podSpec.Containers[0]
650-
651-
// Initialize lifecycle if nil
652-
if container.Lifecycle == nil {
653-
container.Lifecycle = &corev1.Lifecycle{}
654-
}
655-
656-
// Calculate preStop duration based on poll interval
657-
pollInterval := GetMetricsPollInterval(trainJob)
658-
preStopDuration := pollInterval*2 + time.Duration(constants.PreStopBufferSecs)*time.Second
659-
preStopSleep := int(preStopDuration.Seconds())
660-
661-
// Add preStop hook
662-
container.Lifecycle.PreStop = &corev1.LifecycleHandler{
663-
Exec: &corev1.ExecAction{
664-
Command: []string{"sleep", strconv.Itoa(preStopSleep)},
665-
},
666-
}
667-
668-
// Set termination grace period (must be > preStop duration)
669-
terminationGrace := preStopDuration + time.Duration(constants.TerminationGraceBufferSecs)*time.Second
670-
terminationGraceSecs := int64(terminationGrace.Seconds())
671-
podSpec.TerminationGracePeriodSeconds = &terminationGraceSecs
672-
673-
return nil
674-
}
675-
676689
// ReconcileProgression handles progression tracking during TrainJob reconciliation.
677690
// Returns ctrl.Result for requeue behavior and any errors encountered.
678691
// This should be called at the end of TrainJob reconciliation when progression tracking is enabled.
@@ -702,20 +715,17 @@ func ReconcileProgression(ctx context.Context, c client.Client, reader client.Re
702715
}
703716

704717
if (isCompleted || isFailed) && !IsFinalStatusCaptured(trainJob) {
705-
// Job just completed/failed - capture final metrics
706-
// PreStop hook keeps pod alive, so this should succeed
718+
// Capture final metrics (termination message + HTTP polling fallback)
707719
captured, pollErr := PollAndUpdateFinalProgress(ctx, c, reader, trainJob, isCompleted)
708720
if pollErr != nil {
709721
log.V(1).Info("Failed to capture final training progress, will retry", "error", pollErr, "completed", isCompleted)
710-
// Requeue quickly - pod should still be alive in preStop window
711722
return ctrl.Result{RequeueAfter: 5 * time.Second}, nil
712723
}
713724
if !captured {
714725
log.V(1).Info("Pod not available for final metrics poll, will retry", "completed", isCompleted)
715-
// Pod might be in preStop or already terminated, retry a few times
716726
return ctrl.Result{RequeueAfter: 2 * time.Second}, nil
717727
}
718-
log.Info("Captured final training progress", "completed", isCompleted)
728+
log.Info("Captured final training progress from HTTP poll", "completed", isCompleted)
719729
}
720730

721731
return ctrl.Result{}, nil

0 commit comments

Comments
 (0)