Skip to content

Refactoring run method #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions impl/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ func (wr *workflowRunnerImpl) Run(input interface{}) (output interface{}, err er
wr.RunnerCtx.SetInput(input)
// Run tasks sequentially
wr.RunnerCtx.SetStatus(ctx.RunningStatus)
doRunner, err := NewDoTaskRunner(wr.Workflow.Do, wr)
doRunner, err := NewDoTaskRunner(wr.Workflow.Do)
if err != nil {
return nil, err
}
wr.RunnerCtx.SetStartedAt(time.Now())
output, err = doRunner.Run(wr.RunnerCtx.GetInput())
output, err = doRunner.Run(wr.RunnerCtx.GetInput(), wr)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion impl/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var _ TaskRunner = &ForTaskRunner{}
var _ TaskRunner = &DoTaskRunner{}

type TaskRunner interface {
Run(input interface{}) (interface{}, error)
Run(input interface{}, taskSupport TaskSupport) (interface{}, error)
GetTaskName() string
}

Expand Down
44 changes: 44 additions & 0 deletions impl/task_runner_call_http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2025 The Serverless Workflow Specification Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package impl

import (
"fmt"

"github.com/serverlessworkflow/sdk-go/v3/model"
)

type CallHTTPTaskRunner struct {
TaskName string
}

func NewCallHttpRunner(taskName string, task *model.CallHTTP) (taskRunner *CallHTTPTaskRunner, err error) {
if task == nil {
err = model.NewErrValidation(fmt.Errorf("invalid For task %s", taskName), taskName)
} else {
taskRunner = new(CallHTTPTaskRunner)
taskRunner.TaskName = taskName
}
return
}

func (f *CallHTTPTaskRunner) Run(input interface{}, taskSupport TaskSupport) (interface{}, error) {
return input, nil

}

func (f *CallHTTPTaskRunner) GetTaskName() string {
return f.TaskName
}
100 changes: 50 additions & 50 deletions impl/task_runner_do.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,110 +23,110 @@ import (
)

