Skip to content

Commit 215d3be

Browse files
WaVEVtimgraham
authored andcommitted
INTPYTHON-827 Prevent model update values from being interpreted as expressions
1 parent 88f2631 commit 215d3be

File tree

7 files changed

+104
-5
lines changed

7 files changed

+104
-5
lines changed

django_mongodb_backend/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,9 @@ def execute_sql(self, result_type):
883883
f"{field.__class__.__name__}."
884884
)
885885
prepared = field.get_db_prep_save(value, connection=self.connection)
886-
if hasattr(value, "as_mql"):
886+
if is_direct_value(value):
887+
prepared = {"$literal": prepared}
888+
else:
887889
prepared = prepared.as_mql(self, self.connection, as_expr=True)
888890
values[field.column] = prepared
889891
try:

django_mongodb_backend/expressions/builtins.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def when(self, compiler, connection):
211211

212212
def value(self, compiler, connection, as_expr=False): # noqa: ARG001
213213
value = self.value
214-
if isinstance(value, (list, int, str)) and as_expr:
215-
# Wrap lists, numbers, and strings in $literal to avoid ambiguity when
216-
# Value is used in aggregate() or update_many()'s $set.
214+
if isinstance(value, (list, int, str, dict, tuple)) and as_expr:
215+
# Wrap lists, numbers, strings, dicts, and tuples in $literal to avoid
216+
# ambiguity when Value is used in aggregate() or update_many()'s $set.
217217
return {"$literal": value}
218218
if isinstance(value, Decimal):
219219
return Decimal128(value)

django_mongodb_backend/test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,16 @@ def assertInsertQuery(self, query, expected_collection, expected_documents):
2828
self.assertEqual(operator, "insert_many")
2929
self.assertEqual(collection, expected_collection)
3030
self.assertEqual(eval(pipeline[:-1], self.query_types), expected_documents) # noqa: S307
31+
32+
def assertUpdateQuery(self, query, expected_collection, expected_condition, expected_set):
33+
"""
34+
Assert that the logged query is equal to:
35+
db.{expected_collection}.update_many({expected_condition}, {expected_set})
36+
"""
37+
prefix, pipeline = query.split("(", 1)
38+
_, collection, operator = prefix.split(".")
39+
self.assertEqual(operator, "update_many")
40+
self.assertEqual(collection, expected_collection)
41+
condition, set_expression = eval(pipeline[:-1], self.query_types, {}) # noqa: S307
42+
self.assertEqual(condition, expected_condition)
43+
self.assertEqual(set_expression, expected_set)

docs/releases/5.2.x.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ Bug fixes
1919
pipeline.
2020
- Made :class:`~django.db.models.Value` wrap strings in ``$literal`` to
2121
prevent dollar-prefixed strings from being interpreted as expressions.
22+
Also wrapped dictionaries and tuples to prevent the same for them.
23+
- Made model update queries wrap values in ``$literal`` to prevent values from
24+
being interpreted as expressions.
2225

2326
Performance improvements
2427
------------------------

tests/basic_/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,11 @@ class Author(models.Model):
66

77
def __str__(self):
88
return self.name
9+
10+
11+
class Blob(models.Model):
12+
name = models.CharField(max_length=10)
13+
data = models.JSONField(null=True)
14+
15+
def __str__(self):
16+
return self.name

tests/basic_/test_escaping.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from django_mongodb_backend.test import MongoTestCaseMixin
99

10-
from .models import Author
10+
from .models import Author, Blob
1111

1212

1313
class ModelCreationTests(MongoTestCaseMixin, TestCase):
@@ -23,6 +23,69 @@ def test_dollar_prefixed_string(self):
2323
)
2424

2525

26+
class ModelUpdateTests(MongoTestCaseMixin, TestCase):
27+
"""
28+
$-prefixed strings and dict/tuples that could be interpreted as expressions
29+
are escaped in the queries that update model instances.
30+
"""
31+
32+
def test_dollar_prefixed_string(self):
33+
obj = Author.objects.create(name="foobar")
34+
obj.name = "$updated"
35+
with self.assertNumQueries(1) as ctx:
36+
obj.save()
37+
obj.refresh_from_db()
38+
self.assertEqual(obj.name, "$updated")
39+
self.assertUpdateQuery(
40+
ctx.captured_queries[0]["sql"],
41+
"basic__author",
42+
{"_id": obj.id},
43+
[{"$set": {"name": {"$literal": "$updated"}}}],
44+
)
45+
46+
def test_dollar_prefixed_value(self):
47+
obj = Author.objects.create(name="foobar")
48+
obj.name = Value("$updated")
49+
with self.assertNumQueries(1) as ctx:
50+
obj.save()
51+
obj.refresh_from_db()
52+
self.assertEqual(obj.name, "$updated")
53+
self.assertUpdateQuery(
54+
ctx.captured_queries[0]["sql"],
55+
"basic__author",
56+
{"_id": obj.id},
57+
[{"$set": {"name": {"$literal": "$updated"}}}],
58+
)
59+
60+
def test_dict(self):
61+
obj = Blob.objects.create(name="foobar")
62+
obj.data = {"$concat": ["$name", "-", "$name"]}
63+
obj.save()
64+
obj.refresh_from_db()
65+
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})
66+
67+
def test_dict_value(self):
68+
obj = Blob.objects.create(name="foobar", data={})
69+
obj.data = Value({"$concat": ["$name", "-", "$name"]})
70+
obj.save()
71+
obj.refresh_from_db()
72+
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})
73+
74+
def test_tuple(self):
75+
obj = Blob.objects.create(name="foobar")
76+
obj.data = ("$name", "-", "$name")
77+
obj.save()
78+
obj.refresh_from_db()
79+
self.assertEqual(obj.data, ["$name", "-", "$name"])
80+
81+
def test_tuple_value(self):
82+
obj = Blob.objects.create(name="foobar")
83+
obj.data = Value(("$name", "-", "$name"))
84+
obj.save()
85+
obj.refresh_from_db()
86+
self.assertEqual(obj.data, ["$name", "-", "$name"])
87+
88+
2689
class AnnotationTests(MongoTestCaseMixin, TestCase):
2790
def test_dollar_prefixed_value(self):
2891
"""Value() escapes dollar prefixed strings."""

tests/expressions_/test_value.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def test_datetime(self):
2323
def test_decimal(self):
2424
self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0"))
2525

26+
def test_dict_expr(self):
27+
self.assertEqual(
28+
Value({"$foo": "$bar"}).as_mql(None, None, as_expr=True), {"$literal": {"$foo": "$bar"}}
29+
)
30+
2631
def test_list(self):
2732
self.assertEqual(Value([1, 2]).as_mql(None, None, as_expr=True), {"$literal": [1, 2]})
2833

@@ -44,6 +49,11 @@ def test_str(self):
4449
def test_str_expr(self):
4550
self.assertEqual(Value("$foo").as_mql(None, None, as_expr=True), {"$literal": "$foo"})
4651

52+
def test_tuple_expr(self):
53+
self.assertEqual(
54+
Value(("$foo", "$bar")).as_mql(None, None, as_expr=True), {"$literal": ("$foo", "$bar")}
55+
)
56+
4757
def test_uuid(self):
4858
value = uuid.UUID(int=1)
4959
self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")

0 commit comments

Comments
 (0)