Skip to content

Commit 2a582f8

Browse files
author
Sebastian Neira
authored
Add mock call assertions to TestWorkflowEnvironment (#748)
1 parent 8608a59 commit 2a582f8

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

internal/workflow_testsuite.go

+24
Original file line numberDiff line numberDiff line change
@@ -814,3 +814,27 @@ func (e *TestWorkflowEnvironment) SetSearchAttributesOnStart(searchAttributes ma
814814
func (e *TestWorkflowEnvironment) AssertExpectations(t mock.TestingT) bool {
815815
return e.mock.AssertExpectations(t)
816816
}
817+
818+
// AssertCalled asserts that the method was called with the supplied arguments.
819+
// Useful to assert that an Activity was called from within a workflow with the expected arguments.
820+
// Since the first argument is a context, consider using mock.Anything for that argument.
821+
//
822+
// env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("mock_result", nil)
823+
// env.ExecuteWorkflow(workflowThatCallsActivityWithItsArgument, "Hello")
824+
// env.AssertCalled(t, "namedActivity", mock.Anything, "Hello")
825+
//
826+
// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
827+
func (e *TestWorkflowEnvironment) AssertCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool {
828+
return e.mock.AssertCalled(t, methodName, arguments...)
829+
}
830+
831+
// AssertNotCalled asserts that the method was not called with the given arguments.
832+
// See AssertCalled for more info.
833+
func (e *TestWorkflowEnvironment) AssertNotCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool {
834+
return e.mock.AssertNotCalled(t, methodName, arguments...)
835+
}
836+
837+
// AssertNumberOfCalls asserts that a method was called expectedCalls times.
838+
func (e *TestWorkflowEnvironment) AssertNumberOfCalls(t mock.TestingT, methodName string, expectedCalls int) bool {
839+
return e.mock.AssertNumberOfCalls(t, methodName, expectedCalls)
840+
}

internal/workflow_testsuite_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"testing"
3232
"time"
3333

34+
"github.com/stretchr/testify/mock"
3435
"github.com/stretchr/testify/require"
3536
)
3637

@@ -236,3 +237,57 @@ func TestWorkflowStartTimeInsideTestWorkflow(t *testing.T) {
236237
require.NoError(t, env.GetWorkflowResult(&timestamp))
237238
require.Equal(t, env.Now().Unix(), timestamp)
238239
}
240+
241+
func TestActivityAssertCalled(t *testing.T) {
242+
testSuite := &WorkflowTestSuite{}
243+
env := testSuite.NewTestWorkflowEnvironment()
244+
245+
env.RegisterActivity(namedActivity)
246+
env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil)
247+
248+
env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) {
249+
ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{
250+
ScheduleToCloseTimeout: time.Hour,
251+
StartToCloseTimeout: time.Hour,
252+
})
253+
var result string
254+
err := ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
255+
if err != nil {
256+
return "", err
257+
}
258+
return result, nil
259+
}, "Hello")
260+
261+
require.NoError(t, env.GetWorkflowError())
262+
var result string
263+
err := env.GetWorkflowResult(&result)
264+
require.NoError(t, err)
265+
266+
require.Equal(t, "Mock!", result)
267+
env.AssertCalled(t, "namedActivity", mock.Anything, "Hello")
268+
env.AssertNotCalled(t, "namedActivity", mock.Anything, "Bye")
269+
}
270+
271+
func TestActivityAssertNumberOfCalls(t *testing.T) {
272+
testSuite := &WorkflowTestSuite{}
273+
env := testSuite.NewTestWorkflowEnvironment()
274+
275+
env.RegisterActivity(namedActivity)
276+
env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil)
277+
278+
env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) {
279+
ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{
280+
ScheduleToCloseTimeout: time.Hour,
281+
StartToCloseTimeout: time.Hour,
282+
})
283+
var result string
284+
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
285+
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
286+
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
287+
return result, nil
288+
}, "Hello")
289+
290+
require.NoError(t, env.GetWorkflowError())
291+
env.AssertNumberOfCalls(t, "namedActivity", 3)
292+
env.AssertNumberOfCalls(t, "otherActivity", 0)
293+
}

0 commit comments

Comments
 (0)