Skip to content

Commit 70d4931

Browse files
authored
Merge pull request #260 from okhowang/fix/loadable
fix(loadable): cache value in setChannel
2 parents 33d992d + acdc0e8 commit 70d4931

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

lib/cache/loadable.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type LoadableCache[T any] struct {
2828
loadFunc LoadFunction[T]
2929
cache CacheInterface[T]
3030
setChannel chan *loadableKeyValue[T]
31+
setCache sync.Map
3132
setterWg *sync.WaitGroup
3233
}
3334

@@ -55,6 +56,7 @@ func (c *LoadableCache[T]) setter() {
5556

5657
cacheKey := c.getCacheKey(item.key)
5758
c.singleFlight.Forget(cacheKey)
59+
c.setCache.Delete(cacheKey)
5860
}
5961
}
6062

@@ -69,6 +71,9 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) {
6971

7072
// Unable to find in cache, try to load it from load function
7173
cacheKey := c.getCacheKey(key)
74+
if v, ok := c.setCache.Load(cacheKey); ok {
75+
return v.(T), nil
76+
}
7277
zero := *new(T)
7378

7479
loadedResult, err, _ := c.singleFlight.Do(
@@ -89,6 +94,7 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) {
8994
}
9095

9196
// Then, put it back in cache
97+
c.setCache.Store(cacheKey, object)
9298
c.setChannel <- &loadableKeyValue[T]{key, object}
9399

94100
return object, err

lib/cache/loadable_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99
"time"
1010

11+
"github.com/eko/gocache/lib/v4/store"
1112
"github.com/stretchr/testify/assert"
1213
"go.uber.org/mock/gomock"
1314
)
@@ -299,3 +300,29 @@ func TestLoadableGetType(t *testing.T) {
299300
// When - Then
300301
assert.Equal(t, LoadableType, cache.GetType())
301302
}
303+
304+
func TestLoadableGetTwice(t *testing.T) {
305+
// Given
306+
ctrl := gomock.NewController(t)
307+
308+
cache1 := NewMockSetterCacheInterface[any](ctrl)
309+
310+
var counter atomic.Uint64
311+
loadFunc := func(_ context.Context, key any) (any, error) {
312+
return counter.Add(1), nil
313+
}
314+
315+
cache := NewLoadable[any](loadFunc, cache1)
316+
317+
key := 1
318+
cache1.EXPECT().Get(context.Background(), key).Return(nil, store.NotFound{}).Times(2)
319+
cache1.EXPECT().Set(context.Background(), key, uint64(1)).Times(1)
320+
v1, err1 := cache.Get(context.Background(), key)
321+
v2, err2 := cache.Get(context.Background(), key) // setter may not be called now because it's done by another goroutine
322+
assert.NoError(t, err1)
323+
assert.NoError(t, err2)
324+
assert.Equal(t, uint64(1), v1)
325+
assert.Equal(t, uint64(1), v2)
326+
assert.Equal(t, uint64(1), counter.Load())
327+
_ = cache.Close() // wait for setter
328+
}

0 commit comments

Comments
 (0)