Skip to content

Commit 3140935

Browse files
committed
Allow Q objects to be used for upsert update condition and index predicate
1 parent 487db71 commit 3140935

File tree

8 files changed

+149
-60
lines changed

8 files changed

+149
-60
lines changed

docs/source/api_reference.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ API Reference
2929

3030
.. autoclass:: DateTimeEpoch
3131

32+
.. autoclass:: ExcludedCol
33+
3234
.. automodule:: psqlextra.indexes
3335

36+
.. autoclass:: UniqueIndex
3437
.. autoclass:: ConditionalUniqueIndex
3538
.. autoclass:: CaseInsensitiveUniqueIndex
3639

docs/source/conflict_handling.rst

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -155,35 +155,21 @@ A row level lock is acquired before evaluating the condition and proceeding with
155155
An expression that returns a value of type boolean. Only rows for which this expression returns true will be updated, although all rows will be locked when the ON CONFLICT DO UPDATE action is taken. Note that condition is evaluated last, after a conflict has been identified as a candidate to update.
156156

157157

158-
.. warning::
159-
160-
Always parameterize the input to avoid SQL injections.
161-
162-
Do:
163-
164-
.. code-block:: python
165-
166-
my_name = 'henk'
167-
RawSQL("name != %s", (my_name,))
168-
169-
Not:
170-
171-
.. code-block:: python
172-
173-
RawSQL("name != " + henk, tuple())
174-
175-
176158
.. code-block:: python
177159
178-
from django.db.models.expressions import RawSQL
160+
from psqlextra.expressions import CombinedExpression, ExcludedCol
179161
180162
pk = (
181163
MyModel
182164
.objects
183165
.on_conflict(
184166
['name'],
185167
ConflictAction.UPDATE,
186-
update_condition=RawSQL("priority >= EXCLUDED.priority"),
168+
update_condition=CombinedExpression(
169+
MyModel._meta.get_field('priority').get_col(MyModel._meta.db_table),
170+
'>',
171+
ExcludedCol('priority'),
172+
)
187173
)
188174
.insert(
189175
name='henk',
@@ -197,30 +183,18 @@ A row level lock is acquired before evaluating the condition and proceeding with
197183
print('condition was false-ish and no changes were made')
198184
199185
200-
When writing expressions, refer to the data you're trying to upsert with ``EXCLUDED``. Refer to the existing row by prefixing the name of the table:
201-
202-
.. code-block:: python
203-
204-
RawSQL(MyModel._meta.db_table + '.mycolumn = EXCLUDED.mycolumn')
205-
206-
You can use :meth:`~django:django.db.models.expressions.CombinedExpression` to build simple comparion expressions:
186+
When writing expressions, refer to the data you're trying to upsert with the :class:`psqlextra.expressions.ExcludedCol` expression.
207187

188+
Alternatively, with Django 3.1, :class:`~django:django.db.models.Q` objects can be used instead:
208189

209190
.. code-block:: python
210191
211-
from django.db.models import CombinedExpression, Col, Value
212-
213-
CombinedExpression(
214-
MyModel._meta.get_field('name').get_col(MyModel._meta.db_table)
215-
'=',
216-
Col('EXCLUDED', 'name'),
217-
)
192+
from django.db.models import Q
193+
from psqlextra.expressions import ExcludedCol
218194
219-
CombinedExpression(
220-
MyModel._meta.get_field('active').get_col(MyModel._meta.db_table)
221-
'=',
222-
Value(True),
223-
)
195+
Q(name=ExcludedCol('name'))
196+
Q(name__isnull=True)
197+
Q(name__gt=ExcludedCol('priority'))
224198
225199
226200
ConflictAction.NOTHING

docs/source/expressions.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,31 @@ Use the :class:`~psqlextra.expressions.IsNotNone` expression to perform somethin
8989
.values_list('name', flat=True)
9090
.first()
9191
)
92+
93+
94+
Excluded column
95+
---------------
96+
97+
Use the :class:`~psqlextra.expressions.ExcludedCol` expression when performing an upsert using `ON CONFLICT`_ to refer to a column/field in the data is about to be upserted.
98+
99+
PostgreSQL keeps that data to be upserted in a special table named `EXCLUDED`. This expression is used to refer to a column in that table.
100+
101+
.. code-block:: python
102+
103+
from django.db.models import Q
104+
from psqlextra.expressions import ExcludedCol
105+
106+
(
107+
MyModel
108+
.objects
109+
.on_conflict(
110+
['name'],
111+
ConflictAction.UPDATE,
112+
# translates to `priority > EXCLUDED.priority`
113+
update_condition=Q(priority__gt=ExcludedCol('priority')),
114+
)
115+
.insert(
116+
name='henk',
117+
priority=1,
118+
)
119+
)

psqlextra/compiler.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from collections.abc import Iterable
2+
from typing import Tuple, Union
3+
4+
import django
25

36
from django.core.exceptions import SuspiciousOperation
4-
from django.db.models import Expression, Model
7+
from django.db.models import Expression, Model, Q
58
from django.db.models.fields.related import RelatedField
69
from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler
710
from django.db.utils import ProgrammingError
@@ -155,20 +158,19 @@ def _rewrite_insert_on_conflict(
155158
rewritten_sql = f"{sql} ON CONFLICT {conflict_target}"
156159

157160
if index_predicate:
158-
if isinstance(index_predicate, Expression):
159-
expr_sql, expr_params = self.compile(index_predicate)
160-
rewritten_sql += f" WHERE {expr_sql}"
161-
params += tuple(expr_params)
162-
else:
163-
rewritten_sql += f" WHERE {index_predicate}"
161+
expr_sql, expr_params = self._compile_expression(index_predicate)
162+
rewritten_sql += f" WHERE {expr_sql}"
163+
params += tuple(expr_params)
164164

165165
rewritten_sql += f" DO {conflict_action}"
166166

167167
if conflict_action == "UPDATE":
168168
rewritten_sql += f" SET {update_columns}"
169169

170170
if update_condition:
171-
expr_sql, expr_params = self.compile(update_condition)
171+
expr_sql, expr_params = self._compile_expression(
172+
update_condition
173+
)
172174
rewritten_sql += f" WHERE {expr_sql}"
173175
params += tuple(expr_params)
174176

@@ -319,6 +321,27 @@ def _format_field_value(self, field_name) -> str:
319321
value,
320322
)
321323

324+
def _compile_expression(
325+
self, expression: Union[Expression, Q, str]
326+
) -> Tuple[str, tuple]:
327+
"""Compiles an expression, Q object or raw SQL string into SQL and
328+
tuple of parameters."""
329+
330+
if isinstance(expression, Q):
331+
if django.VERSION < (3, 1):
332+
raise SuspiciousOperation(
333+
"Q objects in psqlextra can only be used with Django 3.1 and newer"
334+
)
335+
336+
return self.query.build_where(expression).as_sql(
337+
self, self.connection
338+
)
339+
340+
elif isinstance(expression, Expression):
341+
return self.compile(expression)
342+
343+
return expression, tuple()
344+
322345
def _assert_valid_field(self, field_name: str):
323346
"""Asserts that a field with the specified name exists on the model and
324347
raises :see:SuspiciousOperation if it does not."""

psqlextra/expressions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,18 @@ def IsNotNone(*fields, default=None):
204204
default=expressions.Value(default),
205205
output_field=CharField(),
206206
)
207+
208+
209+
class ExcludedCol(expressions.Expression):
210+
"""References a column in PostgreSQL's special EXCLUDED column, which is
211+
used in upserts to refer to the data about to be inserted/updated.
212+
213+
See: https://www.postgresql.org/docs/9.5/sql-insert.html#SQL-ON-CONFLICT
214+
"""
215+
216+
def __init__(self, name: str):
217+
self.name = name
218+
219+
def as_sql(self, compiler, connection):
220+
quoted_name = connection.ops.quote_name(self.name)
221+
return f"EXCLUDED.{quoted_name}", tuple()

