Skip to content

Commit b7c4594

Browse files
committed
Implement simple caching option
1 parent 762ae38 commit b7c4594

File tree

5 files changed

+240
-1
lines changed

5 files changed

+240
-1
lines changed

django_sorcery/db/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
from __future__ import absolute_import, print_function, unicode_literals
145145

146146
from .sqlalchemy import SQLAlchemy # noqa
147+
from .strategy_options import FromCache # noqa
147148
from .utils import dbdict
148149

149150

django_sorcery/db/query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ class Query(sa.orm.Query):
5353
A customized sqlalchemy query
5454
"""
5555

56+
def __init__(self, entities, session=None):
57+
ents = entities if isinstance(entities, (list, tuple)) else (entities,)
58+
ent_zero = next(iter(ents), None)
59+
ent_zero = getattr(ent_zero, "class_", ent_zero)
60+
self.caching_option = getattr(ent_zero, "__caching_option__", None) or None
61+
super(Query, self).__init__(entities, session=session)
62+
5663
def get(self, *args, **kwargs):
5764
"""
5865
Return an instance based on the given primary key identifier, either as args or
@@ -121,6 +128,22 @@ def _lookup_to_expression(self, lookup, value):
121128

122129
return lhs == value
123130

131+
def __iter__(self):
132+
"""
133+
override __iter__ to pull results from caching option
134+
"""
135+
super_ = super(Query, self)
136+
137+
if self.caching_option:
138+
return self.caching_option.get(super_)
139+
else:
140+
return super_.__iter__()
141+
142+
def invalidate(self):
143+
"""Invalidate the cache value represented by this Query."""
144+
if self.caching_option:
145+
self.caching_option.invalidate(self)
146+
124147

125148
class QueryProperty(object):
126149
def __init__(self, db, model=None, *args, **kwargs):
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, unicode_literals
3+
4+
from sqlalchemy.orm.interfaces import MapperOption
5+
6+
from django.core.exceptions import ImproperlyConfigured
7+
8+
9+
def _key_from_query(query):
10+
"""
11+
Given a Query, create a cache key.
12+
13+
There are many approaches to this; here we use the simplest, which is to create an md5 hash of the text of the SQL
14+
statement, combined with stringified versions of all the bound parameters within it. There's a bit of a
15+
performance hit with compiling out "query.statement" here; other approaches include setting up an explicit cache
16+
key with a particular Query, then combining that with the bound parameter values.
17+
"""
18+
19+
stmt = query.with_labels().statement
20+
compiled = stmt.compile()
21+
params = compiled.params
22+
return " ".join([str(compiled)] + [str(params[k]) for k in sorted(params)])
23+
24+
25+
class FromCache(MapperOption):
26+
"""Specifies that a Query should load results from a cache."""
27+
28+
propagate_to_loaders = False
29+
30+
def __init__(self, region, expiration_time=None, key_maker=_key_from_query):
31+
"""
32+
Provides caching mechanism for a query
33+
--------------------------------------
34+
35+
region: any
36+
The cache region. Can be a dogpile.cache region object
37+
38+
expiration_time: int or datetime.timedelta
39+
The expiration time that will be passed to region.
40+
41+
keymaker: callable
42+
A callable that will take the query and generate a cache key out of it.
43+
44+
Note that this approach does *not* detach the loaded objects from the current session. If the cache backend is
45+
an in-process cache (like "memory") and lives beyond the scope of the current session's transaction, those
46+
objects may be expired. The method here can be modified to first expunge() each loaded item from the current
47+
session before returning the list of items, so that the items in the cache are not the same ones in the
48+
current Session.
49+
"""
50+
if region is None:
51+
raise ImproperlyConfigured("FromCache requires a cache region")
52+
self.expiration_time = expiration_time
53+
self.key_maker = key_maker
54+
self.region = region
55+
56+
def process_query(self, query):
57+
"""Process a Query during normal loading operation."""
58+
query.caching_option = self
59+
60+
def get(self, query, merge=True, createfunc=None):
61+
"""
62+
Return the value from the cache for this query.
63+
"""
64+
createfunc = query.__iter__
65+
cache_key = self.key_maker(query)
66+
cached_value = self.region.get_or_create(
67+
cache_key, lambda: list(createfunc()), expiration_time=self.expiration_time
68+
)
69+
if merge:
70+
cached_value = query.merge_result(cached_value, load=False)
71+
return cached_value
72+
73+
def invalidate(self, query):
74+
cache_key = self.key_maker(query)
75+
self.region.delete(cache_key)

