Skip to content

Commit 7d6386e

Browse files
committed
Push conditions into lookup when is possible.
1 parent cf0ef15 commit 7d6386e

File tree

4 files changed

+88
-20
lines changed

4 files changed

+88
-20
lines changed

django_mongodb_backend/compiler.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull, Lookup
12+
from django.db.models.lookups import IsNull
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16-
from django.db.models.sql.where import AND, WhereNode
16+
from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

2020
from .expressions.search import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
22-
from .query_utils import is_direct_value
22+
from .query_utils import is_constant_value
2323

2424

2525
class SQLCompiler(compiler.SQLCompiler):
@@ -661,27 +661,61 @@ def get_combinator_queries(self):
661661
combinator_pipeline.append({"$unset": "_id"})
662662
return combinator_pipeline
663663

664+
def _get_pushable_conditions(self):
665+
def collect_pushable(expr, negated=False):
666+
if expr is None or isinstance(expr, NothingNode):
667+
return {}
668+
if isinstance(expr, WhereNode):
669+
negated ^= expr.negated
670+
pushable_expressions = [
671+
collect_pushable(sub_expr, negated=negated)
672+
for sub_expr in expr.children
673+
if sub_expr is not None
674+
]
675+
operator = expr.connector
676+
if operator == XOR:
677+
return {}
678+
if negated:
679+
operator = OR if operator == AND else AND
680+
alias_children = defaultdict(list)
681+
for pe in pushable_expressions:
682+
for alias, expressions in pe.items():
683+
alias_children[alias].append(expressions)
684+
result = {}
685+
for alias, children in alias_children.items():
686+
result[alias] = WhereNode(
687+
children=children,
688+
negated=False,
689+
connector=operator,
690+
)
691+
if operator == AND:
692+
return result
693+
shared_alias = (
694+
set.intersection(*(set(pe) for pe in pushable_expressions))
695+
if pushable_expressions
696+
else set()
697+
)
698+
return {k: v for k, v in result.items() if k in shared_alias}
699+
if isinstance(expr.lhs, Col) and (
700+
is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False)
701+
):
702+
alias = expr.lhs.alias
703+
expr = WhereNode(children=[expr], negated=negated)
704+
return {alias: expr}
705+
return {}
706+
707+
return collect_pushable(self.get_where())
708+
664709
def get_lookup_pipeline(self):
665710
result = []
666711
# To improve join performance, push conditions (filters) from the
667712
# WHERE ($match) clause to the JOIN ($lookup) clause.
668-
where = self.get_where()
669-
pushed_filters = defaultdict(list)
670-
for expr in where.children if where and where.connector == AND else ():
671-
# Push only basic lookups; no subqueries or complex conditions.
672-
# To avoid duplication across subqueries, only use the LHS target
673-
# table.
674-
if (
675-
isinstance(expr, Lookup)
676-
and isinstance(expr.lhs, Col)
677-
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col)))
678-
):
679-
pushed_filters[expr.lhs.alias].append(expr)
713+
pushed_filters = self._get_pushable_conditions()
680714
for alias in tuple(self.query.alias_map):
681715
if not self.query.alias_refcount[alias] or self.collection_name == alias:
682716
continue
683717
result += self.query.alias_map[alias].as_mql(
684-
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
718+
self, self.connection, pushed_filters.get(alias)
685719
)
686720
return result
687721

django_mongodb_backend/query_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core.exceptions import FullResultSet
2+
from django.db.models import F
23
from django.db.models.aggregates import Aggregate
34
from django.db.models.expressions import CombinedExpression, Func, Value
45
from django.db.models.sql.query import Query
@@ -67,7 +68,7 @@ def is_constant_value(value):
6768
else:
6869
constants_sub_expressions = True
6970
constants_sub_expressions = constants_sub_expressions and not (
70-
isinstance(value, Query)
71+
isinstance(value, Query | F)
7172
or value.contains_aggregate
7273
or value.contains_over_clause
7374
or value.contains_column_references

django_mongodb_backend/test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@
66
class MongoTestCaseMixin:
77
maxDiff = None
88

9+
COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"}
10+
11+
@staticmethod
12+
def _normalize_query(obj):
13+
if isinstance(obj, dict):
14+
normalized = {}
15+
for k, v in obj.items():
16+
if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list):
17+
# Only sort for commutative operators
18+
normalized[k] = sorted(
19+
(MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x)
20+
)
21+
else:
22+
normalized[k] = MongoTestCaseMixin._normalize_query(v)
23+
return normalized
24+
25+
if isinstance(obj, list):
26+
# Lists not under commutative ops keep their order
27+
return [MongoTestCaseMixin._normalize_query(i) for i in obj]
28+
29+
return obj
30+
931
def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1032
"""
1133
Assert that the logged query is equal to:
@@ -16,6 +38,10 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1638
self.assertEqual(operator, "aggregate")
1739
self.assertEqual(collection, expected_collection)
1840
self.assertEqual(
19-
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
20-
expected_pipeline,
41+
self._normalize_query(
42+
eval( # noqa: S307
43+
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
44+
)
45+
),
46+
self._normalize_query(expected_pipeline),
2147
)

tests/queries_/test_mql.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,14 @@ def test_negated_related_filter_is_not_pushable(self):
281281
"pipeline": [
282282
{
283283
"$match": {
284-
"$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]}
284+
"$and": [
285+
{
286+
"$expr": {
287+
"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]
288+
}
289+
},
290+
{"$nor": [{"name": "John"}]},
291+
]
285292
}
286293
}
287294
],

0 commit comments

Comments
 (0)