Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.math import Power
from django.db.models.lookups import IsNull, Lookup
from django.db.models.lookups import IsNull
from django.db.models.sql import compiler
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
from django.db.models.sql.datastructures import BaseTable
from django.db.models.sql.where import AND, WhereNode
from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .expressions.search import SearchExpression, SearchVector
from .query import MongoQuery, wrap_database_errors
from .query_utils import is_direct_value
from .query_utils import is_constant_value


class SQLCompiler(compiler.SQLCompiler):
Expand Down Expand Up @@ -661,27 +661,61 @@ def get_combinator_queries(self):
combinator_pipeline.append({"$unset": "_id"})
return combinator_pipeline

def _get_pushable_conditions(self):
def collect_pushable(expr, negated=False):
if expr is None or isinstance(expr, NothingNode):
return {}
if isinstance(expr, WhereNode):
negated ^= expr.negated
pushable_expressions = [
collect_pushable(sub_expr, negated=negated)
for sub_expr in expr.children
if sub_expr is not None
]
operator = expr.connector
if operator == XOR:
return {}
if negated:
operator = OR if operator == AND else AND
alias_children = defaultdict(list)
for pe in pushable_expressions:
for alias, expressions in pe.items():
alias_children[alias].append(expressions)
result = {}
for alias, children in alias_children.items():
result[alias] = WhereNode(
children=children,
negated=False,
connector=operator,
)
if operator == AND:
return result
shared_alias = (
set.intersection(*(set(pe) for pe in pushable_expressions))
if pushable_expressions
else set()
)
return {k: v for k, v in result.items() if k in shared_alias}
if isinstance(expr.lhs, Col) and (
is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False)
):
alias = expr.lhs.alias
expr = WhereNode(children=[expr], negated=negated)
return {alias: expr}
return {}

return collect_pushable(self.get_where())

def get_lookup_pipeline(self):
result = []
# To improve join performance, push conditions (filters) from the
# WHERE ($match) clause to the JOIN ($lookup) clause.
where = self.get_where()
pushed_filters = defaultdict(list)
for expr in where.children if where and where.connector == AND else ():
# Push only basic lookups; no subqueries or complex conditions.
# To avoid duplication across subqueries, only use the LHS target
# table.
if (
isinstance(expr, Lookup)
and isinstance(expr.lhs, Col)
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col)))
):
pushed_filters[expr.lhs.alias].append(expr)
pushed_filters = self._get_pushable_conditions()
for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias] or self.collection_name == alias:
continue
result += self.query.alias_map[alias].as_mql(
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
self, self.connection, pushed_filters.get(alias)
)
return result

Expand Down
3 changes: 2 additions & 1 deletion django_mongodb_backend/query_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.core.exceptions import FullResultSet
from django.db.models import F
from django.db.models.aggregates import Aggregate
from django.db.models.expressions import CombinedExpression, Func, Value
from django.db.models.sql.query import Query
Expand Down Expand Up @@ -67,7 +68,7 @@ def is_constant_value(value):
else:
constants_sub_expressions = True
constants_sub_expressions = constants_sub_expressions and not (
isinstance(value, Query)
isinstance(value, Query | F)
or value.contains_aggregate
or value.contains_over_clause
or value.contains_column_references
Expand Down
30 changes: 28 additions & 2 deletions django_mongodb_backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
class MongoTestCaseMixin:
maxDiff = None

COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"}

@staticmethod
def _normalize_query(obj):
if isinstance(obj, dict):
normalized = {}
for k, v in obj.items():
if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list):
# Only sort for commutative operators
normalized[k] = sorted(
(MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x)
)
else:
normalized[k] = MongoTestCaseMixin._normalize_query(v)
return normalized

if isinstance(obj, list):
# Lists not under commutative ops keep their order
return [MongoTestCaseMixin._normalize_query(i) for i in obj]

return obj

def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
"""
Assert that the logged query is equal to:
Expand All @@ -16,6 +38,10 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
self.assertEqual(operator, "aggregate")
self.assertEqual(collection, expected_collection)
self.assertEqual(
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
expected_pipeline,
self._normalize_query(
eval( # noqa: S307
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
)
),
self._normalize_query(expected_pipeline),
)
Loading