psqlextra/query.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from django.core.exceptions import SuspiciousOperation
66
from django.db import connections, models, router
7-
from django.db.models import Expression
7+
from django.db.models import Expression, Q
88
from django.db.models.fields import NOT_PROVIDED
99

1010
from .sql import PostgresInsertQuery, PostgresQuery
@@ -82,8 +82,8 @@ def on_conflict(
8282
self,
8383
fields: ConflictTarget,
8484
action: ConflictAction,
85-
index_predicate: Optional[Union[Expression, str]] = None,
86-
update_condition: Optional[Expression] = None,
85+
index_predicate: Optional[Union[Expression, Q, str]] = None,
86+
update_condition: Optional[Union[Expression, Q, str]] = None,
8787
):
8888
"""Sets the action to take when conflicts arise when attempting to
8989
insert/create a new row.
@@ -257,9 +257,9 @@ def upsert(
257257
self,
258258
conflict_target: ConflictTarget,
259259
fields: dict,
260-
index_predicate: Optional[Union[Expression, str]] = None,
260+
index_predicate: Optional[Union[Expression, Q, str]] = None,
261261
using: Optional[str] = None,
262-
update_condition: Optional[Expression] = None,
262+
update_condition: Optional[Union[Expression, Q, str]] = None,
263263
) -> int:
264264
"""Creates a new record or updates the existing one with the specified
265265
data.
@@ -298,9 +298,9 @@ def upsert_and_get(
298298
self,
299299
conflict_target: ConflictTarget,
300300
fields: dict,
301-
index_predicate: Optional[Union[Expression, str]] = None,
301+
index_predicate: Optional[Union[Expression, Q, str]] = None,
302302
using: Optional[str] = None,
303-
update_condition: Optional[Expression] = None,
303+
update_condition: Optional[Union[Expression, Q, str]] = None,
304304
):
305305
"""Creates a new record or updates the existing one with the specified
306306
data and then gets the row.
@@ -340,10 +340,10 @@ def bulk_upsert(
340340
self,
341341
conflict_target: ConflictTarget,
342342
rows: Iterable[Dict],
343-
index_predicate: Optional[Union[Expression, str]] = None,
343+
index_predicate: Optional[Union[Expression, Q, str]] = None,
344344
return_model: bool = False,
345345
using: Optional[str] = None,
346-
update_condition: Optional[Expression] = None,
346+
update_condition: Optional[Union[Expression, Q, str]] = None,
347347
):
348348
"""Creates a set of new records or updates the existing ones with the
349349
specified data.

