diff --git a/dig.go b/dig.go index 7e76b559..131612b1 100644 --- a/dig.go +++ b/dig.go @@ -476,7 +476,7 @@ func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) erro // // The function may return an error to indicate failure. The error will be // returned to the caller as-is. -func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { +func (c *Container) Invoke(function interface{}, providedParams ...interface{}) error { ftype := reflect.TypeOf(function) if ftype == nil { return errors.New("can't invoke an untyped nil") @@ -485,7 +485,7 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return fmt.Errorf("can't invoke non-function %v (type %v)", function, ftype) } - pl, err := newParamList(ftype) + pl, err := newParamList(ftype, providedParams...) if err != nil { return err } diff --git a/param.go b/param.go index 2464d9fe..1af8f893 100644 --- a/param.go +++ b/param.go @@ -53,6 +53,7 @@ type param interface { } var ( + _ param = paramProvided{} _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} @@ -120,7 +121,7 @@ func walkParam(p param, v paramVisitor) { } switch par := p.(type) { - case paramSingle, paramGroupedSlice: + case paramSingle, paramGroupedSlice, paramProvided: // No sub-results case paramObject: for _, f := range par.Fields { @@ -161,7 +162,7 @@ func (pl paramList) DotParam() []*dot.Param { // // Variadic arguments of a constructor are ignored and not included as // dependencies. -func newParamList(ctype reflect.Type) (paramList, error) { +func newParamList(ctype reflect.Type, providedParams ...interface{}) (paramList, error) { numArgs := ctype.NumIn() if ctype.IsVariadic() { // NOTE: If the function is variadic, we skip the last argument @@ -175,11 +176,17 @@ func newParamList(ctype reflect.Type) (paramList, error) { } for i := 0; i < numArgs; i++ { - p, err := newParam(ctype.In(i)) - if err != nil { - return pl, errWrapf(err, "bad argument %d", i+1) + if i < len(providedParams) { + pl.Params = append(pl.Params, paramProvided{Param: providedParams[i]}) + + } else { + p, err := newParam(ctype.In(i)) + if err != nil { + return pl, errWrapf(err, "bad argument %d", i+1) + } + + pl.Params = append(pl.Params, p) } - pl.Params = append(pl.Params, p) } return pl, nil @@ -203,6 +210,7 @@ func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { return nil, err } } + return args, nil } @@ -452,3 +460,28 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } return result, nil } + +// paramProvided is an passed in param. +type paramProvided struct { + Name string + Optional bool + Type reflect.Type + Param interface{} +} + +func (pp paramProvided) DotParam() []*dot.Param { + return []*dot.Param{ + { + Node: &dot.Node{ + Type: pp.Type, + Name: pp.Name, + }, + Optional: pp.Optional, + }, + } +} + +func (pp paramProvided) Build(c containerStore) (reflect.Value, error) { + return reflect.ValueOf(pp.Param), nil +} + diff --git a/stringer.go b/stringer.go index a8ee1f38..56c21dfb 100644 --- a/stringer.go +++ b/stringer.go @@ -104,3 +104,7 @@ func (pt paramGroupedSlice) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) } + +func (pp paramProvided) String() string { + return fmt.Sprintf("%v[%v]", pp.Type, pp.Param) +}