Skip to content

Commit ac72975

Browse files
committed
fix: prefer drivers in DI
1 parent 3f7fc7e commit ac72975

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

dependency.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ func newDefaultDriver(args DriverArgs) (Driver, error) {
201201
if err := args.Populator.Populate(&injected); err != nil {
202202
return nil, fmt.Errorf("missing dependency for the default queue driver: %w", err)
203203
}
204+
driver, err := driverFromDI(args.Populator)
205+
if err != nil {
206+
return nil, fmt.Errorf("error fetching default driver from DI: %w", err)
207+
}
208+
if driver != nil {
209+
return driver, nil
210+
}
204211
var redisName string
205212
if err := injected.ConfigUnmarshaler.Unmarshal(fmt.Sprintf("queue.%s.redisName", injected.AppName), &redisName); err != nil {
206213
return nil, fmt.Errorf("bad configuration: %w", err)
@@ -257,3 +264,15 @@ func provideConfig() configOut {
257264
}}
258265
return configOut{Config: configs}
259266
}
267+
268+
func driverFromDI(populator contract.DIPopulator) (Driver, error) {
269+
var injected struct {
270+
di.In
271+
Driver `optional:"true"`
272+
}
273+
err := populator.Populate(&injected)
274+
if err != nil {
275+
return nil, err
276+
}
277+
return injected.Driver, nil
278+
}

dependency_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,23 @@ func TestProvideConfigs(t *testing.T) {
133133
c := provideConfig()
134134
assert.NotEmpty(t, c.Config)
135135
}
136+
137+
type driverPopulator struct{}
138+
139+
func (d driverPopulator) Populate(target interface{}) error {
140+
graph := di.NewGraph()
141+
graph.Provide(func() Driver {
142+
return mockDriver{}
143+
})
144+
di.IntoPopulator(graph).Populate(target)
145+
return nil
146+
}
147+
148+
func TestDriverFromDI(t *testing.T) {
149+
driver, err := newDefaultDriver(DriverArgs{
150+
Name: "",
151+
Populator: driverPopulator{},
152+
})
153+
assert.NoError(t, err)
154+
assert.IsType(t, mockDriver{}, driver)
155+
}

0 commit comments

Comments
 (0)