tests/db/test_strategy_options.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, unicode_literals
3+
4+
import pytest
5+
6+
from django.core.exceptions import ImproperlyConfigured
7+
8+
from django_sorcery.db import FromCache
9+
from django_sorcery.pytest_plugin import sqlalchemy_profiler # noqa
10+
11+
from ..testapp.models import CachedModel, CachedReference, OtherCachedModel, Owner, UnCachedModel, cache, db
12+
13+
14+
def test_without_region(): # noqa
15+
16+
with pytest.raises(ImproperlyConfigured):
17+
Owner.objects.options(FromCache(None)).filter_by(first_name="foo").one_or_none()
18+
19+
20+
def test_from_cache_option(sqlalchemy_profiler): # noqa
21+
db.add(Owner(first_name="foo", last_name="bar"))
22+
db.flush()
23+
db.expire_all()
24+
cache.cache.clear()
25+
26+
with sqlalchemy_profiler:
27+
Owner.objects.options(FromCache(cache)).filter_by(first_name="foo").one_or_none()
28+
29+
assert sqlalchemy_profiler.stats["select"] == 1
30+
assert len(cache.cache) == 1
31+
32+
with sqlalchemy_profiler:
33+
Owner.objects.options(FromCache(cache)).filter_by(first_name="foo").one_or_none()
34+
35+
assert sqlalchemy_profiler.stats["select"] == 0
36+
assert len(cache.cache) == 1
37+
38+
FromCache(cache).invalidate(Owner.objects.options(FromCache(cache)).filter_by(first_name="foo"))
39+
assert len(cache.cache) == 0
40+
41+
with sqlalchemy_profiler:
42+
Owner.objects.options(FromCache(cache)).filter_by(first_name="foo").one_or_none()
43+
assert sqlalchemy_profiler.stats["select"] == 1
44+
assert len(cache.cache) == 1
45+
46+
Owner.objects.options(FromCache(cache)).filter_by(first_name="foo").invalidate()
47+
assert len(cache.cache) == 0
48+
49+
50+
def test_model_cache_option(sqlalchemy_profiler): # noqa
51+
instance = UnCachedModel(
52+
cached=CachedModel(name="cached"),
53+
other_cached=[OtherCachedModel(name="other cached 1"), OtherCachedModel(name="other cached 2")],
54+
references=[CachedReference(name="ref1"), CachedReference(name="ref2")],
55+
)
56+
57+
db.add(instance)
58+
db.flush()
59+
pk = instance.pk
60+
cached_pk = instance.cached.pk
61+
db.expire_all()
62+
cache.cache.clear()
63+
64+
with sqlalchemy_profiler:
65+
instance = CachedModel.objects.get(cached_pk)
66+
assert sqlalchemy_profiler.stats["select"] == 1
67+
assert len(cache.cache) == 1
68+
69+
with sqlalchemy_profiler:
70+
instance = CachedModel.objects.get(cached_pk)
71+
assert sqlalchemy_profiler.stats["select"] == 0
72+
assert len(cache.cache) == 1
73+
74+
instance = UnCachedModel.objects.filter_by(pk=pk).one()
75+
db.refresh(instance)
76+
with sqlalchemy_profiler:
77+
assert instance.cached.pk == cached_pk
78+
assert sqlalchemy_profiler.stats["select"] == 0
79+
assert len(cache.cache) == 1 # many-to-one cached
80+
81+
instance = UnCachedModel.objects.filter_by(pk=pk).one()
82+
db.refresh(instance)
83+
with sqlalchemy_profiler:
84+
assert len(instance.other_cached) == 2
85+
assert sqlalchemy_profiler.stats["select"] == 1
86+
assert len(cache.cache) == 1 # one-to-many not cached
87+
88+
instance = UnCachedModel.objects.filter_by(pk=pk).one()
89+
db.refresh(instance)
90+
with sqlalchemy_profiler:
91+
assert len(instance.references) == 2
92+
assert sqlalchemy_profiler.stats["select"] == 1
93+
assert len(cache.cache) == 1 # many-to-many not cached

tests/testapp/models.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.core.exceptions import ValidationError
88
from django.core.validators import RegexValidator
99

10-
from django_sorcery.db import databases
10+
from django_sorcery.db import FromCache, databases
1111
from django_sorcery.db.models import autocoerce, autocoerce_properties
1212
from django_sorcery.db.query import Query
1313
from django_sorcery.validators import ValidateTogetherModelFields, ValidateUnique
@@ -16,6 +16,23 @@
1616
db = databases.get("test")
1717

1818

19+
class DummyCache(object):
20+
def __init__(self):
21+
self.cache = {}
22+
23+
def get_or_create(self, key, func, **kwargs):
24+
if key not in self.cache:
25+
self.cache[key] = func()
26+
27+
return self.cache[key]
28+
29+
def delete(self, key):
30+
del self.cache[key]
31+
32+
33+
cache = DummyCache()
34+
35+
1936
COLORS = ["", "red", "green", "blue", "silver", "pink"]
2037

2138

@@ -262,6 +279,36 @@ def clean(self, **kwargs):
262279
raise ValidationError("bad model")
263280

264281

282+
class CachedModel(db.Model):
283+
__caching_option__ = FromCache(cache)
284+
285+
pk = db.IntegerField(autoincrement=True, primary_key=True)
286+
name = db.CharField()
287+
288+
289+
class OtherCachedModel(db.Model):
290+
__caching_option__ = FromCache(cache)
291+
292+
pk = db.IntegerField(autoincrement=True, primary_key=True)
293+
name = db.CharField()
294+
295+
296+
class CachedReference(db.Model):
297+
__caching_option__ = FromCache(cache)
298+
299+
pk = db.IntegerField(autoincrement=True, primary_key=True)
300+
name = db.CharField()
301+
302+
303+
class UnCachedModel(db.Model):
304+
pk = db.IntegerField(autoincrement=True, primary_key=True)
305+
name = db.CharField()
306+
307+
cached = db.ManyToOne(CachedModel, backref=db.backref("uncached"))
308+
other_cached = db.OneToMany(OtherCachedModel, backref=db.backref("uncached"))
309+
references = db.ManyToMany(CachedReference, table_name="uncached_refs", backref=db.backref("uncached"))
310+
311+
265312
class ValidateUniqueModel(db.Model):
266313
pk = db.IntegerField(autoincrement=True, primary_key=True)
267314
name = db.CharField()

0 commit comments

Comments
 (0)