Skip to content

Commit

Permalink
Merge pull request #260 from okhowang/fix/loadable
Browse files Browse the repository at this point in the history
fix(loadable): cache value in setChannel
  • Loading branch information
eko authored Jan 8, 2025
2 parents 33d992d + acdc0e8 commit 70d4931
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/cache/loadable.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type LoadableCache[T any] struct {
loadFunc LoadFunction[T]
cache CacheInterface[T]
setChannel chan *loadableKeyValue[T]
setCache sync.Map
setterWg *sync.WaitGroup
}

Expand Down Expand Up @@ -55,6 +56,7 @@ func (c *LoadableCache[T]) setter() {

cacheKey := c.getCacheKey(item.key)
c.singleFlight.Forget(cacheKey)
c.setCache.Delete(cacheKey)
}
}

Expand All @@ -69,6 +71,9 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) {

// Unable to find in cache, try to load it from load function
cacheKey := c.getCacheKey(key)
if v, ok := c.setCache.Load(cacheKey); ok {
return v.(T), nil
}
zero := *new(T)

loadedResult, err, _ := c.singleFlight.Do(
Expand All @@ -89,6 +94,7 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) {
}

// Then, put it back in cache
c.setCache.Store(cacheKey, object)
c.setChannel <- &loadableKeyValue[T]{key, object}

return object, err
Expand Down
27 changes: 27 additions & 0 deletions lib/cache/loadable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
Expand Down Expand Up @@ -299,3 +300,29 @@ func TestLoadableGetType(t *testing.T) {
// When - Then
assert.Equal(t, LoadableType, cache.GetType())
}

func TestLoadableGetTwice(t *testing.T) {
// Given
ctrl := gomock.NewController(t)

cache1 := NewMockSetterCacheInterface[any](ctrl)

var counter atomic.Uint64
loadFunc := func(_ context.Context, key any) (any, error) {
return counter.Add(1), nil
}

cache := NewLoadable[any](loadFunc, cache1)

key := 1
cache1.EXPECT().Get(context.Background(), key).Return(nil, store.NotFound{}).Times(2)
cache1.EXPECT().Set(context.Background(), key, uint64(1)).Times(1)
v1, err1 := cache.Get(context.Background(), key)
v2, err2 := cache.Get(context.Background(), key) // setter may not be called now because it's done by another goroutine
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.Equal(t, uint64(1), v1)
assert.Equal(t, uint64(1), v2)
assert.Equal(t, uint64(1), counter.Load())
_ = cache.Close() // wait for setter
}

0 comments on commit 70d4931

Please sign in to comment.