@@ -68,6 +68,7 @@ func run(p *analysis.Pass) (any, error) {
6868 pass .checkStringLenCompare (n )
6969 case * ast.FuncDecl :
7070 pass .checkFuncArgs (n )
71+ pass .checkContextArgs (n )
7172 case * ast.CallExpr :
7273 pass .checkFlagDefinition (n )
7374 pass .checkLogErrorFormat (n )
@@ -229,6 +230,43 @@ func (pass *Pass) reportFuncArgs(fields []*ast.Field, first, last int) {
229230 pass .report (fields [first ], "Use '%v %v'" , names [2 :], pass .typ (fields [first ].Type ))
230231}
231232
233+ func (pass * Pass ) checkContextArgs (n * ast.FuncDecl ) {
234+ if n .Type .Params == nil {
235+ return
236+ }
237+ expectedArgPos := 0
238+ if len (n .Type .Params .List ) > 0 {
239+ field := n .Type .Params .List [0 ]
240+ if strings .HasSuffix (pass .typ (field .Type ), "*testing.T" ) {
241+ expectedArgPos = 1
242+ }
243+ }
244+ argPos := 0
245+ for _ , field := range n .Type .Params .List {
246+ isContext := pass .typ (field .Type ) == "context.Context"
247+ if len (field .Names ) == 0 {
248+ if isContext {
249+ if argPos != expectedArgPos {
250+ pass .report (field , "Context must be the first argument" )
251+ }
252+ pass .report (field , "Context variable must be named 'ctx'" )
253+ }
254+ argPos ++
255+ }
256+ for _ , name := range field .Names {
257+ if isContext {
258+ if argPos != expectedArgPos {
259+ pass .report (name , "Context must be the first argument" )
260+ }
261+ if name .Name != "ctx" && name .Name != "_" {
262+ pass .report (name , "Context variable must be named 'ctx'" )
263+ }
264+ }
265+ argPos ++
266+ }
267+ }
268+ }
269+
232270func (pass * Pass ) checkFlagDefinition (n * ast.CallExpr ) {
233271 fun , ok := n .Fun .(* ast.SelectorExpr )
234272 if ! ok {
0 commit comments