Skip to content

Commit d24676a

Browse files
fix: Add retry logic with delay for fetching DAG tasks during TaskFinalStatus resolution
Signed-off-by: khushiiagrawal <[email protected]>
1 parent f087a3d commit d24676a

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

backend/src/v2/driver/resolve.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"errors"
2121
"fmt"
2222
"strings"
23+
"time"
2324

2425
"github.com/golang/glog"
2526
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
@@ -440,17 +441,35 @@ func resolveInputParameter(
440441
return nil, paramError(fmt.Errorf("param runtime value spec of type %T not implemented", t))
441442
}
442443
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_:
443-
tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil)
444-
if err != nil {
445-
return nil, err
444+
// Retry logic for obtaining DAG tasks to handle potential race conditions or delays
445+
// in MLMD propagation. We try up to 5 times with a 2-second sleep between attempts.
446+
var tasks map[string]*metadata.Execution
447+
for i := 0; i < 5; i++ {
448+
var err error
449+
tasks, err = getDAGTasks(ctx, dag, pipeline, mlmd, nil)
450+
if err != nil {
451+
return nil, err
452+
}
453+
454+
if len(opts.Task.DependentTasks) < 1 {
455+
return nil, fmt.Errorf("task %v has no dependent tasks", opts.Task.TaskInfo.GetName())
456+
}
457+
458+
producerName := metadata.GetTaskNameWithDagID(opts.Task.DependentTasks[0], dag.Execution.GetID())
459+
if _, ok := tasks[producerName]; ok {
460+
break
461+
}
462+
glog.Warningf("Unable to obtain the node for %s, taskName %s. Retrying...", producerName, opts.TaskName)
463+
time.Sleep(2 * time.Second)
446464
}
447465

448466
if len(opts.Task.DependentTasks) < 1 {
449467
return nil, fmt.Errorf("task %v has no dependent tasks", opts.Task.TaskInfo.GetName())
450468
}
451-
producer, ok := tasks[metadata.GetTaskNameWithDagID(opts.Task.DependentTasks[0], dag.Execution.GetID())]
469+
producerName := metadata.GetTaskNameWithDagID(opts.Task.DependentTasks[0], dag.Execution.GetID())
470+
producer, ok := tasks[producerName]
452471
if !ok {
453-
return nil, fmt.Errorf("producer task, %v, not in tasks", producer.TaskName())
472+
return nil, fmt.Errorf("producer task, %s, not in tasks", producerName) // Fixed potential panic
454473
}
455474
finalStatus := pipelinespec.PipelineTaskFinalStatus{
456475
State: producer.GetExecution().GetLastKnownState().String(),

0 commit comments

Comments
 (0)