Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GetWithExpirationUpdate - atomic implementation #126

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ import (
type Item struct {
Object interface{}
Expiration int64
mu sync.RWMutex
}

// Returns true if the item has expired.
func (item Item) Expired() bool {
func (item *Item) Expired() bool {
item.mu.RLock()

if item.Expiration == 0 {
item.mu.RUnlock()
return false
}

item.mu.RUnlock()
return time.Now().UnixNano() > item.Expiration
}

Expand All @@ -39,7 +45,7 @@ type Cache struct {

type cache struct {
defaultExpiration time.Duration
items map[string]Item
items map[string]*Item
mu sync.RWMutex
onEvicted func(string, interface{})
janitor *janitor
Expand All @@ -58,7 +64,7 @@ func (c *cache) Set(k string, x interface{}, d time.Duration) {
e = time.Now().Add(d).UnixNano()
}
c.mu.Lock()
c.items[k] = Item{
c.items[k] = &Item{
Object: x,
Expiration: e,
}
Expand All @@ -75,7 +81,7 @@ func (c *cache) set(k string, x interface{}, d time.Duration) {
if d > 0 {
e = time.Now().Add(d).UnixNano()
}
c.items[k] = Item{
c.items[k] = &Item{
Object: x,
Expiration: e,
}
Expand Down Expand Up @@ -119,18 +125,19 @@ func (c *cache) Replace(k string, x interface{}, d time.Duration) error {
// whether the key was found.
func (c *cache) Get(k string) (interface{}, bool) {
c.mu.RLock()

// "Inlining" of get and Expired
item, found := c.items[k]
if !found {
c.mu.RUnlock()
return nil, false
}
if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
c.mu.RUnlock()
return nil, false
}

if item.Expired() {
c.mu.RUnlock()
return nil, false
}

c.mu.RUnlock()
return item.Object, true
}
Expand All @@ -141,41 +148,81 @@ func (c *cache) Get(k string) (interface{}, bool) {
// whether the key was found.
func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) {
c.mu.RLock()

// "Inlining" of get and Expired
item, found := c.items[k]
if !found {
c.mu.RUnlock()
return nil, time.Time{}, false
}

item.mu.RLock()

if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
item.mu.RUnlock()
c.mu.RUnlock()
return nil, time.Time{}, false
}

// Return the item and the expiration time
item.mu.RUnlock()
c.mu.RUnlock()
return item.Object, time.Unix(0, item.Expiration), true
}

// If expiration <= 0 (i.e. no expiration time set) then return the item
// and a zeroed time.Time
c.mu.RUnlock()
return item.Object, time.Time{}, true
}

func (c *cache) get(k string) (interface{}, bool) {
// GetWithExpirationUpdate returns item and updates its cache expiration time
// It returns the item or nil, the expiration time if one is set (if the item
// never expires a zero value for time.Time is returned), and a bool indicating
// whether the key was found.
func (c *cache) GetWithExpirationUpdate(k string, d time.Duration) (interface{}, bool) {
c.mu.RLock()

item, found := c.items[k]
if !found {
c.mu.RUnlock()
return nil, false
}
// "Inlining" of Expired

// Don't call item.Expired() here since
// we write lock item.Expiration
item.mu.Lock()

if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
item.mu.Unlock()
c.mu.RUnlock()
return nil, false
}
}

if d == DefaultExpiration {
d = c.defaultExpiration
}
if d > 0 {
c.items[k].Expiration = time.Now().Add(d).UnixNano()
}

item.mu.Unlock()
c.mu.RUnlock()
return item.Object, true
}

func (c *cache) get(k string) (interface{}, bool) {
item, found := c.items[k]
if !found {
return nil, false
}
// "Inlining" of Expired
if item.Expired() {
return nil, false
}

return item.Object, true
}

Expand Down Expand Up @@ -968,11 +1015,13 @@ func (c *cache) Save(w io.Writer) (err error) {
}
}()
c.mu.RLock()
defer c.mu.RUnlock()

for _, v := range c.items {
gob.Register(v.Object)
}
err = enc.Encode(&c.items)

c.mu.RUnlock()
return
}

Expand Down Expand Up @@ -1001,18 +1050,19 @@ func (c *cache) SaveFile(fname string) error {
// documentation for NewFrom().)
func (c *cache) Load(r io.Reader) error {
dec := gob.NewDecoder(r)
items := map[string]Item{}
items := map[string]*Item{}
err := dec.Decode(&items)
if err == nil {
c.mu.Lock()
defer c.mu.Unlock()

for k, v := range items {
ov, found := c.items[k]
if !found || ov.Expired() {
c.items[k] = v
}
}
}
c.mu.Unlock()
return err
}

Expand All @@ -1035,20 +1085,22 @@ func (c *cache) LoadFile(fname string) error {
}

// Copies all unexpired items in the cache into a new map and returns it.
func (c *cache) Items() map[string]Item {
func (c *cache) Items() map[string]*Item {
c.mu.RLock()
defer c.mu.RUnlock()
m := make(map[string]Item, len(c.items))
now := time.Now().UnixNano()

m := make(map[string]*Item, len(c.items))
for k, v := range c.items {
// "Inlining" of Expired
if v.Expiration > 0 {
if now > v.Expiration {
continue
}
if v.Expired() {
continue
}
m[k] = &Item{
Object: v.Object,
Expiration: v.Expiration,
}
m[k] = v
}

c.mu.RUnlock()
return m
}

Expand All @@ -1064,7 +1116,7 @@ func (c *cache) ItemCount() int {
// Delete all items from the cache.
func (c *cache) Flush() {
c.mu.Lock()
c.items = map[string]Item{}
c.items = map[string]*Item{}
c.mu.Unlock()
}

Expand Down Expand Up @@ -1099,7 +1151,7 @@ func runJanitor(c *cache, ci time.Duration) {
go j.Run(c)
}

func newCache(de time.Duration, m map[string]Item) *cache {
func newCache(de time.Duration, m map[string]*Item) *cache {
if de == 0 {
de = -1
}
Expand All @@ -1110,7 +1162,7 @@ func newCache(de time.Duration, m map[string]Item) *cache {
return c
}

func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) *Cache {
func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]*Item) *Cache {
c := newCache(de, m)
// This trick ensures that the janitor goroutine (which--granted it
// was enabled--is running DeleteExpired on c forever) does not keep
Expand All @@ -1131,7 +1183,7 @@ func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item)
// manually. If the cleanup interval is less than one, expired items are not
// deleted from the cache before calling c.DeleteExpired().
func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
items := make(map[string]Item)
items := make(map[string]*Item)
return newCacheWithJanitor(defaultExpiration, cleanupInterval, items)
}

Expand All @@ -1156,6 +1208,6 @@ func New(defaultExpiration, cleanupInterval time.Duration) *Cache {
// gob.Register() the individual types stored in the cache before encoding a
// map retrieved with c.Items(), and to register those same types before
// decoding a blob containing an items map.
func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache {
func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]*Item) *Cache {
return newCacheWithJanitor(defaultExpiration, cleanupInterval, items)
}
31 changes: 28 additions & 3 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func TestCacheTimes(t *testing.T) {
}

func TestNewFrom(t *testing.T) {
m := map[string]Item{
"a": Item{
m := map[string]*Item{
"a": &Item{
Object: 1,
Expiration: 0,
},
"b": Item{
"b": &Item{
Object: 2,
Expiration: 0,
},
Expand Down Expand Up @@ -1769,3 +1769,28 @@ func TestGetWithExpiration(t *testing.T) {
t.Error("expiration for e is in the past")
}
}

func TestGetWithExpirationUpdate(t *testing.T) {
var found bool

tc := New(50*time.Millisecond, 1*time.Millisecond)
tc.Set("a", 1, DefaultExpiration)

<-time.After(25 * time.Millisecond)
_, found = tc.GetWithExpirationUpdate("a", DefaultExpiration)
if !found {
t.Error("item `a` not expired yet")
}

<-time.After(25 * time.Millisecond)
_, found = tc.Get("a")
if !found {
t.Error("item `a` not expired yet")
}

<-time.After(30 * time.Millisecond)
_, found = tc.Get("a")
if found {
t.Error("Found `a` when it should have been automatically deleted")
}
}
6 changes: 3 additions & 3 deletions sharded.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ func (sc *shardedCache) DeleteExpired() {
// fields of the items should be checked. Note that explicit synchronization
// is needed to use a cache and its corresponding Items() return values at
// the same time, as the maps are shared.
func (sc *shardedCache) Items() []map[string]Item {
res := make([]map[string]Item, len(sc.cs))
func (sc *shardedCache) Items() []map[string]*Item {
res := make([]map[string]*Item, len(sc.cs))
for i, v := range sc.cs {
res[i] = v.Items()
}
Expand Down Expand Up @@ -171,7 +171,7 @@ func newShardedCache(n int, de time.Duration) *shardedCache {
for i := 0; i < n; i++ {
c := &cache{
defaultExpiration: de,
items: map[string]Item{},
items: map[string]*Item{},
}
sc.cs[i] = c
}
Expand Down