@@ -17,6 +17,7 @@ package pytorch
17
17
18
18
import (
19
19
"fmt"
20
+ "strings"
20
21
"time"
21
22
22
23
kubebatchclient "github.com/kubernetes-sigs/kube-batch/pkg/client/clientset/versioned"
@@ -38,8 +39,10 @@ import (
38
39
jobinformers "github.com/kubeflow/pytorch-operator/pkg/client/informers/externalversions"
39
40
jobinformersv1beta2 "github.com/kubeflow/pytorch-operator/pkg/client/informers/externalversions/pytorch/v1beta2"
40
41
joblisters "github.com/kubeflow/pytorch-operator/pkg/client/listers/pytorch/v1beta2"
42
+ common "github.com/kubeflow/tf-operator/pkg/apis/common/v1beta2"
41
43
"github.com/kubeflow/tf-operator/pkg/common/jobcontroller"
42
44
pylogger "github.com/kubeflow/tf-operator/pkg/logger"
45
+ "github.com/kubeflow/tf-operator/pkg/util/k8sutil"
43
46
)
44
47
45
48
const (
@@ -326,18 +329,15 @@ func (pc *PyTorchController) syncPyTorchJob(key string) (bool, error) {
326
329
return true , err
327
330
}
328
331
329
- func getTotalReplicas (obj metav1.Object ) int32 {
330
- job := obj .(* v1beta2.PyTorchJob )
331
- jobReplicas := int32 (0 )
332
- for _ , r := range job .Spec .PyTorchReplicaSpecs {
333
- jobReplicas += * r .Replicas
334
- }
335
- return jobReplicas
336
- }
337
-
338
332
// reconcilePyTorchJobs checks and updates replicas for each given PyTorchReplicaSpec.
339
333
// It will requeue the job in case of an error while creating/deleting pods/services.
340
334
func (pc * PyTorchController ) reconcilePyTorchJobs (job * v1beta2.PyTorchJob ) error {
335
+ jobKey , err := KeyFunc (job )
336
+ if err != nil {
337
+ utilruntime .HandleError (fmt .Errorf ("couldn't get key for pytorch job object %#v: %v" , job , err ))
338
+ return err
339
+ }
340
+
341
341
logger := pylogger .LoggerForJob (job )
342
342
logger .Infof ("Reconcile PyTorchJobs %s" , job .Name )
343
343
@@ -355,8 +355,46 @@ func (pc *PyTorchController) reconcilePyTorchJobs(job *v1beta2.PyTorchJob) error
355
355
return err
356
356
}
357
357
358
+ // retrieve the previous number of retry
359
+ previousRetry := pc .WorkQueue .NumRequeues (jobKey )
360
+
361
+ activePods := k8sutil .FilterActivePods (pods )
362
+ active := int32 (len (activePods ))
363
+ failed := int32 (k8sutil .FilterPods (pods , v1 .PodFailed ))
364
+ totalReplicas := getTotalReplicas (job )
365
+ prevReplicasFailedNum := getTotalFailedReplicas (job )
366
+
367
+ var failureMessage string
368
+ jobExceedsLimit := false
369
+ exceedsBackoffLimit := false
370
+ pastBackoffLimit := false
371
+
372
+ if job .Spec .BackoffLimit != nil {
373
+ jobHasNewFailure := failed > prevReplicasFailedNum
374
+ // new failures happen when status does not reflect the failures and active
375
+ // is different than parallelism, otherwise the previous controller loop
376
+ // failed updating status so even if we pick up failure it is not a new one
377
+ exceedsBackoffLimit = jobHasNewFailure && (active != totalReplicas ) &&
378
+ (int32 (previousRetry )+ 1 > * job .Spec .BackoffLimit )
379
+
380
+ pastBackoffLimit , err = pc .pastBackoffLimit (job , pods )
381
+ if err != nil {
382
+ return err
383
+ }
384
+ }
385
+
386
+ if exceedsBackoffLimit || pastBackoffLimit {
387
+ // check if the number of pod restart exceeds backoff (for restart OnFailure only)
388
+ // OR if the number of failed jobs increased since the last syncJob
389
+ jobExceedsLimit = true
390
+ failureMessage = fmt .Sprintf ("PyTorchJob %s has failed because it has reached the specified backoff limit" , job .Name )
391
+ } else if pc .pastActiveDeadline (job ) {
392
+ failureMessage = fmt .Sprintf ("PyTorchJob %s has failed because it was active longer than specified deadline" , job .Name )
393
+ jobExceedsLimit = true
394
+ }
395
+
358
396
// If the PyTorchJob is terminated, delete all pods and services.
359
- if isSucceeded (job .Status ) || isFailed (job .Status ) {
397
+ if isSucceeded (job .Status ) || isFailed (job .Status ) || jobExceedsLimit {
360
398
if err := pc .deletePodsAndServices (job , pods ); err != nil {
361
399
return err
362
400
}
@@ -375,7 +413,18 @@ func (pc *PyTorchController) reconcilePyTorchJobs(job *v1beta2.PyTorchJob) error
375
413
376
414
}
377
415
}
378
-
416
+ if jobExceedsLimit {
417
+ pc .Recorder .Event (job , v1 .EventTypeNormal , pytorchJobFailedReason , failureMessage )
418
+ if job .Status .CompletionTime == nil {
419
+ now := metav1 .Now ()
420
+ job .Status .CompletionTime = & now
421
+ }
422
+ err := updatePyTorchJobConditions (job , common .JobFailed , pytorchJobFailedReason , failureMessage )
423
+ if err != nil {
424
+ logger .Infof ("Append pytorchjob condition error: %v" , err )
425
+ return err
426
+ }
427
+ }
379
428
// At this point the pods may have been deleted, so if the job succeeded, we need to manually set the replica status.
380
429
// If any replicas are still Active, set their status to succeeded.
381
430
if isSucceeded (job .Status ) {
@@ -434,6 +483,59 @@ func (pc *PyTorchController) satisfiedExpectations(job *v1beta2.PyTorchJob) bool
434
483
return satisfied
435
484
}
436
485
486
+ // pastBackoffLimitOnFailure checks if container restartCounts sum exceeds BackoffLimit
487
+ // this method applies only to pods with restartPolicy == OnFailure or Always
488
+ func (pc * PyTorchController ) pastBackoffLimit (job * v1beta2.PyTorchJob , pods []* v1.Pod ) (bool , error ) {
489
+ if job .Spec .BackoffLimit == nil {
490
+ return false , nil
491
+ }
492
+ logger := pylogger .LoggerForJob (job )
493
+ result := int32 (0 )
494
+ for rtype , spec := range job .Spec .PyTorchReplicaSpecs {
495
+ if spec .RestartPolicy != common .RestartPolicyOnFailure && spec .RestartPolicy != common .RestartPolicyAlways {
496
+ logger .Warnf ("The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit." , rtype , job .Name )
497
+ continue
498
+ }
499
+ // Convert PyTorchReplicaType to lower string.
500
+ rt := strings .ToLower (string (rtype ))
501
+ pods , err := pc .FilterPodsForReplicaType (pods , rt )
502
+ if err != nil {
503
+ return false , err
504
+ }
505
+ for i := range pods {
506
+ po := pods [i ]
507
+ if po .Status .Phase != v1 .PodRunning {
508
+ continue
509
+ }
510
+ for j := range po .Status .InitContainerStatuses {
511
+ stat := po .Status .InitContainerStatuses [j ]
512
+ result += stat .RestartCount
513
+ }
514
+ for j := range po .Status .ContainerStatuses {
515
+ stat := po .Status .ContainerStatuses [j ]
516
+ result += stat .RestartCount
517
+ }
518
+ }
519
+ }
520
+
521
+ if * job .Spec .BackoffLimit == 0 {
522
+ return result > 0 , nil
523
+ }
524
+ return result >= * job .Spec .BackoffLimit , nil
525
+ }
526
+
527
+ // pastActiveDeadline checks if job has ActiveDeadlineSeconds field set and if it is exceeded.
528
+ func (pc * PyTorchController ) pastActiveDeadline (job * v1beta2.PyTorchJob ) bool {
529
+ if job .Spec .ActiveDeadlineSeconds == nil || job .Status .StartTime == nil {
530
+ return false
531
+ }
532
+ now := metav1 .Now ()
533
+ start := job .Status .StartTime .Time
534
+ duration := now .Time .Sub (start )
535
+ allowedDuration := time .Duration (* job .Spec .ActiveDeadlineSeconds ) * time .Second
536
+ return duration >= allowedDuration
537
+ }
538
+
437
539
func (pc * PyTorchController ) GetJobFromInformerCache (namespace , name string ) (metav1.Object , error ) {
438
540
return pc .getPyTorchJobFromName (namespace , name )
439
541
}
0 commit comments