Skip to content

Commit e16f7fe

Browse files
committed
feat: support method expression providers
1 parent 275e271 commit e16f7fe

8 files changed

Lines changed: 153 additions & 21 deletions

File tree

internal/wire/analyze.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ type call struct {
5252
pkg *types.Package
5353
name string
5454

55+
// methodExprRecv is the receiver type for a method expression provider.
56+
// It is nil for package-level function providers.
57+
methodExprRecv types.Type
58+
5559
// args is a list of arguments to call the provider with. Each element is:
5660
// a) one of the givens (args[i] < len(given)),
5761
// b) the result of a previous provider call (args[i] >= len(given))
@@ -196,16 +200,17 @@ dfs:
196200
}
197201
}
198202
calls = append(calls, call{
199-
kind: kind,
200-
pkg: p.Pkg,
201-
name: p.Name,
202-
args: args,
203-
varargs: p.Varargs,
204-
fieldNames: fieldNames,
205-
ins: ins,
206-
out: curr.t,
207-
hasCleanup: p.HasCleanup,
208-
hasErr: p.HasErr,
203+
kind: kind,
204+
pkg: p.Pkg,
205+
name: p.Name,
206+
methodExprRecv: p.MethodExprRecv,
207+
args: args,
208+
varargs: p.Varargs,
209+
fieldNames: fieldNames,
210+
ins: ins,
211+
out: curr.t,
212+
hasCleanup: p.HasCleanup,
213+
hasErr: p.HasErr,
209214
})
210215
case pv.IsValue():
211216
v := pv.Value()

