Skip to content

Commit 4db88f6

Browse files
committed
Push conditions into lookup when is possible.
1 parent 06894c5 commit 4db88f6

File tree

4 files changed

+88
-20
lines changed

4 files changed

+88
-20
lines changed

django_mongodb_backend/compiler.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull, Lookup
12+
from django.db.models.lookups import IsNull
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16-
from django.db.models.sql.where import AND, WhereNode
16+
from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

2020
from .expressions.search import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
22-
from .query_utils import is_direct_value
22+
from .query_utils import is_constant_value
2323

2424

2525
class SQLCompiler(compiler.SQLCompiler):
@@ -658,27 +658,61 @@ def get_combinator_queries(self):
658658
combinator_pipeline.append({"$unset": "_id"})
659659
return combinator_pipeline
660660

661+
def _get_pushable_conditions(self):
662+
def collect_pushable(expr, negated=False):
663+
if expr is None or isinstance(expr, NothingNode):
664+
return {}
665+
if isinstance(expr, WhereNode):
666+
negated ^= expr.negated
667+
pushable_expressions = [
668+
collect_pushable(sub_expr, negated=negated)
669+
for sub_expr in expr.children
670+
if sub_expr is not None
671+
]
672+
operator = expr.connector
673+
if operator == XOR:
674+
return {}
675+
if negated:
676+
operator = OR if operator == AND else AND
677+
alias_children = defaultdict(list)
678+
for pe in pushable_expressions:
679+
for alias, expressions in pe.items():
680+
alias_children[alias].append(expressions)
681+
result = {}
682+
for alias, children in alias_children.items():
683+
result[alias] = WhereNode(
684+
children=children,
685+
negated=False,
686+
connector=operator,
687+
)
688+
if operator == AND:
689+
return result
690+
shared_alias = (
691+
set.intersection(*(set(pe) for pe in pushable_expressions))
692+
if pushable_expressions
693+
else set()
694+
)
695+
return {k: v for k, v in result.items() if k in shared_alias}
696+
if isinstance(expr.lhs, Col) and (
697+
is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False)
698+
):
699+
alias = expr.lhs.alias
700+
expr = WhereNode(children=[expr], negated=negated)
701+
return {alias: expr}
702+
return {}
703+
704+
return collect_pushable(self.get_where())
705+
661706
def get_lookup_pipeline(self):
662707
result = []
663708
# To improve join performance, push conditions (filters) from the
664709
# WHERE ($match) clause to the JOIN ($lookup) clause.
665-
where = self.get_where()
666-
pushed_filters = defaultdict(list)
667-
for expr in where.children if where and where.connector == AND else ():
668-
# Push only basic lookups; no subqueries or complex conditions.
669-
# To avoid duplication across subqueries, only use the LHS target
670-
# table.
671-
if (
672-
isinstance(expr, Lookup)
673-
and isinstance(expr.lhs, Col)
674-
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col)))
675-
):
676-
pushed_filters[expr.lhs.alias].append(expr)
710+
pushed_filters = self._get_pushable_conditions()
677711
for alias in tuple(self.query.alias_map):
678712
if not self.query.alias_refcount[alias] or self.collection_name == alias:
679713
continue
680714
result += self.query.alias_map[alias].as_mql(
681-
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
715+
self, self.connection, pushed_filters.get(alias)
682716
)
683717
return result
684718

django_mongodb_backend/query_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core.exceptions import FullResultSet
2+
from django.db.models import F
23
from django.db.models.aggregates import Aggregate
34
from django.db.models.expressions import CombinedExpression, Func, Value
45
from django.db.models.sql.query import Query
@@ -67,7 +68,7 @@ def is_constant_value(value):
6768
else:
6869
constants_sub_expressions = True
6970
constants_sub_expressions = constants_sub_expressions and not (
70-
isinstance(value, Query)
71+
isinstance(value, Query | F)
7172
or value.contains_aggregate
7273
or value.contains_over_clause
7374
or value.contains_column_references

django_mongodb_backend/test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@
66
class MongoTestCaseMixin:
77
maxDiff = None
88

9+
COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"}
10+
11+
@staticmethod
12+
def _normalize_query(obj):
13+
if isinstance(obj, dict):
14+
normalized = {}
15+
for k, v in obj.items():
16+
if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list):
17+
# Only sort for commutative operators
18+
normalized[k] = sorted(
19+
(MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x)
20+
)
21+
else:
22+
normalized[k] = MongoTestCaseMixin._normalize_query(v)
23+
return normalized
24+
25+
if isinstance(obj, list):
26+
# Lists not under commutative ops keep their order
27+
return [MongoTestCaseMixin._normalize_query(i) for i in obj]
28+
29+
return obj
30+
931
def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1032
"""
1133
Assert that the logged query is equal to:
@@ -16,6 +38,10 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1638
self.assertEqual(operator, "aggregate")
1739
self.assertEqual(collection, expected_collection)
1840
self.assertEqual(
19-
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
20-
expected_pipeline,
41+
self._normalize_query(
42+
eval( # noqa: S307
43+
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
44+
)
45+
),
46+
self._normalize_query(expected_pipeline),
2147
)

tests/queries_/test_mql.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,14 @@ def test_negated_related_filter_is_not_pushable(self):
281281
"pipeline": [
282282
{
283283
"$match": {
284-
"$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]}
284+
"$and": [
285+
{
286+
"$expr": {
287+
"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]
288+
}
289+
},
290+
{"$nor": [{"name": "John"}]},
291+
]
285292
}
286293
}
287294
],

0 commit comments

Comments
 (0)