diff --git a/cache.go b/cache.go index db88d2f..57b1980 100644 --- a/cache.go +++ b/cache.go @@ -13,13 +13,19 @@ import ( type Item struct { Object interface{} Expiration int64 + mu sync.RWMutex } // Returns true if the item has expired. -func (item Item) Expired() bool { +func (item *Item) Expired() bool { + item.mu.RLock() + if item.Expiration == 0 { + item.mu.RUnlock() return false } + + item.mu.RUnlock() return time.Now().UnixNano() > item.Expiration } @@ -39,7 +45,7 @@ type Cache struct { type cache struct { defaultExpiration time.Duration - items map[string]Item + items map[string]*Item mu sync.RWMutex onEvicted func(string, interface{}) janitor *janitor @@ -58,7 +64,7 @@ func (c *cache) Set(k string, x interface{}, d time.Duration) { e = time.Now().Add(d).UnixNano() } c.mu.Lock() - c.items[k] = Item{ + c.items[k] = &Item{ Object: x, Expiration: e, } @@ -75,7 +81,7 @@ func (c *cache) set(k string, x interface{}, d time.Duration) { if d > 0 { e = time.Now().Add(d).UnixNano() } - c.items[k] = Item{ + c.items[k] = &Item{ Object: x, Expiration: e, } @@ -119,18 +125,19 @@ func (c *cache) Replace(k string, x interface{}, d time.Duration) error { // whether the key was found. func (c *cache) Get(k string) (interface{}, bool) { c.mu.RLock() + // "Inlining" of get and Expired item, found := c.items[k] if !found { c.mu.RUnlock() return nil, false } - if item.Expiration > 0 { - if time.Now().UnixNano() > item.Expiration { - c.mu.RUnlock() - return nil, false - } + + if item.Expired() { + c.mu.RUnlock() + return nil, false } + c.mu.RUnlock() return item.Object, true } @@ -141,6 +148,7 @@ func (c *cache) Get(k string) (interface{}, bool) { // whether the key was found. func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { c.mu.RLock() + // "Inlining" of get and Expired item, found := c.items[k] if !found { @@ -148,34 +156,73 @@ func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { return nil, time.Time{}, false } + item.mu.RLock() + if item.Expiration > 0 { if time.Now().UnixNano() > item.Expiration { + item.mu.RUnlock() c.mu.RUnlock() return nil, time.Time{}, false } // Return the item and the expiration time + item.mu.RUnlock() c.mu.RUnlock() return item.Object, time.Unix(0, item.Expiration), true } // If expiration <= 0 (i.e. no expiration time set) then return the item // and a zeroed time.Time - c.mu.RUnlock() return item.Object, time.Time{}, true } -func (c *cache) get(k string) (interface{}, bool) { +// GetWithExpirationUpdate returns item and updates its cache expiration time +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpirationUpdate(k string, d time.Duration) (interface{}, bool) { + c.mu.RLock() + item, found := c.items[k] if !found { + c.mu.RUnlock() return nil, false } - // "Inlining" of Expired + + // Don't call item.Expired() here since + // we write lock item.Expiration + item.mu.Lock() + if item.Expiration > 0 { if time.Now().UnixNano() > item.Expiration { + item.mu.Unlock() + c.mu.RUnlock() return nil, false } } + + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + c.items[k].Expiration = time.Now().Add(d).UnixNano() + } + + item.mu.Unlock() + c.mu.RUnlock() + return item.Object, true +} + +func (c *cache) get(k string) (interface{}, bool) { + item, found := c.items[k] + if !found { + return nil, false + } + // "Inlining" of Expired + if item.Expired() { + return nil, false + } + return item.Object, true } @@ -968,11 +1015,13 @@ func (c *cache) Save(w io.Writer) (err error) { } }() c.mu.RLock() - defer c.mu.RUnlock() + for _, v := range c.items { gob.Register(v.Object) } err = enc.Encode(&c.items) + + c.mu.RUnlock() return } @@ -1001,11 +1050,11 @@ func (c *cache) SaveFile(fname string) error { // documentation for NewFrom().) func (c *cache) Load(r io.Reader) error { dec := gob.NewDecoder(r) - items := map[string]Item{} + items := map[string]*Item{} err := dec.Decode(&items) if err == nil { c.mu.Lock() - defer c.mu.Unlock() + for k, v := range items { ov, found := c.items[k] if !found || ov.Expired() { @@ -1013,6 +1062,7 @@ func (c *cache) Load(r io.Reader) error { } } } + c.mu.Unlock() return err } @@ -1035,20 +1085,22 @@ func (c *cache) LoadFile(fname string) error { } // Copies all unexpired items in the cache into a new map and returns it. -func (c *cache) Items() map[string]Item { +func (c *cache) Items() map[string]*Item { c.mu.RLock() - defer c.mu.RUnlock() - m := make(map[string]Item, len(c.items)) - now := time.Now().UnixNano() + + m := make(map[string]*Item, len(c.items)) for k, v := range c.items { // "Inlining" of Expired - if v.Expiration > 0 { - if now > v.Expiration { - continue - } + if v.Expired() { + continue + } + m[k] = &Item{ + Object: v.Object, + Expiration: v.Expiration, } - m[k] = v } + + c.mu.RUnlock() return m } @@ -1064,7 +1116,7 @@ func (c *cache) ItemCount() int { // Delete all items from the cache. func (c *cache) Flush() { c.mu.Lock() - c.items = map[string]Item{} + c.items = map[string]*Item{} c.mu.Unlock() } @@ -1099,7 +1151,7 @@ func runJanitor(c *cache, ci time.Duration) { go j.Run(c) } -func newCache(de time.Duration, m map[string]Item) *cache { +func newCache(de time.Duration, m map[string]*Item) *cache { if de == 0 { de = -1 } @@ -1110,7 +1162,7 @@ func newCache(de time.Duration, m map[string]Item) *cache { return c } -func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) *Cache { +func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]*Item) *Cache { c := newCache(de, m) // This trick ensures that the janitor goroutine (which--granted it // was enabled--is running DeleteExpired on c forever) does not keep @@ -1131,7 +1183,7 @@ func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) // manually. If the cleanup interval is less than one, expired items are not // deleted from the cache before calling c.DeleteExpired(). func New(defaultExpiration, cleanupInterval time.Duration) *Cache { - items := make(map[string]Item) + items := make(map[string]*Item) return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) } @@ -1156,6 +1208,6 @@ func New(defaultExpiration, cleanupInterval time.Duration) *Cache { // gob.Register() the individual types stored in the cache before encoding a // map retrieved with c.Items(), and to register those same types before // decoding a blob containing an items map. -func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache { +func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]*Item) *Cache { return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) } diff --git a/cache_test.go b/cache_test.go index de3e9d6..655b865 100644 --- a/cache_test.go +++ b/cache_test.go @@ -107,12 +107,12 @@ func TestCacheTimes(t *testing.T) { } func TestNewFrom(t *testing.T) { - m := map[string]Item{ - "a": Item{ + m := map[string]*Item{ + "a": &Item{ Object: 1, Expiration: 0, }, - "b": Item{ + "b": &Item{ Object: 2, Expiration: 0, }, @@ -1769,3 +1769,28 @@ func TestGetWithExpiration(t *testing.T) { t.Error("expiration for e is in the past") } } + +func TestGetWithExpirationUpdate(t *testing.T) { + var found bool + + tc := New(50*time.Millisecond, 1*time.Millisecond) + tc.Set("a", 1, DefaultExpiration) + + <-time.After(25 * time.Millisecond) + _, found = tc.GetWithExpirationUpdate("a", DefaultExpiration) + if !found { + t.Error("item `a` not expired yet") + } + + <-time.After(25 * time.Millisecond) + _, found = tc.Get("a") + if !found { + t.Error("item `a` not expired yet") + } + + <-time.After(30 * time.Millisecond) + _, found = tc.Get("a") + if found { + t.Error("Found `a` when it should have been automatically deleted") + } +} diff --git a/sharded.go b/sharded.go index bcc0538..7ae92af 100644 --- a/sharded.go +++ b/sharded.go @@ -109,8 +109,8 @@ func (sc *shardedCache) DeleteExpired() { // fields of the items should be checked. Note that explicit synchronization // is needed to use a cache and its corresponding Items() return values at // the same time, as the maps are shared. -func (sc *shardedCache) Items() []map[string]Item { - res := make([]map[string]Item, len(sc.cs)) +func (sc *shardedCache) Items() []map[string]*Item { + res := make([]map[string]*Item, len(sc.cs)) for i, v := range sc.cs { res[i] = v.Items() } @@ -171,7 +171,7 @@ func newShardedCache(n int, de time.Duration) *shardedCache { for i := 0; i < n; i++ { c := &cache{ defaultExpiration: de, - items: map[string]Item{}, + items: map[string]*Item{}, } sc.cs[i] = c }