diff --git a/django_mongodb_backend/cache.py b/django_mongodb_backend/cache.py index 00b903afe..2037dd856 100644 --- a/django_mongodb_backend/cache.py +++ b/django_mongodb_backend/cache.py @@ -1,8 +1,10 @@ import pickle from datetime import datetime, timezone +from base64 import b64encode, b64decode from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache from django.core.cache.backends.db import Options +from django.core.signing import Signer, BadSignature from django.db import connections, router from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING, IndexModel, ReturnDocument @@ -10,8 +12,9 @@ class MongoSerializer: - def __init__(self, protocol=None): + def __init__(self, protocol=None, signer=True, salt=None): self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol + self.signer = Signer(salt=salt) if signer else None def dumps(self, obj): # For better incr() and decr() atomicity, don't pickle integers. @@ -19,12 +22,15 @@ def dumps(self, obj): # subclasses like bool. if type(obj) is int: # noqa: E721 return obj - return pickle.dumps(obj, self.protocol) + pickled_data = pickle.dumps(obj, protocol=self.protocol) # noqa: S301 + return self.signer.sign(b64encode(pickled_data).decode()) if self.signer else pickled_data def loads(self, data): try: return int(data) except (ValueError, TypeError): + if self.signer is not None: + data = b64decode(self.signer.unsign(data)) return pickle.loads(data) # noqa: S301 @@ -39,6 +45,8 @@ class CacheEntry: _meta = Options(collection_name) self.cache_model_class = CacheEntry + self._sign_cache = params.get("ENABLE_SIGNING", True) + self._salt = params.get("SALT", None) def create_indexes(self): expires_index = IndexModel("expires_at", expireAfterSeconds=0) @@ -47,7 +55,7 @@ def create_indexes(self): @cached_property def serializer(self): - return MongoSerializer(self.pickle_protocol) + return MongoSerializer(self.pickle_protocol, self._sign_cache, self._salt) @property def collection_for_read(self): @@ -84,7 +92,13 @@ def get_many(self, keys, version=None): with self.collection_for_read.find( {"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)} ) as cursor: - return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor} + results = {} + for row in cursor: + try: + results[keys_map[row["key"]]] = self.serializer.loads(row["value"]) + except (BadSignature, TypeError): + self.delete(row["key"]) + return results def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_and_validate_key(key, version=version) diff --git a/docs/source/topics/cache.rst b/docs/source/topics/cache.rst index 881e1b78b..f58b2852d 100644 --- a/docs/source/topics/cache.rst +++ b/docs/source/topics/cache.rst @@ -32,6 +32,11 @@ In addition, the cache is culled based on ``CULL_FREQUENCY`` when ``add()`` or ``set()`` is called, if ``MAX_ENTRIES`` is exceeded. See :ref:`django:cache_arguments` for an explanation of these two options. +Cache entries include a HMAC signature to ensure data integrity by default. +You can disable this by setting ``ENABLE_SIGNING`` to ``False``. +Signatures can also include an optional salt parameter by setting ``SALT`` +to a string value. + Creating the cache collection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/cache_/tests.py b/tests/cache_/tests.py index c28b549e5..60f449bc3 100644 --- a/tests/cache_/tests.py +++ b/tests/cache_/tests.py @@ -97,6 +97,7 @@ def caches_setting_for_tests(base=None, exclude=None, **params): BACKEND="django_mongodb_backend.cache.MongoDBCache", # Spaces are used in the name to ensure quoting/escaping works. LOCATION="test cache collection", + ENABLE_SIGNING=False, ), ) @modify_settings( @@ -955,6 +956,22 @@ def test_serializer_dumps(self): self.assertIsInstance(cache.serializer.dumps("abc"), bytes) +@override_settings( + CACHES=caches_setting_for_tests( + BACKEND="django_mongodb_backend.cache.MongoDBCache", + # Spaces are used in the name to ensure quoting/escaping works. + LOCATION="test cache collection", + ENABLE_SIGNING=True, + SALT="test-salt", + ), +) +class SignedCacheTests(CacheTests): + def test_serializer_dumps(self): + # The serializer should return a bytestring for signed caches. + self.assertEqual(cache.serializer.dumps(123), 123) + self.assertIsInstance(cache.serializer.dumps(True), str) + self.assertIsInstance(cache.serializer.dumps("abc"), str) + class DBCacheRouter: """A router that puts the cache table on the 'other' database."""