Skip to content

Commit a5727b7

Browse files
committed
feat(subtract): added Subtract function to wire
Signed-off-by: Giau. Tran Minh <[email protected]>
1 parent 0675cdc commit a5727b7

File tree

7 files changed

+279
-21
lines changed

7 files changed

+279
-21
lines changed

Diff for: internal/wire/parse.go

+114-3
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546
case "NewSet":
547547
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
548548
return pset, notePositionAll(exprPos, errs)
549+
case "Subtract":
550+
pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName)
551+
return pset, notePositionAll(exprPos, errs)
549552
case "Bind":
550553
b, err := processBind(oc.fset, info, call)
551554
if err != nil {
@@ -590,6 +593,114 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
590593
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
591594
}
592595

596+
func (oc *objectCache) filterType(s *ProviderSet, st types.Type) []error {
597+
hasType := func(outs []types.Type) bool {
598+
for _, o := range outs {
599+
if types.Identical(o, st) {
600+
return true
601+
}
602+
pt, ok := o.(*types.Pointer)
603+
if ok && types.Identical(pt.Elem(), st) {
604+
return true
605+
}
606+
}
607+
return false
608+
}
609+
providers := make([]*Provider, 0, len(s.Providers))
610+
for _, p := range s.Providers {
611+
if !hasType(p.Out) {
612+
providers = append(providers, p)
613+
}
614+
}
615+
s.Providers = providers
616+
617+
bindings := make([]*IfaceBinding, 0, len(s.Bindings))
618+
for _, i := range s.Bindings {
619+
if !types.Identical(i.Iface, st) {
620+
bindings = append(bindings, i)
621+
}
622+
}
623+
s.Bindings = bindings
624+
625+
values := make([]*Value, 0, len(s.Values))
626+
for _, v := range s.Values {
627+
if !types.Identical(v.Out, st) {
628+
values = append(values, v)
629+
}
630+
}
631+
s.Values = values
632+
633+
fields := make([]*Field, 0, len(s.Fields))
634+
for _, f := range s.Fields {
635+
if !hasType(f.Out) {
636+
fields = append(fields, f)
637+
}
638+
}
639+
s.Fields = fields
640+
641+
imports := make([]*ProviderSet, 0, len(s.Imports))
642+
for _, p := range s.Imports {
643+
clone := *p
644+
if errs := oc.filterType(&clone, st); len(errs) > 0 {
645+
return errs
646+
}
647+
imports = append(imports, &clone)
648+
}
649+
s.Imports = imports
650+
651+
var errs []error
652+
s.providerMap, s.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, s)
653+
if len(errs) > 0 {
654+
return errs
655+
}
656+
return nil
657+
}
658+
659+
func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) {
660+
// Assumes that call.Fun is wire.Subtract.
661+
if len(call.Args) < 2 {
662+
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
663+
errors.New("call to Subtract must specify types to be subtracted"))}
664+
}
665+
firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "")
666+
if len(errs) > 0 {
667+
return nil, errs
668+
}
669+
set, ok := firstArg.(*ProviderSet)
670+
if !ok {
671+
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
672+
fmt.Errorf("first argument to Subtract must be a Set")),
673+
}
674+
}
675+
pset := &ProviderSet{
676+
Pos: call.Pos(),
677+
InjectorArgs: args,
678+
PkgPath: pkgPath,
679+
VarName: varName,
680+
// Copy the other fields.
681+
Providers: set.Providers,
682+
Bindings: set.Bindings,
683+
Values: set.Values,
684+
Fields: set.Fields,
685+
Imports: set.Imports,
686+
}
687+
ec := new(errorCollector)
688+
for _, arg := range call.Args[1:] {
689+
ptr, ok := info.TypeOf(arg).(*types.Pointer)
690+
if !ok {
691+
ec.add(notePosition(oc.fset.Position(arg.Pos()),
692+
errors.New("argument to Subtract must be a pointer"),
693+
))
694+
continue
695+
}
696+
ec.add(oc.filterType(pset, ptr.Elem())...)
697+
}
698+
if len(ec.errors) > 0 {
699+
return nil, ec.errors
700+
}
701+
return pset, nil
702+
}
703+
593704
func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) {
594705
// Assumes that call.Fun is wire.NewSet or wire.Build.
595706

@@ -1173,9 +1284,9 @@ func (pt ProvidedType) IsNil() bool {
11731284
//
11741285
// - For a function provider, this is the first return value type.
11751286
// - For a struct provider, this is either the struct type or the pointer type
1176-
// whose element type is the struct type.
1177-
// - For a value, this is the type of the expression.
1178-
// - For an argument, this is the type of the argument.
1287+
// whose element type is the struct type.
1288+
// - For a value, this is the type of the expression.
1289+
// - For an argument, this is the type of the argument.
11791290
func (pt ProvidedType) Type() types.Type {
11801291
return pt.t
11811292
}

Diff for: internal/wire/testdata/Subtract/foo/foo.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"github.com/google/wire"
19+
)
20+
21+
type context struct{}
22+
23+
func main() {}
24+
25+
type FooOptions struct{}
26+
type Foo string
27+
type Bar struct{}
28+
type BarName string
29+
30+
func (b *Bar) Bar() {}
31+
32+
func provideFooOptions() *FooOptions {
33+
return &FooOptions{}
34+
}
35+
36+
func provideFoo(*FooOptions) Foo {
37+
return Foo("foo")
38+
}
39+
40+
func provideBar(Foo, BarName) *Bar {
41+
return &Bar{}
42+
}
43+
44+
type BarService interface {
45+
Bar()
46+
}
47+
48+
type FooBar struct {
49+
BarService
50+
Foo
51+
}
52+
53+
var Set = wire.NewSet(
54+
provideFooOptions,
55+
provideFoo,
56+
provideBar,
57+
)
58+
59+
var SuperSet = wire.NewSet(Set,
60+
wire.Struct(new(FooBar), "*"),
61+
wire.Bind(new(BarService), new(*Bar)),
62+
)
63+
64+
type FakeBarService struct{}
65+
66+
func (f *FakeBarService) Bar() {}

Diff for: internal/wire/testdata/Subtract/foo/wire.go

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build wireinject
16+
// +build wireinject
17+
18+
package main
19+
20+
import (
21+
// "strings"
22+
23+
"github.com/google/wire"
24+
)
25+
26+
func inject(name BarName, opts *FooOptions) *Bar {
27+
panic(wire.Build(wire.Subtract(Set, new(FooOptions))))
28+
}
29+
30+
func injectBarService(name BarName, opts *FakeBarService) *FooBar {
31+
panic(wire.Build(
32+
wire.Subtract(SuperSet, new(BarService)),
33+
wire.Bind(new(BarService), new(*FakeBarService)),
34+
))
35+
}
36+
37+
func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar {
38+
panic(wire.Build(
39+
wire.Subtract(SuperSet, new(FooOptions), new(BarService)),
40+
wire.Bind(new(BarService), new(*FakeBarService)),
41+
))
42+
}

Diff for: internal/wire/testdata/Subtract/pkg

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
example.com/foo

Diff for: internal/wire/testdata/Subtract/want/program_out.txt

Whitespace-only changes.

Diff for: internal/wire/testdata/Subtract/want/wire_gen.go

+34
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: wire.go

+22-18
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func NewSet(...interface{}) ProviderSet {
5959
return ProviderSet{}
6060
}
6161

62+
func Subtract(...interface{}) ProviderSet {
63+
return ProviderSet{}
64+
}
65+
6266
// Build is placed in the body of an injector function template to declare the
6367
// providers to use. The Wire code generation tool will fill in an
6468
// implementation of the function. The arguments to Build are interpreted the
@@ -156,12 +160,12 @@ type StructProvider struct{}
156160
//
157161
// For example:
158162
//
159-
// type S struct {
160-
// MyFoo *Foo
161-
// MyBar *Bar
162-
// }
163-
// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo
164-
// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields
163+
// type S struct {
164+
// MyFoo *Foo
165+
// MyBar *Bar
166+
// }
167+
// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo
168+
// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields
165169
func Struct(structType interface{}, fieldNames ...string) StructProvider {
166170
return StructProvider{}
167171
}
@@ -175,22 +179,22 @@ type StructFields struct{}
175179
//
176180
// The following example would provide Foo and Bar using S.MyFoo and S.MyBar respectively:
177181
//
178-
// type S struct {
179-
// MyFoo Foo
180-
// MyBar Bar
181-
// }
182+
// type S struct {
183+
// MyFoo Foo
184+
// MyBar Bar
185+
// }
182186
//
183-
// func NewStruct() S { /* ... */ }
184-
// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar"))
187+
// func NewStruct() S { /* ... */ }
188+
// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar"))
185189
//
186-
// or
190+
// or
187191
//
188-
// func NewStruct() *S { /* ... */ }
189-
// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar"))
192+
// func NewStruct() *S { /* ... */ }
193+
// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar"))
190194
//
191-
// If the structType argument is a pointer to a pointer to a struct, then FieldsOf
192-
// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the
193-
// example above).
195+
// If the structType argument is a pointer to a pointer to a struct, then FieldsOf
196+
// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the
197+
// example above).
194198
func FieldsOf(structType interface{}, fieldNames ...string) StructFields {
195199
return StructFields{}
196200
}

0 commit comments

Comments
 (0)