Skip to content

Commit be2b46c

Browse files
authored
[tst] refactor interfaces (#28)
* [tst] refactor interfaces * edge cases tests
1 parent 6855394 commit be2b46c

File tree

3 files changed

+107
-24
lines changed

3 files changed

+107
-24
lines changed

tst/errors.go

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@ const expectedError = "Expected error but none received"
1515
type TestingT interface {
1616
Errorf(format string, args ...interface{})
1717
FailNow()
18+
Helper()
1819
}
1920

2021
type ErrorAssertionFunc func(t TestingT, err error) bool
2122

2223
func (e ErrorAssertionFunc) AsRequire() require.ErrorAssertionFunc {
2324
return func(tt require.TestingT, err error, _ ...any) {
24-
if suc := e(tt, err); !suc {
25+
t, is := tt.(TestingT)
26+
if !is { // not possible
27+
tt.Errorf("Wrong TestingT type %T", tt)
28+
tt.FailNow()
29+
return
30+
}
31+
32+
if suc := e(t, err); !suc {
2533
tt.FailNow()
2634
}
2735
}
@@ -30,21 +38,18 @@ func (e ErrorAssertionFunc) AsRequire() require.ErrorAssertionFunc {
3038
func (e ErrorAssertionFunc) AsAssert() assert.ErrorAssertionFunc {
3139
return func(tt assert.TestingT, err error, _ ...any) bool {
3240
t, is := tt.(TestingT)
33-
if is {
34-
return e(t, err)
41+
if !is { // not possible
42+
tt.Errorf("Wrong TestingT type %T", tt)
43+
return false
3544
}
3645

37-
// not possible
38-
tt.Errorf("Wrong TestingT type %T", tt)
39-
return false
46+
return e(t, err)
4047
}
4148
}
4249

4350
func NoError() ErrorAssertionFunc {
4451
return func(t TestingT, err error) bool {
45-
if h, ok := t.(interface{ Helper() }); ok {
46-
h.Helper()
47-
}
52+
t.Helper()
4853

4954
if err != nil {
5055
t.Errorf("Expected nil error but received : %T(%s)", err, err.Error())
@@ -57,9 +62,7 @@ func NoError() ErrorAssertionFunc {
5762

5863
func Error() ErrorAssertionFunc {
5964
return func(t TestingT, err error) bool {
60-
if h, ok := t.(interface{ Helper() }); ok {
61-
h.Helper()
62-
}
65+
t.Helper()
6366

6467
if err == nil {
6568
t.Errorf(expectedError)
@@ -76,9 +79,8 @@ func Error() ErrorAssertionFunc {
7679
// Returns false if the error is nil or doesn't match any expected errors.
7780
func ErrorIs(allExpectedErrors ...error) ErrorAssertionFunc {
7881
return func(t TestingT, err error) bool {
79-
if h, ok := t.(interface{ Helper() }); ok {
80-
h.Helper()
81-
}
82+
t.Helper()
83+
8284
if err == nil {
8385
t.Errorf(expectedError)
8486
return false
@@ -121,9 +123,7 @@ func ErrorIs(allExpectedErrors ...error) ErrorAssertionFunc {
121123

122124
func ErrorOfType[T error](typedAsserts ...func(TestingT, T)) ErrorAssertionFunc {
123125
return func(t TestingT, err error) bool {
124-
if h, ok := t.(interface{ Helper() }); ok {
125-
h.Helper()
126-
}
126+
t.Helper()
127127

128128
if err == nil {
129129
t.Errorf(expectedError)
@@ -152,9 +152,7 @@ func ErrorOfType[T error](typedAsserts ...func(TestingT, T)) ErrorAssertionFunc
152152

153153
func ErrorStringContains(s string) ErrorAssertionFunc {
154154
return func(t TestingT, err error) bool {
155-
if h, ok := t.(interface{ Helper() }); ok {
156-
h.Helper()
157-
}
155+
t.Helper()
158156

159157
if err == nil {
160158
t.Errorf(expectedError)
@@ -173,9 +171,7 @@ func ErrorStringContains(s string) ErrorAssertionFunc {
173171

174172
func All(expected ...ErrorAssertionFunc) ErrorAssertionFunc {
175173
return func(t TestingT, err error) bool {
176-
if h, ok := t.(interface{ Helper() }); ok {
177-
h.Helper()
178-
}
174+
t.Helper()
179175

180176
ret := true
181177
for _, fn := range expected {

tst/errors_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/stretchr/testify/assert"
1313
"github.com/stretchr/testify/mock"
14+
"github.com/stretchr/testify/require"
1415
)
1516

1617
func TestNoError(t *testing.T) {
@@ -249,6 +250,7 @@ func TestErrorOfType(t *testing.T) {
249250
pathErr := &os.PathError{Op: "open", Path: "/test", Err: baseErr}
250251

251252
mt := NewMockTestingT(t)
253+
mt.EXPECT().Helper().Maybe()
252254
ma := &mockErrorTypedAssertionFunc{}
253255
ma.OnAssert(mt, pathErr).Return(true).Times(3)
254256

@@ -337,6 +339,7 @@ type errorAssertionFuncTestCase struct {
337339

338340
func (tc errorAssertionFuncTestCase) Test(t *testing.T) {
339341
mt := NewMockTestingT(t)
342+
mt.EXPECT().Helper().Maybe()
340343
if tc.initMock != nil {
341344
tc.initMock(mt)
342345
}
@@ -402,11 +405,29 @@ func TestTestifyIntegration(t *testing.T) {
402405

403406
// mock T init
404407
mt := NewMockTestingT(t)
408+
mt.EXPECT().Helper().Maybe()
405409
tc.mockTInit(mt)
406410

407411
tc.run(t, mt, f)
408412
})
409413
}
414+
415+
t.Run("mismatch of T", func(t *testing.T) {
416+
t.Run("require testingT", func(t *testing.T) {
417+
mt := &testifyRequireTestingTMock{}
418+
mt.On("Errorf", "Wrong TestingT type %T", mock.Anything).Once()
419+
mt.On("FailNow").Once()
420+
421+
NoError().AsRequire()(mt, nil)
422+
})
423+
t.Run("assert testingT", func(t *testing.T) {
424+
mt := &testifyAssertTestingTMock{}
425+
mt.On("Errorf", "Wrong TestingT type %T", mock.Anything).Once()
426+
427+
got := NoError().AsAssert()(mt, nil)
428+
assert.False(t, got)
429+
})
430+
})
410431
}
411432

412433
func TestAll(t *testing.T) {
@@ -450,6 +471,7 @@ func TestAll(t *testing.T) {
450471
for _, tt := range tests {
451472
t.Run(tt.name, func(t *testing.T) {
452473
mt := NewMockTestingT(t)
474+
mt.EXPECT().Helper().Maybe()
453475
result := All(tt.assertionFuncs...)(mt, errors.New("test error"))
454476
assert.Equal(t, tt.expectResult, result)
455477
})
@@ -562,3 +584,35 @@ func TestFail(t *testing.T) {
562584
})
563585
}
564586
}
587+
588+
var _ (assert.TestingT) = (*testifyAssertTestingTMock)(nil)
589+
590+
type testifyAssertTestingTMock struct {
591+
mock.Mock
592+
}
593+
594+
func (m *testifyAssertTestingTMock) Errorf(format string, args ...any) {
595+
if len(args) > 0 {
596+
m.Called(format, args)
597+
} else {
598+
m.Called(format)
599+
}
600+
}
601+
602+
var _ (require.TestingT) = (*testifyRequireTestingTMock)(nil)
603+
604+
type testifyRequireTestingTMock struct {
605+
mock.Mock
606+
}
607+
608+
func (m *testifyRequireTestingTMock) Errorf(format string, args ...any) {
609+
if len(args) > 0 {
610+
m.Called(format, args)
611+
} else {
612+
m.Called(format)
613+
}
614+
}
615+
616+
func (m *testifyRequireTestingTMock) FailNow() {
617+
m.Called()
618+
}

tst/mocks_test.go

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)