tests/test_query_values.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,16 @@ def test_query_values_list_hstore_key(model, modelobj):
5151
assert result[1] == modelobj.title["ar"]
5252

5353

54+
@pytest.mark.skipif(
55+
django.VERSION < (2, 1), reason="requires django 2.1 or newer"
56+
)
5457
def test_query_values_hstore_key_through_fk():
5558
"""Tests whether selecting a single key from a :see:HStoreField using the
5659
query set's .values() works properly when there's a foreign key
5760
relationship involved."""
5861

5962
# this starting working in django 2.1
6063
# see: https://github.com/django/django/commit/20bab2cf9d02a5c6477d8aac066a635986e0d3f3
61-
if django.VERSION < (2, 1):
62-
return
6364

6465
fmodel = get_fake_model({"name": HStoreField()})
6566

tests/test_upsert.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import django
2+
import pytest
3+
14
from django.db import models
5+
from django.db.models import Q
26
from django.db.models.expressions import CombinedExpression, Value
37

8+
from psqlextra.expressions import ExcludedCol
49
from psqlextra.fields import HStoreField
510

611
from .fake_model import get_fake_model
@@ -78,7 +83,7 @@ def test_upsert_explicit_pk():
7883

7984

8085
def test_upsert_with_update_condition():
81-
"""Tests that a custom expression can be passed as an update condition."""
86+
"""Tests that an expression can be used as an upsert update condition."""
8287

8388
model = get_fake_model(
8489
{
@@ -96,7 +101,7 @@ def test_upsert_with_update_condition():
96101
update_condition=CombinedExpression(
97102
model._meta.get_field("active").get_col(model._meta.db_table),
98103
"=",
99-
Value(True),
104+
ExcludedCol("active"),
100105
),
101106
fields=dict(name="joe", priority=2, active=True),
102107
)
@@ -122,6 +127,46 @@ def test_upsert_with_update_condition():
122127
assert obj1.active
123128

124129

130+
@pytest.mark.skipif(
131+
django.VERSION < (3, 1), reason="requires django 3.1 or newer"
132+
)
133+
def test_upsert_with_update_condition_with_q_object():
134+
"""Tests that :see:Q objects can be used as an upsert update condition."""
135+
136+
model = get_fake_model(
137+
{
138+
"name": models.TextField(unique=True),
139+
"priority": models.IntegerField(),
140+
"active": models.BooleanField(),
141+
}
142+
)
143+
144+
obj1 = model.objects.create(name="joe", priority=1, active=False)
145+
146+
# should not return anything because no rows were affected
147+
assert not model.objects.upsert(
148+
conflict_target=["name"],
149+
update_condition=Q(active=ExcludedCol("active")),
150+
fields=dict(name="joe", priority=2, active=True),
151+
)
152+
153+
obj1.refresh_from_db()
154+
assert obj1.priority == 1
155+
assert not obj1.active
156+
157+
# should return something because one row was affected
158+
obj1_pk = model.objects.upsert(
159+
conflict_target=["name"],
160+
update_condition=Q(active=Value(False)),
161+
fields=dict(name="joe", priority=2, active=True),
162+
)
163+
164+
obj1.refresh_from_db()
165+
assert obj1.pk == obj1_pk
166+
assert obj1.priority == 2
167+
assert obj1.active
168+
169+
125170
def test_upsert_and_get_applies_converters():
126171
"""Tests that converters are properly applied when using upsert_and_get."""
127172

0 commit comments

Comments
 (0)