Skip to content

Commit eddb08e

Browse files
authored
fixes: support assert.CollectT (#233)
1 parent c3bdbc7 commit eddb08e

File tree

5 files changed

+81
-2
lines changed

5 files changed

+81
-2
lines changed

analyzer/analyzer_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ func TestTestifyLint_NotDefaultCases(t *testing.T) {
7676
dir: "error-as-target",
7777
flags: map[string]string{"disable-all": "true", "enable": checkers.NewErrorIsAs().Name()},
7878
},
79+
{
80+
dir: "error-is-as-issue231",
81+
flags: map[string]string{"enable-all": "true"},
82+
},
7983
{
8084
dir: "error-nil-issue95",
8185
flags: map[string]string{"disable-all": "true", "enable": checkers.NewErrorNil().Name()},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package errorisasissue231
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestEventuallyAsserts(t *testing.T) {
14+
require.EventuallyWithT(t,
15+
func(c *assert.CollectT) {
16+
_, err := strconv.Atoi("a")
17+
if err != nil {
18+
c.Errorf("failed: %v", err)
19+
20+
c.Errorf(fmt.Sprintf("failed: %v", err)) // want "formatter: remove unnecessary fmt\\.Sprintf"
21+
c.Errorf("failed: %v") // want "formatter: c\\.Errorf format %v reads arg #1, but call has 0 args"
22+
return
23+
}
24+
},
25+
time.Second,
26+
time.Millisecond,
27+
)
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package errorisasissue231
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestEventuallyAsserts(t *testing.T) {
14+
require.EventuallyWithT(t,
15+
func(c *assert.CollectT) {
16+
_, err := strconv.Atoi("a")
17+
if err != nil {
18+
c.Errorf("failed: %v", err)
19+
20+
c.Errorf("failed: %v", err) // want "formatter: remove unnecessary fmt\\.Sprintf"
21+
c.Errorf("failed: %v") // want "formatter: c\\.Errorf format %v reads arg #1, but call has 0 args"
22+
return
23+
}
24+
},
25+
time.Second,
26+
time.Millisecond,
27+
)
28+
}

internal/checkers/error_is_as.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"go/types"
77

88
"golang.org/x/tools/go/analysis"
9+
10+
"github.com/Antonboom/testifylint/internal/analysisutil"
11+
"github.com/Antonboom/testifylint/internal/testify"
912
)
1013

1114
// ErrorIsAs detects situations like
@@ -34,7 +37,7 @@ func (ErrorIsAs) Name() string { return "error-is-as" }
3437
func (checker ErrorIsAs) Check(pass *analysis.Pass, call *CallMeta) *analysis.Diagnostic {
3538
switch call.Fn.NameFTrimmed {
3639
case "Error":
37-
if len(call.Args) >= 2 && isError(pass, call.Args[1]) {
40+
if len(call.Args) >= 2 && isError(pass, call.Args[1]) && !isAssertCollectT(pass, call.Selector.X) {
3841
const proposed = "ErrorIs"
3942
msg := fmt.Sprintf("invalid usage of %[1]s.Error, use %[1]s.%[2]s instead", call.SelectorXStr, proposed)
4043
return newDiagnostic(checker.Name(), call, msg, newSuggestedFuncReplacement(call, proposed))
@@ -146,3 +149,18 @@ func (checker ErrorIsAs) Check(pass *analysis.Pass, call *CallMeta) *analysis.Di
146149
}
147150
return nil
148151
}
152+
153+
func isAssertCollectT(pass *analysis.Pass, e ast.Expr) bool {
154+
ptr, ok := pass.TypesInfo.TypeOf(e).(*types.Pointer)
155+
if !ok {
156+
return false
157+
}
158+
159+
named, ok := ptr.Elem().(*types.Named)
160+
if !ok {
161+
return false
162+
}
163+
164+
collectT := analysisutil.ObjectOf(pass.Pkg, testify.AssertPkgPath, "CollectT")
165+
return named.Obj() == collectT
166+
}

internal/checkers/formatter.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ func getMsgPosition(sig *types.Signature) int {
259259
for i := 0; i < sig.Params().Len(); i++ {
260260
param := sig.Params().At(i)
261261

262-
if b, ok := param.Type().(*types.Basic); ok && b.Kind() == types.String && param.Name() == "msg" {
262+
if b, ok := param.Type().(*types.Basic); ok && b.Kind() == types.String && (param.Name() == "msg" ||
263+
param.Name() == "format") { // NOTE(a.telyshev): assert.CollectT case.
263264
return i
264265
}
265266
}

0 commit comments

Comments
 (0)