Skip to content

Commit 1180e10

Browse files
authored
fix(backend): fix simple loop bug (#7578)
* support IR YAML format in API * Check the error message and return false if it is not nil * update error message * fixed simple loop but need cleaning up * Deleted debug logs * remove logs and fix some format * fix static_loop_arguments * change the driver image change the driver image back to the kfp container registry. * change variable declaration * remove logs * remove log * move `ok` definition * change test file for debug purpose * change test for debug purpose * update sample test for static loop * update test file, remove code for debug
1 parent c19facc commit 1180e10

File tree

3 files changed

+43
-13
lines changed

3 files changed

+43
-13
lines changed

backend/src/v2/driver/driver.go

+37-4
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
414414
executorInput := &pipelinespec.ExecutorInput{
415415
Inputs: inputs,
416416
}
417+
glog.Infof("executorInput value: %+v", executorInput)
417418
execution = &Execution{ExecutorInput: executorInput}
418419
condition := opts.Task.GetTriggerPolicy().GetCondition()
419420
if condition != "" {
@@ -436,14 +437,37 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
436437
return execution, fmt.Errorf("ArtifactIterator is not implemented")
437438
}
438439
isIterator := opts.Task.GetParameterIterator() != nil && opts.IterationIndex < 0
440+
// Fan out iterations
439441
if execution.WillTrigger() && isIterator {
440442
iterator := opts.Task.GetParameterIterator()
441-
value, ok := executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
442443
report := func(err error) error {
443444
return fmt.Errorf("iterating on item input %q failed: %w", iterator.GetItemInput(), err)
444445
}
445-
if !ok {
446-
return execution, report(fmt.Errorf("cannot find input parameter"))
446+
// Check the items type of parameterIterator:
447+
// It can be "inputParameter" or "Raw"
448+
var value *structpb.Value
449+
switch iterator.GetItems().GetKind().(type) {
450+
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_InputParameter:
451+
var ok bool
452+
value, ok = executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
453+
if !ok {
454+
return execution, report(fmt.Errorf("cannot find input parameter"))
455+
}
456+
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_Raw:
457+
value_raw := iterator.GetItems().GetRaw()
458+
var unmarshalled_raw interface{}
459+
err = json.Unmarshal([]byte(value_raw), &unmarshalled_raw)
460+
if err != nil {
461+
return execution, fmt.Errorf("error unmarshall raw string: %q", err)
462+
}
463+
value, err = structpb.NewValue(unmarshalled_raw)
464+
if err != nil {
465+
return execution, fmt.Errorf("error converting unmarshalled raw string into protobuf Value type: %q", err)
466+
}
467+
// Add the raw input to the executor input
468+
execution.ExecutorInput.Inputs.ParameterValues[iterator.GetItemInput()] = value
469+
default:
470+
return execution, fmt.Errorf("cannot find parameter iterator")
447471
}
448472
items, err := getItems(value)
449473
if err != nil {
@@ -724,7 +748,16 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
724748
case task.GetArtifactIterator() != nil:
725749
return nil, fmt.Errorf("artifact iterator not implemented yet")
726750
case task.GetParameterIterator() != nil:
727-
itemsInput := task.GetParameterIterator().GetItems().GetInputParameter()
751+
var itemsInput string
752+
if task.GetParameterIterator().GetItems().GetInputParameter() != "" {
753+
// input comes from outside the component
754+
itemsInput = task.GetParameterIterator().GetItems().GetInputParameter()
755+
} else if task.GetParameterIterator().GetItemInput() != "" {
756+
// input comes from static input
757+
itemsInput = task.GetParameterIterator().GetItemInput()
758+
} else {
759+
return nil, fmt.Errorf("cannot retrieve parameter iterator.")
760+
}
728761
items, err := getItems(inputs.ParameterValues[itemsInput])
729762
if err != nil {
730763
return nil, err

samples/core/loop_static/loop_static_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import kfp_server_api
1919
from .loop_static import my_pipeline
2020
from .loop_static_v2 import my_pipeline as my_pipeline_v2
21-
from kfp.samples.test.utils import KfpTask, debug_verify, run_pipeline_func, TestCase
21+
from kfp.samples.test.utils import KfpTask, run_pipeline_func, TestCase
2222

2323

2424
def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
2525
tasks: dict[str, KfpTask], **kwargs):
2626
t.assertEqual(run.status, 'Succeeded')
2727
# assert DAG structure
28-
t.assertCountEqual(['print-op', 'for-loop-1'], tasks.keys())
28+
t.assertCountEqual(['print-op', 'for-loop-2'], tasks.keys())
2929
# assert all iteration parameters
3030
t.assertCountEqual(
3131
[{
@@ -37,14 +37,14 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
3737
}],
3838
[
3939
x.inputs
40-
.parameters['pipelinechannel--static_loop_arguments-loop-item']
41-
for x in tasks['for-loop-1'].children.values()
40+
.parameters['pipelinechannel--loop-item-param-1']
41+
for x in tasks['for-loop-2'].children.values()
4242
],
4343
)
4444
# assert all iteration outputs
4545
t.assertCountEqual(['12', '1020'], [
4646
x.children['print-op-2'].outputs.parameters['Output']
47-
for x in tasks['for-loop-1'].children.values()
47+
for x in tasks['for-loop-2'].children.values()
4848
])
4949

5050

samples/core/loop_static/loop_static_v2.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,12 @@ def concat_op(a: str, b: str) -> str:
1919
return a + b
2020

2121

22-
_DEFAULT_LOOP_ARGUMENTS = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
23-
24-
2522
@dsl.pipeline(name='pipeline-with-loop-static')
2623
def my_pipeline(
27-
static_loop_arguments: List[dict] = _DEFAULT_LOOP_ARGUMENTS,
2824
greeting: str = 'this is a test for looping through parameters',
2925
):
3026
print_task = print_op(text=greeting)
27+
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
3128

3229
with dsl.ParallelFor(static_loop_arguments) as item:
3330
concat_task = concat_op(a=item.a, b=item.b)

0 commit comments

Comments
 (0)