Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions optimizely/odp/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def peek(self, key: K) -> Optional[V]:
element = self.map.get(key)
return element.value if element is not None else None

def remove(self, key: K) -> None:
"""Remove the element associated with the provided key from the cache."""
if self.capacity <= 0:
return

with self.lock:
self.map.pop(key, None)


@dataclass
class CacheElement(Generic[V]):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,87 @@ def test_reset(self):
cache.save('cow', 'crate')
self.assertEqual(cache.lookup('cow'), 'crate')

def test_remove_non_existent_key(self):
cache = LRUCache(3, 1000)
cache.save("1", 100)
cache.save("2", 200)

cache.remove("3") # Doesn't exist

self.assertEqual(cache.lookup("1"), 100)
self.assertEqual(cache.lookup("2"), 200)
self.assertEqual(len(cache.map), 2)

def test_remove_existing_key(self):
cache = LRUCache(3, 1000)

cache.save("1", 100)
cache.save("2", 200)
cache.save("3", 300)

self.assertEqual(cache.lookup("1"), 100)
self.assertEqual(cache.lookup("2"), 200)
self.assertEqual(cache.lookup("3"), 300)
self.assertEqual(len(cache.map), 3)

cache.remove("2")

self.assertEqual(cache.lookup("1"), 100)
self.assertIsNone(cache.lookup("2"))
self.assertEqual(cache.lookup("3"), 300)
self.assertEqual(len(cache.map), 2)

def test_remove_from_zero_sized_cache(self):
cache = LRUCache(0, 1000)
cache.save("1", 100)
cache.remove("1")

self.assertIsNone(cache.lookup("1"))
self.assertEqual(len(cache.map), 0)

def test_remove_and_add_back(self):
cache = LRUCache(3, 1000)
cache.save("1", 100)
cache.save("2", 200)
cache.save("3", 300)

cache.remove("2")
cache.save("2", 201)

self.assertEqual(cache.lookup("1"), 100)
self.assertEqual(cache.lookup("2"), 201)
self.assertEqual(cache.lookup("3"), 300)
self.assertEqual(len(cache.map), 3)

def test_thread_safety(self):
import threading

max_size = 100
cache = LRUCache(max_size, 1000)

for i in range(1, max_size + 1):
cache.save(str(i), i * 100)

def remove_key(k):
cache.remove(str(k))

threads = []
for i in range(1, (max_size // 2) + 1):
thread = threading.Thread(target=remove_key, args=(i,))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

for i in range(1, max_size + 1):
if i <= max_size // 2:
self.assertIsNone(cache.lookup(str(i)))
else:
self.assertEqual(cache.lookup(str(i)), i * 100)

self.assertEqual(len(cache.map), max_size // 2)

# type checker test
# confirm that LRUCache matches OptimizelySegmentsCache protocol
_: OptimizelySegmentsCache = LRUCache(0, 0)