// NewTaskRunner creates a TaskRunner instance based on the task type.
func NewTaskRunner(taskName string, task model.Task, taskSupport TaskSupport) (TaskRunner, error) {
func NewTaskRunner(taskName string, task model.Task, workflowDef *model.Workflow) (TaskRunner, error) {
switch t := task.(type) {
case *model.SetTask:
return NewSetTaskRunner(taskName, t, taskSupport)
return NewSetTaskRunner(taskName, t)
case *model.RaiseTask:
return NewRaiseTaskRunner(taskName, t, taskSupport)
return NewRaiseTaskRunner(taskName, t, workflowDef)
case *model.DoTask:
return NewDoTaskRunner(t.Do, taskSupport)
return NewDoTaskRunner(t.Do)
case *model.ForTask:
return NewForTaskRunner(taskName, t, taskSupport)
return NewForTaskRunner(taskName, t)
case *model.CallHTTP:
return NewCallHttpRunner(taskName, t)
default:
return nil, fmt.Errorf("unsupported task type '%T' for task '%s'", t, taskName)
}
}

func NewDoTaskRunner(taskList *model.TaskList, taskSupport TaskSupport) (*DoTaskRunner, error) {
func NewDoTaskRunner(taskList *model.TaskList) (*DoTaskRunner, error) {
return &DoTaskRunner{
TaskList: taskList,
TaskSupport: taskSupport,
TaskList: taskList,
}, nil
}

type DoTaskRunner struct {
TaskList *model.TaskList
TaskSupport TaskSupport
TaskList *model.TaskList
}

func (d *DoTaskRunner) Run(input interface{}) (output interface{}, err error) {
func (d *DoTaskRunner) Run(input interface{}, taskSupport TaskSupport) (output interface{}, err error) {
if d.TaskList == nil {
return input, nil
}
return d.runTasks(input, d.TaskList)
return d.runTasks(input, taskSupport)
}

func (d *DoTaskRunner) GetTaskName() string {
return ""
}

// runTasks runs all defined tasks sequentially.
func (d *DoTaskRunner) runTasks(input interface{}, tasks *model.TaskList) (output interface{}, err error) {
func (d *DoTaskRunner) runTasks(input interface{}, taskSupport TaskSupport) (output interface{}, err error) {
output = input
if tasks == nil {
if d.TaskList == nil {
return output, nil
}

idx := 0
currentTask := (*tasks)[idx]
currentTask := (*d.TaskList)[idx]

for currentTask != nil {
if err = d.TaskSupport.SetTaskDef(currentTask); err != nil {
if err = taskSupport.SetTaskDef(currentTask); err != nil {
return nil, err
}
if err = d.TaskSupport.SetTaskReferenceFromName(currentTask.Key); err != nil {
if err = taskSupport.SetTaskReferenceFromName(currentTask.Key); err != nil {
return nil, err
}

if shouldRun, err := d.shouldRunTask(input, currentTask); err != nil {
if shouldRun, err := d.shouldRunTask(input, taskSupport, currentTask); err != nil {
return output, err
} else if !shouldRun {
idx, currentTask = tasks.Next(idx)
idx, currentTask = d.TaskList.Next(idx)
continue
}

d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.PendingStatus)
taskSupport.SetTaskStatus(currentTask.Key, ctx.PendingStatus)

// Check if this task is a SwitchTask and handle it
if switchTask, ok := currentTask.Task.(*model.SwitchTask); ok {
flowDirective, err := d.evaluateSwitchTask(input, currentTask.Key, switchTask)
flowDirective, err := d.evaluateSwitchTask(input, taskSupport, currentTask.Key, switchTask)
if err != nil {
d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.FaultedStatus)
taskSupport.SetTaskStatus(currentTask.Key, ctx.FaultedStatus)
return output, err
}
d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.CompletedStatus)
taskSupport.SetTaskStatus(currentTask.Key, ctx.CompletedStatus)

// Process FlowDirective: update idx/currentTask accordingly
idx, currentTask = tasks.KeyAndIndex(flowDirective.Value)
idx, currentTask = d.TaskList.KeyAndIndex(flowDirective.Value)
if currentTask == nil {
return nil, fmt.Errorf("flow directive target '%s' not found", flowDirective.Value)
}
continue
}

runner, err := NewTaskRunner(currentTask.Key, currentTask.Task, d.TaskSupport)
runner, err := NewTaskRunner(currentTask.Key, currentTask.Task, taskSupport.GetWorkflowDef())
if err != nil {
return output, err
}

d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.RunningStatus)
if output, err = d.runTask(input, runner, currentTask.Task.GetBase()); err != nil {
d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.FaultedStatus)
taskSupport.SetTaskStatus(currentTask.Key, ctx.RunningStatus)
if output, err = d.runTask(input, taskSupport, runner, currentTask.Task.GetBase()); err != nil {
taskSupport.SetTaskStatus(currentTask.Key, ctx.FaultedStatus)
return output, err
}

d.TaskSupport.SetTaskStatus(currentTask.Key, ctx.CompletedStatus)
taskSupport.SetTaskStatus(currentTask.Key, ctx.CompletedStatus)
input = deepCloneValue(output)
idx, currentTask = tasks.Next(idx)
idx, currentTask = d.TaskList.Next(idx)
}

return output, nil
}

func (d *DoTaskRunner) shouldRunTask(input interface{}, task *model.TaskItem) (bool, error) {
func (d *DoTaskRunner) shouldRunTask(input interface{}, taskSupport TaskSupport, task *model.TaskItem) (bool, error) {
if task.GetBase().If != nil {
output, err := traverseAndEvaluateBool(task.GetBase().If.String(), input, d.TaskSupport.GetContext())
output, err := traverseAndEvaluateBool(task.GetBase().If.String(), input, taskSupport.GetContext())
if err != nil {
return false, model.NewErrExpression(err, task.Key)
}
Expand All @@ -135,15 +135,15 @@ func (d *DoTaskRunner) shouldRunTask(input interface{}, task *model.TaskItem) (b
return true, nil
}

func (d *DoTaskRunner) evaluateSwitchTask(input interface{}, taskKey string, switchTask *model.SwitchTask) (*model.FlowDirective, error) {
func (d *DoTaskRunner) evaluateSwitchTask(input interface{}, taskSupport TaskSupport, taskKey string, switchTask *model.SwitchTask) (*model.FlowDirective, error) {
var defaultThen *model.FlowDirective
for _, switchItem := range switchTask.Switch {
for _, switchCase := range switchItem {
if switchCase.When == nil {
defaultThen = switchCase.Then
continue
}
result, err := traverseAndEvaluateBool(model.NormalizeExpr(switchCase.When.String()), input, d.TaskSupport.GetContext())
result, err := traverseAndEvaluateBool(model.NormalizeExpr(switchCase.When.String()), input, taskSupport.GetContext())
if err != nil {
return nil, model.NewErrExpression(err, taskKey)
}
Expand All @@ -162,39 +162,39 @@ func (d *DoTaskRunner) evaluateSwitchTask(input interface{}, taskKey string, swi
}

// runTask executes an individual task.
func (d *DoTaskRunner) runTask(input interface{}, runner TaskRunner, task *model.TaskBase) (output interface{}, err error) {
func (d *DoTaskRunner) runTask(input interface{}, taskSupport TaskSupport, runner TaskRunner, task *model.TaskBase) (output interface{}, err error) {
taskName := runner.GetTaskName()

d.TaskSupport.SetTaskStartedAt(time.Now())
d.TaskSupport.SetTaskRawInput(input)
d.TaskSupport.SetTaskName(taskName)
taskSupport.SetTaskStartedAt(time.Now())
taskSupport.SetTaskRawInput(input)
taskSupport.SetTaskName(taskName)

if task.Input != nil {
if input, err = d.processTaskInput(task, input, taskName); err != nil {
if input, err = d.processTaskInput(task, input, taskSupport, taskName); err != nil {
return nil, err
}
}

output, err = runner.Run(input)
output, err = runner.Run(input, taskSupport)
if err != nil {
return nil, err
}

d.TaskSupport.SetTaskRawOutput(output)
taskSupport.SetTaskRawOutput(output)

if output, err = d.processTaskOutput(task, output, taskName); err != nil {
if output, err = d.processTaskOutput(task, output, taskSupport, taskName); err != nil {
return nil, err
}

if err = d.processTaskExport(task, output, taskName); err != nil {
if err = d.processTaskExport(task, output, taskSupport, taskName); err != nil {
return nil, err
}

return output, nil
}

// processTaskInput processes task input validation and transformation.
func (d *DoTaskRunner) processTaskInput(task *model.TaskBase, taskInput interface{}, taskName string) (output interface{}, err error) {
func (d *DoTaskRunner) processTaskInput(task *model.TaskBase, taskInput interface{}, taskSupport TaskSupport, taskName string) (output interface{}, err error) {
if task.Input == nil {
return taskInput, nil
}
Expand All @@ -203,20 +203,20 @@ func (d *DoTaskRunner) processTaskInput(task *model.TaskBase, taskInput interfac
return nil, err
}

if output, err = traverseAndEvaluate(task.Input.From, taskInput, taskName, d.TaskSupport.GetContext()); err != nil {
if output, err = traverseAndEvaluate(task.Input.From, taskInput, taskName, taskSupport.GetContext()); err != nil {
return nil, err
}

return output, nil
}

// processTaskOutput processes task output validation and transformation.
func (d *DoTaskRunner) processTaskOutput(task *model.TaskBase, taskOutput interface{}, taskName string) (output interface{}, err error) {
func (d *DoTaskRunner) processTaskOutput(task *model.TaskBase, taskOutput interface{}, taskSupport TaskSupport, taskName string) (output interface{}, err error) {
if task.Output == nil {
return taskOutput, nil
}

if output, err = traverseAndEvaluate(task.Output.As, taskOutput, taskName, d.TaskSupport.GetContext()); err != nil {
if output, err = traverseAndEvaluate(task.Output.As, taskOutput, taskName, taskSupport.GetContext()); err != nil {
return nil, err
}

Expand All @@ -227,12 +227,12 @@ func (d *DoTaskRunner) processTaskOutput(task *model.TaskBase, taskOutput interf
return output, nil
}

func (d *DoTaskRunner) processTaskExport(task *model.TaskBase, taskOutput interface{}, taskName string) (err error) {
func (d *DoTaskRunner) processTaskExport(task *model.TaskBase, taskOutput interface{}, taskSupport TaskSupport, taskName string) (err error) {
if task.Export == nil {
return nil
}

output, err := traverseAndEvaluate(task.Export.As, taskOutput, taskName, d.TaskSupport.GetContext())
output, err := traverseAndEvaluate(task.Export.As, taskOutput, taskName, taskSupport.GetContext())
if err != nil {
return err
}
Expand All @@ -241,7 +241,7 @@ func (d *DoTaskRunner) processTaskExport(task *model.TaskBase, taskOutput interf
return nil
}

d.TaskSupport.SetWorkflowInstanceCtx(output)
taskSupport.SetWorkflowInstanceCtx(output)

return nil
}
Loading
Loading