@@ -30,11 +30,11 @@ import (
3030 corev1 "k8s.io/api/core/v1"
3131 "k8s.io/apimachinery/pkg/api/meta"
3232 "k8s.io/apimachinery/pkg/labels"
33- corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
3433 ctrl "sigs.k8s.io/controller-runtime"
3534 "sigs.k8s.io/controller-runtime/pkg/client"
3635
3736 trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
37+ pkgconstants "github.com/kubeflow/trainer/v2/pkg/constants"
3838 "github.com/kubeflow/trainer/v2/pkg/rhai/constants"
3939)
4040
@@ -411,10 +411,11 @@ func PollAndUpdateProgress(ctx context.Context, c client.Client, reader client.R
411411 annotationStatus := ToAnnotationStatus (status )
412412
413413 // Use Patch to avoid conflicts with main controller
414- patch := client . MergeFrom ( trainJob .DeepCopy () )
414+ oldTrainJob := trainJob .DeepCopy ()
415415 if err := UpdateTrainerStatusAnnotation (trainJob , annotationStatus ); err != nil {
416416 return false , fmt .Errorf ("failed to update trainer status annotation: %w" , err )
417417 }
418+ patch := client .MergeFrom (oldTrainJob )
418419 if err := c .Patch (ctx , trainJob , patch ); err != nil {
419420 return false , fmt .Errorf ("failed to patch TrainJob annotations: %w" , err )
420421 }
@@ -437,14 +438,11 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
437438 return false
438439 }
439440
440- // Consider final status captured if:
441- // 1. We have a progress percentage (any value, including <100% for epoch-based training)
442- // 2. Estimated remaining time is 0 (indicates training has ended)
441+ // Final status is captured if: progress percentage exists AND (remaining time is 0 OR grace period expired)
443442 if status .ProgressPercentage == nil {
444443 return false
445444 }
446445
447- // Check if remaining time is explicitly 0 or summary indicates completion
448446 hasZeroRemaining := status .EstimatedRemainingSeconds != nil && * status .EstimatedRemainingSeconds == 0
449447 hasCompleteSummary := status .EstimatedRemainingTimeSummary == "complete" ||
450448 status .EstimatedRemainingTimeSummary == "0 seconds"
@@ -453,17 +451,14 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
453451 return true
454452 }
455453
456- // Also consider captured if job completed/failed AND preStop window has expired
457- // This handles cases where on_train_end() never fired (pod killed, crash, etc.)
458- // Check if we're past the preStop hook duration since last update
454+ // Consider captured if sufficient time has passed since last update (handles pod termination/crash)
459455 if status .LastUpdatedTime != "" {
460456 lastUpdate , err := time .Parse (time .RFC3339 , status .LastUpdatedTime )
461457 if err == nil {
462458 pollInterval := GetMetricsPollInterval (trainJob )
463- preStopDuration := pollInterval * 2 + time . Duration ( constants . PreStopBufferSecs ) * time . Second
464- gracePeriod := preStopDuration + time .Duration (constants .TerminationGraceBufferSecs )* time .Second
459+ // Grace period = 3 polling cycles + buffer (allows multiple retry attempts)
460+ gracePeriod := pollInterval * 3 + time .Duration (constants .TerminationGraceBufferSecs )* time .Second
465461
466- // If it's been longer than preStop + grace period, pod is definitely gone
467462 if time .Since (lastUpdate ) > gracePeriod {
468463 return true // Stop trying, preserve last known state
469464 }
@@ -473,64 +468,175 @@ func IsFinalStatusCaptured(trainJob *trainer.TrainJob) bool {
473468 return false
474469}
475470
471+ // CaptureMetricsFromTerminationMessage reads final metrics from pod's termination message.
472+ // Called when job completes/fails to capture final state written
473+ func CaptureMetricsFromTerminationMessage (ctx context.Context , pod * corev1.Pod ) (* AnnotationStatus , error ) {
474+ if pod == nil {
475+ return nil , fmt .Errorf ("pod is nil" )
476+ }
477+
478+ // Look for the trainer container (typically named "node")
479+ for _ , containerStatus := range pod .Status .ContainerStatuses {
480+ // Check if this is the trainer container
481+ if containerStatus .Name != pkgconstants .Node {
482+ continue
483+ }
484+
485+ // Check if container has terminated
486+ if containerStatus .State .Terminated == nil {
487+ continue
488+ }
489+
490+ // Read termination message
491+ message := containerStatus .State .Terminated .Message
492+ if message == "" {
493+ return nil , fmt .Errorf ("termination message is empty for container %s" , containerStatus .Name )
494+ }
495+
496+ // Parse JSON from termination message
497+ var status AnnotationStatus
498+ if err := json .Unmarshal ([]byte (message ), & status ); err != nil {
499+ return nil , fmt .Errorf ("failed to parse termination message JSON: %w" , err )
500+ }
501+
502+ // Validate critical fields
503+ if status .ProgressPercentage == nil {
504+ return nil , fmt .Errorf ("termination message missing progressPercentage" )
505+ }
506+
507+ status .LastUpdatedTime = time .Now ().UTC ().Format (time .RFC3339 )
508+
509+ return & status , nil
510+ }
511+
512+ return nil , fmt .Errorf ("no terminated trainer container found in pod" )
513+ }
514+
476515func PollAndUpdateFinalProgress (ctx context.Context , c client.Client , reader client.Reader , trainJob * trainer.TrainJob , completed bool ) (bool , error ) {
477516 if ! IsProgressionTrackingEnabled (trainJob ) {
478517 return false , nil
479518 }
480519
481- // Try to get final metrics from pod if it still exists
482- pod , err := GetPrimaryPod (ctx , reader , trainJob )
483- if err == nil {
484- metricsPort := GetMetricsPort (trainJob )
485- if status , pollErr := PollTrainingProgress (ctx , pod , metricsPort ); pollErr == nil {
486- // Got real metrics from pod - update with final status
487- annotationStatus := ToAnnotationStatus (status )
488- annotationStatus .LastUpdatedTime = time .Now ().UTC ().Format (time .RFC3339 )
520+ log := ctrl .Log .WithName ("progression" ).WithValues ("trainjob" , trainJob .Name , "namespace" , trainJob .Namespace )
521+
522+ // Get pod (including terminated pods for termination message reading)
523+ // For final progress, we need to check terminated pods, not just running ones
524+ podList := & corev1.PodList {}
525+ if err := reader .List (ctx , podList ,
526+ client .InNamespace (trainJob .Namespace ),
527+ client.MatchingLabels {"jobset.sigs.k8s.io/jobset-name" : trainJob .Name }); err != nil {
528+ log .V (1 ).Info ("Failed to list pods" , "error" , err )
529+ // Fall through to updateFinalStatus
530+ } else if len (podList .Items ) > 0 {
531+ // Find the primary pod (prefer index-0 pods, or use the first one)
532+ var pod * corev1.Pod
533+ for i := range podList .Items {
534+ p := & podList .Items [i ]
535+ // Look for index-0 pod (primary in most frameworks)
536+ if labels := p .Labels ; labels != nil {
537+ if idx := labels ["jobset.sigs.k8s.io/job-index" ]; idx == "0" {
538+ pod = p
539+ break
540+ }
541+ if idx := labels ["training.kubeflow.org/replica-index" ]; idx == "0" {
542+ pod = p
543+ break
544+ }
545+ }
546+ }
547+ if pod == nil {
548+ pod = & podList .Items [0 ] // Fallback to first pod
549+ }
550+
551+ // Priority 1: Try termination message (most authoritative for final state)
552+ if terminationStatus , termErr := CaptureMetricsFromTerminationMessage (ctx , pod ); termErr == nil {
553+ log .Info ("Captured final metrics from termination message" ,
554+ "progress" , * terminationStatus .ProgressPercentage )
489555
490556 // Add descriptive summary
491557 if completed {
492- // Detect early stop: currentStep < totalSteps
493558 earlyStop := false
494- if annotationStatus .CurrentStep != nil && annotationStatus .TotalSteps != nil && * annotationStatus .TotalSteps > 0 {
495- if * annotationStatus .CurrentStep < * annotationStatus .TotalSteps {
559+ if terminationStatus .CurrentStep != nil && terminationStatus .TotalSteps != nil && * terminationStatus .TotalSteps > 0 {
560+ if * terminationStatus .CurrentStep < * terminationStatus .TotalSteps {
496561 earlyStop = true
497562 }
498563 }
499564 if earlyStop {
500- annotationStatus .EstimatedRemainingTimeSummary = "complete (early stopped)"
565+ terminationStatus .EstimatedRemainingTimeSummary = "complete (early stopped)"
501566 } else {
502- annotationStatus .EstimatedRemainingTimeSummary = "complete"
567+ terminationStatus .EstimatedRemainingTimeSummary = "complete"
503568 }
504569 } else {
505- // For failed jobs: show progress context in summary
506570 progressPct := 0
507- if annotationStatus .ProgressPercentage != nil {
508- progressPct = * annotationStatus .ProgressPercentage
571+ if terminationStatus .ProgressPercentage != nil {
572+ progressPct = * terminationStatus .ProgressPercentage
509573 }
510- annotationStatus .EstimatedRemainingTimeSummary = fmt .Sprintf ("failed at %d%%" , progressPct )
574+ terminationStatus .EstimatedRemainingTimeSummary = fmt .Sprintf ("failed at %d%%" , progressPct )
511575 }
512576
513- // Use Patch to avoid conflicts with main controller
514- patch := client . MergeFrom ( trainJob .DeepCopy () )
515- if err := UpdateTrainerStatusAnnotation (trainJob , annotationStatus ); err != nil {
577+ // Update annotation with termination message data
578+ oldTrainJob := trainJob .DeepCopy ()
579+ if err := UpdateTrainerStatusAnnotation (trainJob , terminationStatus ); err != nil {
516580 return false , fmt .Errorf ("failed to update trainer status annotation: %w" , err )
517581 }
582+ patch := client .MergeFrom (oldTrainJob )
518583 if err := c .Patch (ctx , trainJob , patch ); err != nil {
519584 return false , fmt .Errorf ("failed to patch TrainJob annotations: %w" , err )
520585 }
521586
522587 return true , nil
588+ } else {
589+ log .V (1 ).Info ("Termination message not available, trying HTTP polling" , "error" , termErr )
590+ }
591+
592+ // Priority 2: Try HTTP polling (only if pod is still running with IP)
593+ if pod .Status .Phase == corev1 .PodRunning && pod .Status .PodIP != "" {
594+ metricsPort := GetMetricsPort (trainJob )
595+ if status , pollErr := PollTrainingProgress (ctx , pod , metricsPort ); pollErr == nil {
596+ annotationStatus := ToAnnotationStatus (status )
597+ annotationStatus .LastUpdatedTime = time .Now ().UTC ().Format (time .RFC3339 )
598+
599+ // Add descriptive summary
600+ if completed {
601+ earlyStop := false
602+ if annotationStatus .CurrentStep != nil && annotationStatus .TotalSteps != nil && * annotationStatus .TotalSteps > 0 {
603+ if * annotationStatus .CurrentStep < * annotationStatus .TotalSteps {
604+ earlyStop = true
605+ }
606+ }
607+ if earlyStop {
608+ annotationStatus .EstimatedRemainingTimeSummary = "complete (early stopped)"
609+ } else {
610+ annotationStatus .EstimatedRemainingTimeSummary = "complete"
611+ }
612+ } else {
613+ progressPct := 0
614+ if annotationStatus .ProgressPercentage != nil {
615+ progressPct = * annotationStatus .ProgressPercentage
616+ }
617+ annotationStatus .EstimatedRemainingTimeSummary = fmt .Sprintf ("failed at %d%%" , progressPct )
618+ }
619+
620+ oldTrainJob := trainJob .DeepCopy ()
621+ if err := UpdateTrainerStatusAnnotation (trainJob , annotationStatus ); err != nil {
622+ return false , fmt .Errorf ("failed to update trainer status annotation: %w" , err )
623+ }
624+ patch := client .MergeFrom (oldTrainJob )
625+ if err := c .Patch (ctx , trainJob , patch ); err != nil {
626+ return false , fmt .Errorf ("failed to patch TrainJob annotations: %w" , err )
627+ }
628+
629+ return true , nil
630+ }
523631 }
524632 }
525633
526- // Pod not available - update final status using existing metrics
527- // For completed: force remaining time to 0 (no work remains)
528- // For failed: keep remaining time estimate (useful for resume)
529- // Use Patch to avoid conflicts with main controller
530- patch := client .MergeFrom (trainJob .DeepCopy ())
634+ // Priority 3: Update final status using existing metrics (pod not available or both methods failed)
635+ oldTrainJob := trainJob .DeepCopy ()
531636 if err := updateFinalStatus (trainJob , completed ); err != nil {
532637 return false , fmt .Errorf ("failed to update final status: %w" , err )
533638 }
639+ patch := client .MergeFrom (oldTrainJob )
534640 if err := c .Patch (ctx , trainJob , patch ); err != nil {
535641 return false , fmt .Errorf ("failed to patch TrainJob annotations: %w" , err )
536642 }
@@ -580,99 +686,6 @@ func updateFinalStatus(trainJob *trainer.TrainJob, completed bool) error {
580686 return UpdateTrainerStatusAnnotation (trainJob , & status )
581687}
582688
583- // InjectPreStopHookToApplyConfig adds a preStop lifecycle hook to the primary pod container
584- // using Apply Configuration. This is used by the jobset plugin to inject the hook during pod creation.
585- func InjectPreStopHookToApplyConfig (podSpecAC * corev1ac.PodSpecApplyConfiguration , trainJob * trainer.TrainJob ) error {
586- if ! IsProgressionTrackingEnabled (trainJob ) || podSpecAC == nil {
587- return nil
588- }
589-
590- if len (podSpecAC .Containers ) == 0 {
591- return fmt .Errorf ("no containers in pod spec" )
592- }
593-
594- // Calculate preStop duration based on poll interval
595- pollInterval := GetMetricsPollInterval (trainJob )
596- preStopDuration := pollInterval * 2 + time .Duration (constants .PreStopBufferSecs )* time .Second
597- preStopSleep := int (preStopDuration .Seconds ())
598-
599- // Termination grace must be greater than preStop duration
600- terminationGrace := int64 ((preStopDuration + time .Duration (constants .TerminationGraceBufferSecs )* time .Second ).Seconds ())
601-
602- // Find the primary trainer container by name (typically "node")
603- containerIdx := - 1
604- for i , container := range podSpecAC .Containers {
605- if container .Name != nil && * container .Name == "node" {
606- containerIdx = i
607- break
608- }
609- }
610-
611- // Fallback to first container if "node" not found
612- if containerIdx == - 1 {
613- containerIdx = 0
614- }
615-
616- // Inject preStop hook into the target container
617- lifecycle := corev1ac .Lifecycle ().
618- WithPreStop (corev1ac .LifecycleHandler ().
619- WithExec (corev1ac .ExecAction ().
620- WithCommand ("sleep" , strconv .Itoa (preStopSleep ))))
621-
622- podSpecAC .Containers [containerIdx ].WithLifecycle (lifecycle )
623-
624- // Set termination grace period (use max of existing and calculated)
625- if podSpecAC .TerminationGracePeriodSeconds == nil ||
626- * podSpecAC .TerminationGracePeriodSeconds < terminationGrace {
627- podSpecAC .WithTerminationGracePeriodSeconds (terminationGrace )
628- }
629-
630- return nil
631- }
632-
633- // InjectPreStopHook adds a preStop lifecycle hook to the primary pod container.
634- // The hook keeps the metrics server alive after training completes, ensuring
635- // the controller can capture final metrics before pod termination.
636- //
637- // PreStop duration is calculated as: (2 × poll_interval) + buffer
638- // This guarantees at least 2 poll opportunities after training completion.
639- func InjectPreStopHook (podSpec * corev1.PodSpec , trainJob * trainer.TrainJob ) error {
640- if ! IsProgressionTrackingEnabled (trainJob ) {
641- return nil
642- }
643-
644- if len (podSpec .Containers ) == 0 {
645- return fmt .Errorf ("no containers in pod spec" )
646- }
647-
648- // Inject into primary container (index 0)
649- container := & podSpec .Containers [0 ]
650-
651- // Initialize lifecycle if nil
652- if container .Lifecycle == nil {
653- container .Lifecycle = & corev1.Lifecycle {}
654- }
655-
656- // Calculate preStop duration based on poll interval
657- pollInterval := GetMetricsPollInterval (trainJob )
658- preStopDuration := pollInterval * 2 + time .Duration (constants .PreStopBufferSecs )* time .Second
659- preStopSleep := int (preStopDuration .Seconds ())
660-
661- // Add preStop hook
662- container .Lifecycle .PreStop = & corev1.LifecycleHandler {
663- Exec : & corev1.ExecAction {
664- Command : []string {"sleep" , strconv .Itoa (preStopSleep )},
665- },
666- }
667-
668- // Set termination grace period (must be > preStop duration)
669- terminationGrace := preStopDuration + time .Duration (constants .TerminationGraceBufferSecs )* time .Second
670- terminationGraceSecs := int64 (terminationGrace .Seconds ())
671- podSpec .TerminationGracePeriodSeconds = & terminationGraceSecs
672-
673- return nil
674- }
675-
676689// ReconcileProgression handles progression tracking during TrainJob reconciliation.
677690// Returns ctrl.Result for requeue behavior and any errors encountered.
678691// This should be called at the end of TrainJob reconciliation when progression tracking is enabled.
@@ -702,20 +715,17 @@ func ReconcileProgression(ctx context.Context, c client.Client, reader client.Re
702715 }
703716
704717 if (isCompleted || isFailed ) && ! IsFinalStatusCaptured (trainJob ) {
705- // Job just completed/failed - capture final metrics
706- // PreStop hook keeps pod alive, so this should succeed
718+ // Capture final metrics (termination message + HTTP polling fallback)
707719 captured , pollErr := PollAndUpdateFinalProgress (ctx , c , reader , trainJob , isCompleted )
708720 if pollErr != nil {
709721 log .V (1 ).Info ("Failed to capture final training progress, will retry" , "error" , pollErr , "completed" , isCompleted )
710- // Requeue quickly - pod should still be alive in preStop window
711722 return ctrl.Result {RequeueAfter : 5 * time .Second }, nil
712723 }
713724 if ! captured {
714725 log .V (1 ).Info ("Pod not available for final metrics poll, will retry" , "completed" , isCompleted )
715- // Pod might be in preStop or already terminated, retry a few times
716726 return ctrl.Result {RequeueAfter : 2 * time .Second }, nil
717727 }
718- log .Info ("Captured final training progress" , "completed" , isCompleted )
728+ log .Info ("Captured final training progress from HTTP poll " , "completed" , isCompleted )
719729 }
720730
721731 return ctrl.Result {}, nil
0 commit comments