diff --git a/cache.go b/cache.go index ef5ca860..021f879f 100644 --- a/cache.go +++ b/cache.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "github.com/redis/rueidis/internal/cache" ) // NewCacheStoreFn can be provided in ClientOption for using a custom CacheStore implementation @@ -28,7 +30,7 @@ type CacheStore interface { // Update is called when receiving the response of the request sent by the above Flight Case 1 from redis. // It should not only update the store but also deliver the response to all CacheEntry.Wait and return a desired client side PXAT of the response. // Note that the server side expire time can be retrieved from RedisMessage.CachePXAT. - Update(key, cmd string, val RedisMessage) (pxat int64) + Update(key, cmd string, val RedisMessage, now time.Time) (pxat int64) // Cancel is called when the request sent by the above Flight Case 1 failed. // It should not only deliver the error to all CacheEntry.Wait but also remove the CacheEntry from the store. Cancel(key, cmd string, err error) @@ -89,7 +91,7 @@ func (a *adapter) Flight(key, cmd string, ttl time.Duration, now time.Time) (Red return RedisMessage{}, flight } -func (a *adapter) Update(key, cmd string, val RedisMessage) (sxat int64) { +func (a *adapter) Update(key, cmd string, val RedisMessage, _ time.Time) (sxat int64) { a.mu.Lock() entries := a.flights[key] if flight, ok := entries[cmd].(*adapterEntry); ok { @@ -99,7 +101,7 @@ func (a *adapter) Update(key, cmd string, val RedisMessage) (sxat int64) { val.setExpireAt(sxat) } a.store.Set(key+cmd, val) - flight.set(val, nil) + flight.setVal(val) entries[cmd] = nil } a.mu.Unlock() @@ -110,7 +112,7 @@ func (a *adapter) Cancel(key, cmd string, err error) { a.mu.Lock() entries := a.flights[key] if flight, ok := entries[cmd].(*adapterEntry); ok { - flight.set(RedisMessage{}, err) + flight.setErr(err) entries[cmd] = nil } a.mu.Unlock() @@ -152,7 +154,7 @@ func (a *adapter) Close(err error) { for _, entries := range flights { for _, e := range entries { if e != nil { - e.(*adapterEntry).set(RedisMessage{}, err) + e.(*adapterEntry).setErr(err) } } } @@ -165,16 +167,97 @@ type adapterEntry struct { xat int64 } -func (a *adapterEntry) set(val RedisMessage, err error) { - a.err, a.val = err, val +func (a *adapterEntry) setVal(val RedisMessage) { + a.val = val + close(a.ch) +} + +func (a *adapterEntry) setErr(err error) { + a.err = err close(a.ch) } func (a *adapterEntry) Wait(ctx context.Context) (RedisMessage, error) { + ctxCh := ctx.Done() + if ctxCh == nil { + <-a.ch + return a.val, a.err + } select { - case <-ctx.Done(): + case <-ctxCh: return RedisMessage{}, ctx.Err() case <-a.ch: return a.val, a.err } } + +// NewChainedCache returns a CacheStore optimized for concurrency, memory efficiency and GC, compared to +// the default client side caching CacheStore. However, it is not yet optimized for DoMultiCache. +func NewChainedCache(limit int) CacheStore { + return &chained{ + flights: cache.NewDoubleMap[*adapterEntry](64), + cache: cache.NewLRUDoubleMap[[]byte](64, int64(limit)), + } +} + +type chained struct { + flights *cache.DoubleMap[*adapterEntry] + cache *cache.LRUDoubleMap[[]byte] +} + +func (f *chained) Flight(key, cmd string, ttl time.Duration, now time.Time) (RedisMessage, CacheEntry) { + ts := now.UnixMilli() + if e, ok := f.cache.Find(key, cmd, ts); ok { + var ret RedisMessage + _ = ret.CacheUnmarshalView(e) + return ret, nil + } + xat := ts + ttl.Milliseconds() + if af, ok := f.flights.FindOrInsert(key, cmd, func() *adapterEntry { + return &adapterEntry{ch: make(chan struct{}), xat: xat} + }); ok { + return RedisMessage{}, af + } + return RedisMessage{}, nil +} + +func (f *chained) Update(key, cmd string, val RedisMessage, now time.Time) (sxat int64) { + if af, ok := f.flights.Find(key, cmd); ok { + sxat = val.getExpireAt() + if af.xat < sxat || sxat == 0 { + sxat = af.xat + val.setExpireAt(sxat) + } + bs := val.CacheMarshal(nil) + f.cache.Insert(key, cmd, int64(len(bs)+len(key)+len(cmd))+int64(cache.LRUEntrySize)+64, sxat, now.UnixMilli(), bs) + if f.flights.Delete(key, cmd) { + af.setVal(val) + } + } + return sxat +} + +func (f *chained) Cancel(key, cmd string, err error) { + if af, ok := f.flights.Find(key, cmd); ok { + if f.flights.Delete(key, cmd) { + af.setErr(err) + } + } +} + +func (f *chained) Delete(keys []RedisMessage) { + if keys == nil { + f.cache.Reset() + } else { + for _, k := range keys { + f.cache.Delete(k.string) + } + } +} + +func (f *chained) Close(err error) { + f.cache.DeleteAll() + f.flights.Close(func(entry *adapterEntry) { + entry.setErr(err) + }) +} diff --git a/cache_test.go b/cache_test.go index 3e6d12f6..acb574da 100644 --- a/cache_test.go +++ b/cache_test.go @@ -32,7 +32,7 @@ func test(t *testing.T, storeFn func() CacheStore) { v = RedisMessage{typ: '+', string: "val"} v.setExpireAt(now.Add(time.Second).UnixMilli()) - if pttl := store.Update("key", "cmd", v); pttl < now.Add(90*time.Millisecond).UnixMilli() || pttl > now.Add(100*time.Millisecond).UnixMilli() { + if pttl := store.Update("key", "cmd", v, now); pttl < now.Add(90*time.Millisecond).UnixMilli() || pttl > now.Add(100*time.Millisecond).UnixMilli() { t.Fatal("Update should return a desired pttl") } @@ -104,8 +104,8 @@ func test(t *testing.T, storeFn func() CacheStore) { } { store.Flight("key", "cmd1", time.Millisecond*100, now) store.Flight("key", "cmd2", time.Millisecond*100, now) - store.Update("key", "cmd1", RedisMessage{typ: '+', string: "val"}) - store.Update("key", "cmd2", RedisMessage{typ: '+', string: "val"}) + store.Update("key", "cmd1", RedisMessage{typ: '+', string: "val"}, now) + store.Update("key", "cmd2", RedisMessage{typ: '+', string: "val"}, now) store.Delete(deletions) @@ -130,7 +130,7 @@ func test(t *testing.T, storeFn func() CacheStore) { v = RedisMessage{typ: '+', string: "val"} v.setExpireAt(now.Add(time.Millisecond).UnixMilli()) - store.Update("key", "cmd", v) + store.Update("key", "cmd", v, now) v, e = store.Flight("key", "cmd", time.Second, now.Add(time.Millisecond)) if v.typ != 0 || e != nil { @@ -183,6 +183,11 @@ func TestCacheStore(t *testing.T) { return NewSimpleCacheAdapter(&simple{store: map[string]RedisMessage{}}) }) }) + t.Run("FlattenCache", func(t *testing.T) { + test(t, func() CacheStore { + return NewChainedCache(DefaultCacheBytes) + }) + }) } type simple struct { diff --git a/internal/cache/chain.go b/internal/cache/chain.go new file mode 100644 index 00000000..fafb781a --- /dev/null +++ b/internal/cache/chain.go @@ -0,0 +1,74 @@ +package cache + +type node[V any] struct { + key string + next *node[V] + val V +} +type chain[V any] struct { + node[V] +} + +func (h *chain[V]) find(key string) (val V, ok bool) { + if h.node.key == key { + return h.node.val, true + } + for curr := h.node.next; curr != nil; curr = curr.next { + if curr.key == key { + return curr.val, true + } + } + return val, ok +} + +func (h *chain[V]) insert(key string, val V) { + if h.node.key == "" { + h.node.key = key + h.node.val = val + } else if h.node.key == key { + h.node.val = val + } else { + n := &node[V]{key: key, val: val} + n.next = h.node.next + h.node.next = n + } +} + +func (h *chain[V]) empty() bool { + return h.node.next == nil && h.node.key == "" +} + +func (h *chain[V]) delete(key string) (bool, bool) { + var zero V + if h.node.key == key { + h.node.key = "" + h.node.val = zero + return h.node.next == nil, true + } + + if h.node.next == nil { + return h.node.key == "", false + } + + if h.node.next.key == key { + h.node.next.key = "" + h.node.next.val = zero + h.node.next, h.node.next.next = h.node.next.next, nil + return h.empty(), true + } + + prev := h.node.next + curr := h.node.next.next + deleted := false + for curr != nil { + if curr.key == key { + curr.key = "" + curr.val = zero + prev.next, curr.next = curr.next, nil + deleted = true + break + } + prev, curr = curr, curr.next + } + return h.empty(), deleted +} diff --git a/internal/cache/chain_test.go b/internal/cache/chain_test.go new file mode 100644 index 00000000..2e191f3b --- /dev/null +++ b/internal/cache/chain_test.go @@ -0,0 +1,64 @@ +package cache + +import ( + "testing" +) + +func TestChain(t *testing.T) { + h := chain[int]{} + if h.empty() != true { + t.Fatal("chain is not empty") + } + if _, ok := h.find("any"); ok { + t.Fatal("value is found") + } + if empty, deleted := h.delete("any"); !empty || deleted { + t.Fatal("not empty") + } + h.insert("1", 1) + h.insert("2", 2) + h.insert("3", 3) + if v, ok := h.find("1"); !ok || v != 1 { + t.Fatal("value is not found") + } + if v, ok := h.find("2"); !ok || v != 2 { + t.Fatal("value is not found") + } + if v, ok := h.find("3"); !ok || v != 3 { + t.Fatal("value is not found") + } + if empty, deleted := h.delete("1"); empty || !deleted { + t.Fatal("empty") + } + if _, ok := h.find("1"); ok { + t.Fatal("value is found") + } + if v, ok := h.find("2"); !ok || v != 2 { + t.Fatal("value is not found") + } + if v, ok := h.find("3"); !ok || v != 3 { + t.Fatal("value is not found") + } + if empty, deleted := h.delete("2"); empty || !deleted { + t.Fatal("empty") + } + if _, ok := h.find("2"); ok { + t.Fatal("value is found") + } + if v, ok := h.find("3"); !ok || v != 3 { + t.Fatal("value is not found") + } + h.insert("4", 4) + if v, ok := h.find("3"); !ok || v != 3 { + t.Fatal("value is not found") + } + if v, ok := h.find("4"); !ok || v != 4 { + t.Fatal("value is not found") + } + if empty, deleted := h.delete("3"); empty || !deleted { + t.Fatal("empty") + } + if empty, deleted := h.delete("4"); !empty || !deleted { + t.Fatal("not empty") + } +} diff --git a/internal/cache/double.go b/internal/cache/double.go new file mode 100644 index 00000000..c1f78f86 --- /dev/null +++ b/internal/cache/double.go @@ -0,0 +1,137 @@ +package cache + +import ( + "runtime" + "sync" +) + +const bpsize = 1024 + +type head[V any] struct { + chain[V] + mu sync.RWMutex +} + +type DoubleMap[V any] struct { + ma map[string]*head[V] + bp sync.Pool + mu sync.RWMutex +} + +func (m *DoubleMap[V]) Find(key1, key2 string) (val V, ok bool) { + m.mu.RLock() + if h := m.ma[key1]; h != nil { + h.mu.RLock() + val, ok = h.find(key2) + h.mu.RUnlock() + } + m.mu.RUnlock() + return +} + +func (m *DoubleMap[V]) FindOrInsert(key1, key2 string, fn func() V) (val V, ok bool) { + m.mu.RLock() + if h := m.ma[key1]; h != nil { + h.mu.Lock() + if val, ok = h.find(key2); !ok { + val = fn() + h.insert(key2, val) + } + h.mu.Unlock() + m.mu.RUnlock() + return + } + if m.ma == nil { + m.mu.RUnlock() + return + } + m.mu.RUnlock() + m.mu.Lock() + h := m.ma[key1] + if h != nil { + if val, ok = h.find(key2); ok { + m.mu.Unlock() + return + } + } else if m.ma == nil { + m.mu.Unlock() + return + } else { + h = &head[V]{} + m.ma[key1] = h + } + val = fn() + h.insert(key2, val) + m.mu.Unlock() + return +} + +func (m *DoubleMap[V]) Delete(key1, key2 string) (deleted bool) { + var empty bool + m.mu.RLock() + if h := m.ma[key1]; h != nil { + h.mu.Lock() + empty, deleted = h.delete(key2) + h.mu.Unlock() + } + m.mu.RUnlock() + if empty { + e := m.bp.Get().(*empties) + e.s = append(e.s, key1) + if len(e.s) < bpsize { + m.bp.Put(e) + return + } + go func(m *DoubleMap[V], e *empties) { + m.delete(e.s) + clear(e.s) + e.s = e.s[:0] + m.bp.Put(e) + }(m, e) + } + return +} + +func (m *DoubleMap[V]) delete(keys []string) { + m.mu.Lock() + for _, key := range keys { + if h := m.ma[key]; h != nil { + if h.empty() { + delete(m.ma, key) + } + } + } + m.mu.Unlock() +} + +func (m *DoubleMap[V]) Close(cb func(V)) { + m.mu.Lock() + for _, h := range m.ma { + if h.node.key != "" { + cb(h.node.val) + } + for curr := h.node.next; curr != nil; curr = curr.next { + cb(curr.val) + } + } + m.ma = nil + m.mu.Unlock() +} + +type empties struct { + s []string +} + +func NewDoubleMap[V any](hint int) *DoubleMap[V] { + m := &DoubleMap[V]{ma: make(map[string]*head[V], hint)} + m.bp.New = func() any { + e := &empties{s: make([]string, 0, bpsize)} + runtime.SetFinalizer(e, func(e *empties) { + if len(e.s) >= 0 { + m.delete(e.s) + } + }) + return e + } + return m +} diff --git a/internal/cache/double_test.go b/internal/cache/double_test.go new file mode 100644 index 00000000..7d94a3b2 --- /dev/null +++ b/internal/cache/double_test.go @@ -0,0 +1,94 @@ +package cache + +import ( + "runtime" + "strconv" + "testing" +) + +func TestDoubleMap(t *testing.T) { + m := NewDoubleMap[int](8) + if _, ok := m.Find("1", "2"); ok { + t.Fatalf("should not find 1 2") + } + if v, ok := m.FindOrInsert("1", "a", func() int { + return 1 + }); ok || v != 1 { + t.Fatalf("should insert 1 but not found") + } + if v, ok := m.FindOrInsert("1", "a", func() int { + return 2 + }); !ok || v != 1 { + t.Fatalf("should found 1") + } + m.Delete("1", "a") + if _, ok := m.Find("1", "2"); ok { + t.Fatalf("should not find 1 2") + } + if v, ok := m.FindOrInsert("1", "a", func() int { + return 2 + }); ok || v != 2 { + t.Fatalf("should insert 1 but not found") + } + if v, ok := m.FindOrInsert("1", "b", func() int { + return 2 + }); ok || v != 2 { + t.Fatalf("should insert 1 but not found") + } + if v, ok := m.FindOrInsert("2", "b", func() int { + return 2 + }); ok || v != 2 { + t.Fatalf("should insert 1 but not found") + } + c := 0 + m.Close(func(i int) { + if i != 2 { + t.Fatalf("should iterate 2") + } + c++ + }) + if c != 3 { + t.Fatalf("should iterate 3 times") + } +} + +func TestDoubleMap_Delete(t *testing.T) { + m := NewDoubleMap[int](bpsize) + for i := 0; i < bpsize; i++ { + m.FindOrInsert(strconv.Itoa(i), "a", func() int { + return 1 + }) + } + for i := 0; i < bpsize-1; i++ { + m.Delete(strconv.Itoa(i), "a") + } + m.Delete(strconv.Itoa(bpsize-1), "a") + runtime.GC() + runtime.GC() + m.mu.Lock() + heads := len(m.ma) + m.mu.Unlock() + if heads != 0 { + t.Fatalf("no shrink") + } +} + +func TestDoubleMap_DeleteGC(t *testing.T) { + m := NewDoubleMap[int](bpsize) + for i := 0; i < bpsize; i++ { + m.FindOrInsert(strconv.Itoa(i), "a", func() int { + return 1 + }) + } + for i := 0; i < bpsize-1; i++ { + m.Delete(strconv.Itoa(i), "a") + } + runtime.GC() + runtime.GC() + m.mu.Lock() + heads := len(m.ma) + m.mu.Unlock() + if heads != 1 { + t.Fatalf("no shrink") + } +} diff --git a/internal/cache/lru.go b/internal/cache/lru.go new file mode 100644 index 00000000..d9ada380 --- /dev/null +++ b/internal/cache/lru.go @@ -0,0 +1,253 @@ +package cache + +import ( + "runtime" + "sync" + "sync/atomic" + "unsafe" +) + +const LRUEntrySize = unsafe.Sizeof(linked[[]byte]{}) + +type linked[V any] struct { + key string + head chain[V] + next unsafe.Pointer + prev unsafe.Pointer + size int64 + ts int64 + mu sync.RWMutex + cnt uint32 + mark int32 +} + +func (h *linked[V]) find(key string, ts int64) (v V, ok bool) { + h.mu.RLock() + defer h.mu.RUnlock() + if h.ts > ts { + return h.head.find(key) + } + return +} + +func (h *linked[V]) close() { + h.mu.Lock() + h.ts = 0 + h.head = chain[V]{} + h.mu.Unlock() +} + +type LRUDoubleMap[V any] struct { + ma map[string]*linked[V] + mi []string + bp sync.Pool + mu sync.RWMutex + head unsafe.Pointer + tail unsafe.Pointer + total int64 + limit int64 + mark int32 +} + +func (m *LRUDoubleMap[V]) Find(key1, key2 string, ts int64) (val V, ok bool) { + m.mu.RLock() + h := m.ma[key1] + if h != nil { + val, ok = h.find(key2, ts) + } + m.mu.RUnlock() + if ok && atomic.AddUint32(&h.cnt, 1)&3 == 0 { + b := m.bp.Get().(*ruBatch[V]) + b.s = append(b.s, h) + if len(b.s) < bpsize { + m.bp.Put(b) + return + } + go func(m *LRUDoubleMap[V], b *ruBatch[V]) { + m.moveToTail(b.s) + clear(b.s) + b.s = b.s[:0] + m.bp.Put(b) + }(m, b) + } + return +} + +func (m *LRUDoubleMap[V]) remove(h *linked[V]) { + h.mark -= 1 + next := h.next + prev := h.prev + h.next = nil + h.prev = nil + if next != nil { + (*linked[V])(next).prev = prev + } + if prev != nil { + (*linked[V])(prev).next = next + } + if m.head == unsafe.Pointer(h) { + m.head = next + } + if m.tail == unsafe.Pointer(h) { + m.tail = prev + } + atomic.AddInt64(&m.total, -h.size) + delete(m.ma, h.key) +} + +func (m *LRUDoubleMap[V]) move(h *linked[V]) { + prev := h.prev + next := h.next + if prev != nil { + (*linked[V])(prev).next = next + } + if next != nil { + (*linked[V])(next).prev = prev + } + h.next = nil + if m.tail != nil && m.tail != unsafe.Pointer(h) { + h.prev = m.tail + (*linked[V])(m.tail).next = unsafe.Pointer(h) + } + m.tail = unsafe.Pointer(h) + if m.head == unsafe.Pointer(h) && next != nil { + m.head = next + } +} + +func (m *LRUDoubleMap[V]) Insert(key1, key2 string, size, ts, now int64, v V) { + m.mu.RLock() + if h := m.ma[key1]; h != nil { + atomic.AddInt64(&m.total, size) + h.mu.Lock() + if h.ts <= now { + atomic.AddInt64(&m.total, -h.size) + h.size = 0 + h.head = chain[V]{} + } + h.ts = ts + h.size += size + h.head.insert(key2, v) + h.mu.Unlock() + m.mu.RUnlock() + return + } + m.mu.RUnlock() + m.mu.Lock() + if m.ma == nil { + m.mu.Unlock() + return + } + atomic.AddInt64(&m.total, size) + for m.head != nil { + h := (*linked[V])(m.head) + if h.ts != 0 && h.ts > now && atomic.LoadInt64(&m.total) <= m.limit { + break + } + m.remove(h) + } + + h := &linked[V]{key: key1, ts: ts, size: size, mark: m.mark} + h.head.insert(key2, v) + m.ma[key1] = h // h must not exist in the map because this Insert is called sequentially. + m.move(h) + if m.head == nil { + m.head = unsafe.Pointer(h) + } + m.mu.Unlock() +} + +func (m *LRUDoubleMap[V]) Delete(key1 string) { + if m.mi == nil { // no need to lock m.mi because this Delete is called sequentially. + m.mi = make([]string, 0, bpsize) + } else if len(m.mi) == bpsize { + m.mu.Lock() + for _, key := range m.mi { + if h := m.ma[key]; h != nil && h.ts == 0 { + m.remove(h) + } + } + if h := m.ma[key1]; h != nil { + m.remove(h) + } + for m.head != nil { + h := (*linked[V])(m.head) + if h.ts != 0 && atomic.LoadInt64(&m.total) <= m.limit { + break + } + m.remove(h) + } + m.mu.Unlock() + clear(m.mi) + return + } + m.mu.RLock() + h := m.ma[key1] + if h != nil { + h.close() + } + m.mu.RUnlock() + if h != nil { + m.mi = append(m.mi, key1) + } +} + +func (m *LRUDoubleMap[V]) DeleteAll() { + m.mu.Lock() + m.ma = nil + m.mi = nil + m.head = nil + m.tail = nil + atomic.StoreInt64(&m.total, 0) + m.mark++ + m.mu.Unlock() +} + +func (m *LRUDoubleMap[V]) Reset() { + m.mu.Lock() + m.ma = make(map[string]*linked[V], len(m.ma)) + m.mi = nil + m.head = nil + m.tail = nil + atomic.StoreInt64(&m.total, 0) + m.mark++ + m.mu.Unlock() +} + +func (m *LRUDoubleMap[V]) moveToTail(s []*linked[V]) { + m.mu.Lock() + defer m.mu.Unlock() + for _, h := range s { + if h.mark == m.mark { + m.move(h) + } + } + for m.head != nil { + h := (*linked[V])(m.head) + if h.ts != 0 && atomic.LoadInt64(&m.total) <= m.limit { + break + } + m.remove(h) + } +} + +type ruBatch[V any] struct { + s []*linked[V] +} + +func NewLRUDoubleMap[V any](hint, limit int64) *LRUDoubleMap[V] { + m := &LRUDoubleMap[V]{ + ma: make(map[string]*linked[V], hint), + limit: limit, + } + m.bp.New = func() interface{} { + b := &ruBatch[V]{s: make([]*linked[V], 0, bpsize)} + runtime.SetFinalizer(b, func(b *ruBatch[V]) { + if len(b.s) > 0 { + m.moveToTail(b.s) + } + }) + return b + } + return m +} diff --git a/internal/cache/lru_test.go b/internal/cache/lru_test.go new file mode 100644 index 00000000..b35878d2 --- /dev/null +++ b/internal/cache/lru_test.go @@ -0,0 +1,239 @@ +package cache + +import ( + "runtime" + "strconv" + "sync/atomic" + "testing" +) + +func TestLRUDoubleMap(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize, bpsize) + if _, ok := m.Find("1", "a", 1); ok { + t.Fatal("should not find 1") + } + m.Insert("1", "a", 1, 2, 1, 1) + m.Insert("1", "b", 1, 2, 1, 2) + m.Insert("2", "c", 1, 2, 1, 3) + if v, ok := m.Find("1", "a", 1); !ok || v != 1 { + t.Fatal("not find 1") + } + if v, ok := m.Find("1", "b", 1); !ok || v != 2 { + t.Fatal("not find 2") + } + if v, ok := m.Find("2", "c", 1); !ok || v != 3 { + t.Fatal("not find 3") + } + if _, ok := m.Find("1", "a", 2); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("1", "b", 2); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("2", "c", 2); ok { + t.Fatal("should not find") + } + m.Delete("1") + if _, ok := m.Find("1", "a", 1); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("1", "b", 1); ok { + t.Fatal("should not find") + } + if v, ok := m.Find("2", "c", 1); !ok || v != 3 { + t.Fatal("not find 3") + } + m.Delete("2") + m.mu.Lock() + heads := len(m.ma) + m.mu.Unlock() + if heads != 2 { + t.Fatal("should have 2 heads") + } + + m.Insert("1", "d", 1, 2, 1, 1) + m.Insert("1", "e", 1, 2, 1, 2) + m.Insert("2", "f", 1, 2, 1, 3) + if v, ok := m.Find("1", "d", 1); !ok || v != 1 { + t.Fatal("not find 1") + } + if v, ok := m.Find("1", "e", 1); !ok || v != 2 { + t.Fatal("not find 2") + } + if v, ok := m.Find("2", "f", 1); !ok || v != 3 { + t.Fatal("not find 3") + } + m.DeleteAll() + if _, ok := m.Find("1", "d", 1); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("1", "e", 1); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("2", "f", 1); ok { + t.Fatal("should not find") + } +} + +func TestLRUDoubleMap_BatchDelete(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize, bpsize) + m.Insert("1", "a", 1, 2, 1, 1) + m.Insert("2", "c", 1, 2, 1, 3) + m.Delete("1") + m.Delete("2") + if _, ok := m.Find("1", "a", 1); ok { + t.Fatal("should not find") + } + if _, ok := m.Find("2", "c", 1); ok { + t.Fatal("should not find") + } + m.mu.Lock() + heads := len(m.ma) + total := atomic.LoadInt64(&m.total) + m.mu.Unlock() + if heads != 2 { + t.Fatal("should have 2 heads") + } + if total == 0 { + t.Fatal("should not have 0 total") + } + for i := 0; i < bpsize; i++ { + m.Delete("1") + m.Delete("2") + } + m.mu.Lock() + heads = len(m.ma) + total = atomic.LoadInt64(&m.total) + m.mu.Unlock() + if heads != 0 { + t.Fatal("should have 0 heads") + } + if total != 0 { + t.Fatal("should have 0 total") + } +} + +func TestLRUCache_LRU_1(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize, bpsize) + for i := 0; i < bpsize; i++ { + m.Insert(strconv.Itoa(i), "a", 2, 2, 1, i) + } + m.mu.Lock() + heads := len(m.ma) + m.mu.Unlock() + if heads != (bpsize / 2) { + t.Fatal("should have bpsize/2 heads", heads) + } + for i := 0; i < bpsize/2; i++ { + if _, ok := m.Find(strconv.Itoa(i), "a", 1); ok { + t.Fatal("should not find") + } + } + for i := bpsize / 2; i < bpsize; i++ { + if v, ok := m.Find(strconv.Itoa(i), "a", 1); !ok || v != i { + t.Fatal("not find") + } + } +} + +func TestLRUCache_LRU_2(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize*2, bpsize*2) + for i := 0; i < bpsize*2; i++ { + m.Insert(strconv.Itoa(i), "a", 1, 2, 1, i) + } + m.mu.Lock() + heads := len(m.ma) + m.mu.Unlock() + if heads != (bpsize * 2) { + t.Fatal("should have bpsize*2 heads", heads) + } + for i := 0; i < bpsize; i++ { + for j := 0; j < 4; j++ { + if v, ok := m.Find(strconv.Itoa(i), "a", 1); !ok || v != i { + t.Fatal("not find") + } + } + } + runtime.GC() + runtime.GC() + for i := bpsize * 2; i < bpsize*3; i++ { + m.Insert(strconv.Itoa(i), "a", 1, 2, 1, i) + } + for i := 0; i < bpsize; i++ { + if v, ok := m.Find(strconv.Itoa(i), "a", 1); !ok || v != i { + t.Fatal("not find", v, ok) + } + } + for i := bpsize * 1; i < bpsize*2; i++ { + if _, ok := m.Find(strconv.Itoa(i), "a", 1); ok { + t.Fatal("should not find") + } + } + for i := bpsize * 2; i < bpsize*3; i++ { + if v, ok := m.Find(strconv.Itoa(i), "a", 1); !ok || v != i { + t.Fatal("not find", v, ok) + } + } +} + +func TestLRUCache_LRU_GC(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize, bpsize) + for i := 0; i < bpsize; i++ { + m.Insert(strconv.Itoa(i), "a", 1, 2, 1, i) + } + for j := 0; j < 4; j++ { + if v, ok := m.Find(strconv.Itoa(bpsize/2), "a", 1); !ok || v != bpsize/2 { + t.Fatal("not find") + } + } + runtime.GC() + runtime.GC() + m.Insert("a", "a", bpsize-1, 2, 1, 0) + m.mu.Lock() + heads := len(m.ma) + total := m.total + m.mu.Unlock() + if heads != 2 { + t.Fatal("should have 2 heads", heads) + } + if total != bpsize { + t.Fatal("should have bpsize", bpsize) + } + for i := 0; i < bpsize; i++ { + if i == bpsize/2 { + if v, ok := m.Find(strconv.Itoa(i), "a", 1); !ok || v != i { + t.Fatal("not find") + } + } else { + if _, ok := m.Find(strconv.Itoa(i), "a", 1); ok { + t.Fatal("should not find") + } + } + } +} + +func TestLRUCache_LRU_GC_2(t *testing.T) { + m := NewLRUDoubleMap[int](bpsize, bpsize) + for i := 0; i < bpsize; i++ { + m.Insert(strconv.Itoa(i), "a", 1, 2, 1, i) + } + for j := 0; j < 4; j++ { + if v, ok := m.Find(strconv.Itoa(bpsize/2), "a", 1); !ok || v != bpsize/2 { + t.Fatal("not find") + } + } + m.Reset() + runtime.GC() + runtime.GC() + m.Insert("a", "a", bpsize-1, 2, 1, 0) + m.mu.Lock() + heads := len(m.ma) + total := m.total + m.mu.Unlock() + if heads != 1 { + t.Fatal("should have 1 heads", heads) + } + if total != bpsize-1 { + t.Fatal("should have bpsize-1", bpsize-1) + } +} diff --git a/lru.go b/lru.go index 897e6e96..180b0b98 100644 --- a/lru.go +++ b/lru.go @@ -218,7 +218,7 @@ func (c *lru) Flights(now time.Time, multi []CacheableTTL, results []RedisResult return missed[:j] } -func (c *lru) Update(key, cmd string, value RedisMessage) (pxat int64) { +func (c *lru) Update(key, cmd string, value RedisMessage, _ time.Time) (pxat int64) { var ch chan struct{} c.mu.Lock() if kc, ok := c.store[key]; ok { diff --git a/lru_test.go b/lru_test.go index a6745ed9..8e7996ff 100644 --- a/lru_test.go +++ b/lru_test.go @@ -26,7 +26,7 @@ func TestLRU(t *testing.T) { } m := RedisMessage{typ: '+', string: "0", values: []RedisMessage{{}}} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - store.Update("0", "GET", m) + store.Update("0", "GET", m, time.Now()) return store.(*lru) } @@ -49,7 +49,7 @@ func TestLRU(t *testing.T) { t.Fatalf("got unexpected value from the Flight after pttl: %v %v", v, entry) } m := RedisMessage{typ: '+', string: "1"} - lru.Update("1", "GET", m) + lru.Update("1", "GET", m, time.Now()) if v, _ := lru.Flight("1", "GET", TTL, time.Now()); v.typ == 0 { t.Fatalf("did not get the value from the second Flight") } else if v.string != "1" { @@ -98,7 +98,7 @@ func TestLRU(t *testing.T) { lru.Flight(strconv.Itoa(i), "GET", TTL, time.Now()) m := RedisMessage{typ: '+', string: strconv.Itoa(i)} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - lru.Update(strconv.Itoa(i), "GET", m) + lru.Update(strconv.Itoa(i), "GET", m, time.Now()) } if v, entry := lru.Flight("1", "GET", TTL, time.Now()); v.typ != 0 { t.Fatalf("got evicted value from the first Flight: %v %v", v, entry) @@ -123,7 +123,7 @@ func TestLRU(t *testing.T) { for i := 1; i < Entries; i++ { lru.Flight(strconv.Itoa(i), "GET", TTL, time.Now()) m := RedisMessage{typ: '+', string: strconv.Itoa(i)} - lru.Update(strconv.Itoa(i), "GET", m) + lru.Update(strconv.Itoa(i), "GET", m, time.Now()) } for i := 1; i < Entries; i++ { if v, _ := lru.Flight(strconv.Itoa(i), "GET", TTL, time.Now()); v.string != strconv.Itoa(i) { @@ -157,7 +157,7 @@ func TestLRU(t *testing.T) { m := RedisMessage{typ: '+', string: "this Update should have no effect"} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - lru.Update("1", "GET", m) + lru.Update("1", "GET", m, time.Now()) for i := 0; i < 2; i++ { // entry should be always nil after the first call if Close if v, entry := lru.Flight("1", "GET", TTL, time.Now()); v.typ != 0 || entry != nil { t.Fatalf("got unexpected value from the first Flight: %v %v", v, entry) @@ -194,7 +194,7 @@ func TestLRU(t *testing.T) { lru.Flight("key", "cmd", time.Second, time.Now()) m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v := lru.GetTTL("key", "cmd"); !roughly(v, time.Second) { t.Fatalf("unexpected %v", v) } @@ -206,7 +206,7 @@ func TestLRU(t *testing.T) { lru.Flight("key", "cmd", 2*time.Second, time.Now()) m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := lru.Flight("key", "cmd", 2*time.Second, time.Now()); v.CacheTTL() != 1 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -216,7 +216,7 @@ func TestLRU(t *testing.T) { lru.Flight("key", "cmd", 2*time.Second, time.Now()) m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(3 * time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := lru.Flight("key", "cmd", 2*time.Second, time.Now()); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -225,7 +225,7 @@ func TestLRU(t *testing.T) { lru := setup(t) lru.Flight("key", "cmd", 2*time.Second, time.Now()) m := RedisMessage{typ: 1} - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := lru.Flight("key", "cmd", 2*time.Second, time.Now()); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -234,7 +234,7 @@ func TestLRU(t *testing.T) { lru := setup(t) lru.Flight("key", "cmd", 2*time.Second, time.Now()) m := RedisMessage{typ: 1} - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := lru.Flight("key", "cmd", 2*time.Second, time.Now()); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -260,7 +260,7 @@ func TestLRU(t *testing.T) { t.Fatalf("got unexpected value from the Flight after pttl: %v %v", v, entry) } m := RedisMessage{typ: '+', string: "1"} - lru.Update("1", "GET", m) + lru.Update("1", "GET", m, time.Now()) if v, _ := flights(lru, time.Now(), TTL, "GET", "1"); v.typ == 0 { t.Fatalf("did not get the value from the second Flight") } else if v.string != "1" { @@ -309,7 +309,7 @@ func TestLRU(t *testing.T) { flights(lru, time.Now(), TTL, "GET", strconv.Itoa(i)) m := RedisMessage{typ: '+', string: strconv.Itoa(i)} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - lru.Update(strconv.Itoa(i), "GET", m) + lru.Update(strconv.Itoa(i), "GET", m, time.Now()) } if v, entry := flights(lru, time.Now(), TTL, "GET", "1"); v.typ != 0 { t.Fatalf("got evicted value from the first Flight: %v %v", v, entry) @@ -334,7 +334,7 @@ func TestLRU(t *testing.T) { for i := 1; i < Entries; i++ { flights(lru, time.Now(), TTL, "GET", strconv.Itoa(i)) m := RedisMessage{typ: '+', string: strconv.Itoa(i)} - lru.Update(strconv.Itoa(i), "GET", m) + lru.Update(strconv.Itoa(i), "GET", m, time.Now()) } for i := 1; i < Entries; i++ { if v, _ := flights(lru, time.Now(), TTL, "GET", strconv.Itoa(i)); v.string != strconv.Itoa(i) { @@ -368,7 +368,7 @@ func TestLRU(t *testing.T) { m := RedisMessage{typ: '+', string: "this Update should have no effect"} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - lru.Update("1", "GET", m) + lru.Update("1", "GET", m, time.Now()) for i := 0; i < 2; i++ { // entry should be always nil after the first call if Close if v, entry := flights(lru, time.Now(), TTL, "GET", "1"); v.typ != 0 || entry != nil { t.Fatalf("got unexpected value from the first Flight: %v %v", v, entry) @@ -405,7 +405,7 @@ func TestLRU(t *testing.T) { flights(lru, time.Now(), time.Second, "cmd", "key") m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v := lru.GetTTL("key", "cmd"); !roughly(v, time.Second) { t.Fatalf("unexpected %v", v) } @@ -417,7 +417,7 @@ func TestLRU(t *testing.T) { flights(lru, time.Now(), time.Second*2, "cmd", "key") m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := flights(lru, time.Now(), time.Second*2, "cmd", "key"); v.CacheTTL() != 1 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -427,7 +427,7 @@ func TestLRU(t *testing.T) { flights(lru, time.Now(), time.Second*2, "cmd", "key") m := RedisMessage{typ: 1} m.setExpireAt(time.Now().Add(3 * time.Second).UnixMilli()) - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := flights(lru, time.Now(), time.Second*2, "cmd", "key"); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -436,7 +436,7 @@ func TestLRU(t *testing.T) { lru := setup(t) flights(lru, time.Now(), time.Second*2, "cmd", "key") m := RedisMessage{typ: 1} - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := flights(lru, time.Now(), time.Second*2, "cmd", "key"); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -445,7 +445,7 @@ func TestLRU(t *testing.T) { lru := setup(t) flights(lru, time.Now(), time.Second*2, "cmd", "key") m := RedisMessage{typ: 1} - lru.Update("key", "cmd", m) + lru.Update("key", "cmd", m, time.Now()) if v, _ := flights(lru, time.Now(), time.Second*2, "cmd", "key"); v.CacheTTL() != 2 { t.Fatalf("unexpected %v", v.CacheTTL()) } @@ -483,7 +483,7 @@ func BenchmarkLRU(b *testing.B) { lru.Flight(key, "GET", TTL, time.Now()) m := RedisMessage{} m.setExpireAt(time.Now().Add(PTTL * time.Millisecond).UnixMilli()) - lru.Update(key, "GET", m) + lru.Update(key, "GET", m, time.Now()) } }) } diff --git a/pipe.go b/pipe.go index 7c401e7b..1662eba7 100644 --- a/pipe.go +++ b/pipe.go @@ -569,7 +569,7 @@ func (p *pipe) _backgroundRead() (err error) { if pttl := msg.values[i].integer; pttl >= 0 { cp.setExpireAt(now.Add(time.Duration(pttl) * time.Millisecond).UnixMilli()) } - msgs[i].setExpireAt(p.cache.Update(ck, cc, cp)) + msgs[i].setExpireAt(p.cache.Update(ck, cc, cp, now)) } } else { ck, cc := cmds.CacheKey(cacheable) @@ -579,7 +579,7 @@ func (p *pipe) _backgroundRead() (err error) { if pttl := msg.values[ci-1].integer; pttl >= 0 { cp.setExpireAt(now.Add(time.Duration(pttl) * time.Millisecond).UnixMilli()) } - msg.values[ci].setExpireAt(p.cache.Update(ck, cc, cp)) + msg.values[ci].setExpireAt(p.cache.Update(ck, cc, cp, now)) } } if prply { diff --git a/pipe_test.go b/pipe_test.go index 6e187c6e..7bd4fc2f 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -1779,7 +1779,7 @@ func TestClientSideCachingWithSideChannelMGet(t *testing.T) { time.Sleep(100 * time.Millisecond) m := RedisMessage{typ: '+', string: "OK"} m.setExpireAt(time.Now().Add(10 * time.Millisecond).UnixMilli()) - p.cache.Update("a1", "GET", m) + p.cache.Update("a1", "GET", m, time.Now()) }() v, _ := p.DoCache(context.Background(), Cacheable(cmds.NewMGetCompleted([]string{"MGET", "a1"})), 10*time.Second).AsStrSlice() @@ -2141,7 +2141,7 @@ func TestClientSideCachingWithSideChannelDoMultiCache(t *testing.T) { time.Sleep(100 * time.Millisecond) m := RedisMessage{typ: '+', string: "OK"} m.setExpireAt(time.Now().Add(10 * time.Millisecond).UnixMilli()) - p.cache.Update("a1", "GET", m) + p.cache.Update("a1", "GET", m, time.Now()) }() arr := p.DoMultiCache(context.Background(), []CacheableTTL{