Skip to content

Commit 440072f

Browse files
warbertoddbaert
andauthored
fix: internal provider comparison causing race conditions in tests (#312)
fix: internal provider comparison causing race conditions in tests Signed-off-by: Bernd Warmuth <[email protected]> Co-authored-by: Todd Baert <[email protected]>
1 parent 890bfd0 commit 440072f

File tree

3 files changed

+106
-33
lines changed

3 files changed

+106
-33
lines changed

Diff for: openfeature/event_executor.go

+8-33
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package openfeature
22

33
import (
44
"fmt"
5-
"reflect"
65
"sync"
76
"time"
87

@@ -68,13 +67,6 @@ type eventPayload struct {
6867
handler FeatureProvider
6968
}
7069

71-
// providerReference is a helper struct to store FeatureProvider with EventHandler capability along with their
72-
// shutdown semaphore
73-
type providerReference struct {
74-
featureProvider FeatureProvider
75-
shutdownSemaphore chan interface{}
76-
}
77-
7870
// AddHandler adds an API(global) level handler
7971
func (e *eventExecutor) AddHandler(t EventType, c EventCallback) {
8072
e.mu.Lock()
@@ -217,14 +209,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error
217209
e.mu.Lock()
218210
defer e.mu.Unlock()
219211

220-
// register shutdown semaphore for new default provider
221-
sem := make(chan interface{})
222-
223-
newProvider := providerReference{
224-
featureProvider: provider,
225-
shutdownSemaphore: sem,
226-
}
227-
212+
newProvider := newProviderRef(provider)
228213
oldProvider := e.defaultProviderReference
229214
e.defaultProviderReference = newProvider
230215

@@ -235,14 +220,7 @@ func (e *eventExecutor) registerDefaultProvider(provider FeatureProvider) error
235220
func (e *eventExecutor) registerNamedEventingProvider(associatedClient string, provider FeatureProvider) error {
236221
e.mu.Lock()
237222
defer e.mu.Unlock()
238-
239-
// register shutdown semaphore for new named provider
240-
sem := make(chan interface{})
241-
242-
newProvider := providerReference{
243-
featureProvider: provider,
244-
shutdownSemaphore: sem,
245-
}
223+
newProvider := newProviderRef(provider)
246224

247225
oldProvider := e.namedProviderReference[associatedClient]
248226
e.namedProviderReference[associatedClient] = newProvider
@@ -288,7 +266,7 @@ func (e *eventExecutor) startListeningAndShutdownOld(newProvider providerReferen
288266

289267
// drop from active references
290268
for i, r := range e.activeSubscriptions {
291-
if reflect.DeepEqual(oldReference.featureProvider, r.featureProvider) {
269+
if oldReference.equals(r) {
292270
e.activeSubscriptions = append(e.activeSubscriptions[:i], e.activeSubscriptions[i+1:]...)
293271
}
294272
}
@@ -332,8 +310,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) {
332310

333311
// then run client handlers
334312
for domain, reference := range e.namedProviderReference {
335-
if !reflect.DeepEqual(reference.featureProvider, handler) {
336-
// unassociated client, continue to next
313+
if !reference.equals(newProviderRef(handler)) {
337314
continue
338315
}
339316

@@ -343,7 +320,7 @@ func (e *eventExecutor) triggerEvent(event Event, handler FeatureProvider) {
343320
}
344321
}
345322

346-
if !reflect.DeepEqual(e.defaultProviderReference.featureProvider, handler) {
323+
if !e.defaultProviderReference.equals(newProviderRef(handler)) {
347324
return
348325
}
349326

@@ -386,25 +363,23 @@ func (e *eventExecutor) executeHandler(f func(details EventDetails), event Event
386363
// isRunning is a helper till we bump to the latest go version with slices.contains support
387364
func isRunning(provider providerReference, activeProviders []providerReference) bool {
388365
for _, activeProvider := range activeProviders {
389-
if reflect.DeepEqual(activeProvider.featureProvider, provider.featureProvider) {
366+
if activeProvider.equals(provider) {
390367
return true
391368
}
392369
}
393-
394370
return false
395371
}
396372

397373
// isRunning is a helper to check if given provider is already in use
398374
func isBound(provider providerReference, defaultProvider providerReference, namedProviders []providerReference) bool {
399-
if reflect.DeepEqual(provider.featureProvider, defaultProvider.featureProvider) {
375+
if provider.equals(defaultProvider) {
400376
return true
401377
}
402378

403379
for _, namedProvider := range namedProviders {
404-
if reflect.DeepEqual(provider.featureProvider, namedProvider.featureProvider) {
380+
if provider.equals(namedProvider) {
405381
return true
406382
}
407383
}
408-
409384
return false
410385
}

Diff for: openfeature/reference.go

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package openfeature
2+
3+
import (
4+
"reflect"
5+
)
6+
7+
// newProviderRef creates a new providerReference instance that wraps around a FeatureProvider implementation
8+
func newProviderRef(provider FeatureProvider) providerReference {
9+
return providerReference{
10+
featureProvider: provider,
11+
kind: reflect.TypeOf(provider).Kind(),
12+
shutdownSemaphore: make(chan interface{}),
13+
}
14+
}
15+
16+
// providerReference is a helper struct to store FeatureProvider along with their
17+
// shutdown semaphore
18+
type providerReference struct {
19+
featureProvider FeatureProvider
20+
kind reflect.Kind
21+
shutdownSemaphore chan interface{}
22+
}
23+
24+
func (pr providerReference) equals(other providerReference) bool {
25+
if pr.kind == reflect.Ptr && other.kind == reflect.Ptr {
26+
return pr.featureProvider == other.featureProvider
27+
}
28+
return reflect.DeepEqual(pr.featureProvider, other.featureProvider)
29+
}

Diff for: openfeature/reference_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package openfeature
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestProviderReferenceEquals(t *testing.T) {
8+
9+
type myProvider struct {
10+
NoopProvider
11+
field string
12+
}
13+
14+
p1 := myProvider{}
15+
p2 := myProvider{}
16+
17+
tests := []struct {
18+
name string
19+
pr1 providerReference
20+
pr2 providerReference
21+
expected bool
22+
}{
23+
24+
{
25+
name: "both pointers, different instances",
26+
pr1: newProviderRef(&p1),
27+
pr2: newProviderRef(&p2),
28+
expected: false,
29+
},
30+
{
31+
name: "both pointers, same instance",
32+
pr1: newProviderRef(&p1),
33+
pr2: newProviderRef(&p1),
34+
expected: true,
35+
},
36+
{
37+
name: "different pointers, different instance",
38+
pr1: newProviderRef(p1),
39+
pr2: newProviderRef(&p1),
40+
expected: false,
41+
},
42+
{
43+
name: "no pointers, same instance",
44+
pr1: newProviderRef(p1),
45+
pr2: newProviderRef(p1),
46+
expected: true,
47+
},
48+
{
49+
name: "no pointers, different equal instances",
50+
pr1: newProviderRef(myProvider{field: "A"}),
51+
pr2: newProviderRef(myProvider{field: "A"}),
52+
expected: true,
53+
},
54+
{
55+
name: "no pointers, different not equal instances",
56+
pr1: newProviderRef(myProvider{field: "A"}),
57+
pr2: newProviderRef(myProvider{field: "B"}),
58+
expected: false,
59+
},
60+
}
61+
62+
for _, tt := range tests {
63+
t.Run(tt.name, func(t *testing.T) {
64+
if got := tt.pr1.equals(tt.pr2); got != tt.expected {
65+
t.Errorf("providerReference.equals() = %v, want %v", got, tt.expected)
66+
}
67+
})
68+
}
69+
}

0 commit comments

Comments
 (0)