internal/wire/parse.go

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ type Provider struct {
156156
// Name is the name of the Go object.
157157
Name string
158158

159+
// MethodExprRecv is the receiver type for a method expression provider.
160+
// It is nil for package-level function providers.
161+
MethodExprRecv types.Type
162+
159163
// Pos is the source position of the func keyword or type spec
160164
// defining this provider.
161165
Pos token.Pos
@@ -718,6 +722,12 @@ func valueSpecForVar(fset *token.FileSet, files []*ast.File, obj *types.Var) *as
718722
func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) {
719723
exprPos := oc.fset.Position(expr.Pos())
720724
expr = astutil.Unparen(expr)
725+
if sel, ok := expr.(*ast.SelectorExpr); ok {
726+
if selInfo := info.Selections[sel]; selInfo != nil && selInfo.Kind() == types.MethodExpr {
727+
p, errs := processMethodExprProvider(oc.fset, info, sel, selInfo)
728+
return p, notePositionAll(exprPos, errs)
729+
}
730+
}
721731
if obj := qualifiedIdentObject(info, expr); obj != nil {
722732
item, errs := oc.get(obj)
723733
return item, mapErrors(errs, func(err error) error {
@@ -877,28 +887,45 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object {
877887
func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []error) {
878888
sig := fn.Type().(*types.Signature)
879889
fpos := fn.Pos()
890+
return processProviderSignature(fset, fn.Pkg(), fn.Name(), nil, fpos, sig)
891+
}
892+
893+
func processMethodExprProvider(fset *token.FileSet, info *types.Info, expr *ast.SelectorExpr, sel *types.Selection) (*Provider, []error) {
894+
obj, ok := sel.Obj().(*types.Func)
895+
if !ok {
896+
return nil, []error{fmt.Errorf("%s is not a function", expr.Sel.Name)}
897+
}
898+
sig, ok := info.TypeOf(expr).(*types.Signature)
899+
if !ok {
900+
return nil, []error{fmt.Errorf("method expression %s does not have a function signature", expr.Sel.Name)}
901+
}
902+
return processProviderSignature(fset, obj.Pkg(), obj.Name(), sel.Recv(), expr.Pos(), sig)
903+
}
904+
905+
func processProviderSignature(fset *token.FileSet, pkg *types.Package, name string, recv types.Type, pos token.Pos, sig *types.Signature) (*Provider, []error) {
880906
providerSig, err := funcOutput(sig)
881907
if err != nil {
882-
return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("wrong signature for provider %s: %v", fn.Name(), err))}
908+
return nil, []error{notePosition(fset.Position(pos), fmt.Errorf("wrong signature for provider %s: %v", name, err))}
883909
}
884910
params := sig.Params()
885911
provider := &Provider{
886-
Pkg: fn.Pkg(),
887-
Name: fn.Name(),
888-
Pos: fn.Pos(),
889-
Args: make([]ProviderInput, params.Len()),
890-
Varargs: sig.Variadic(),
891-
Out: []types.Type{providerSig.out},
892-
HasCleanup: providerSig.cleanup,
893-
HasErr: providerSig.err,
912+
Pkg: pkg,
913+
Name: name,
914+
MethodExprRecv: recv,
915+
Pos: pos,
916+
Args: make([]ProviderInput, params.Len()),
917+
Varargs: sig.Variadic(),
918+
Out: []types.Type{providerSig.out},
919+
HasCleanup: providerSig.cleanup,
920+
HasErr: providerSig.err,
894921
}
895922
for i := 0; i < params.Len(); i++ {
896923
provider.Args[i] = ProviderInput{
897924
Type: params.At(i).Type(),
898925
}
899926
for j := 0; j < i; j++ {
900927
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
901-
return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("provider has multiple parameters of type %s", types.TypeString(provider.Args[j].Type, nil)))}
928+
return nil, []error{notePosition(fset.Position(pos), fmt.Errorf("provider has multiple parameters of type %s", types.TypeString(provider.Args[j].Type, nil)))}
902929
}
903930
}
904931
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import "fmt"
18+
19+
func main() {
20+
db, err := initDB()
21+
if err != nil {
22+
panic(err)
23+
}
24+
fmt.Println(db.DSN)
25+
}
26+
27+
type Options struct {
28+
DSN string
29+
}
30+
31+
type DB struct {
32+
DSN string
33+
}
34+
35+
func provideOptions() *Options {
36+
return &Options{DSN: "postgres://wire"}
37+
}
38+
39+
func (o *Options) ToDB() (*DB, error) {
40+
return &DB{DSN: o.DSN}, nil
41+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build wireinject
16+
// +build wireinject
17+
18+
package main
19+
20+
import "github.com/goforj/wire"
21+
22+
func initDB() (*DB, error) {
23+
wire.Build(
24+
provideOptions,
25+
(*Options).ToDB,
26+
)
27+
return nil, nil
28+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
example.com/foo
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
postgres://wire

internal/wire/testdata/MethodExprProvider/want/wire_gen.go

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/wire/wire.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputS
720720
ig.p(", %s", ig.errVar)
721721
}
722722
ig.p(" := ")
723-
ig.p("%s(", ig.g.qualifiedID(c.pkg.Name(), c.pkg.Path(), c.name))
723+
ig.p("%s(", ig.funcProviderExpr(c))
724724
for i, a := range c.args {
725725
if i > 0 {
726726
ig.p(", ")
@@ -750,6 +750,17 @@ func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputS
750750
}
751751
}
752752

753+
func (ig *injectorGen) funcProviderExpr(c *call) string {
754+
if c.methodExprRecv == nil {
755+
return ig.g.qualifiedID(c.pkg.Name(), c.pkg.Path(), c.name)
756+
}
757+
recv := types.TypeString(c.methodExprRecv, ig.g.qualifyPkg)
758+
if _, ok := c.methodExprRecv.(*types.Pointer); ok {
759+
recv = "(" + recv + ")"
760+
}
761+
return recv + "." + c.name
762+
}
763+
753764
func (ig *injectorGen) structProviderCall(lname string, c *call) {
754765
ig.p("\t%s", lname)
755766
ig.p(" := ")

0 commit comments

Comments
 (0)