diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 888362c6b..2f7bdd404 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -9,17 +9,17 @@ from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power -from django.db.models.lookups import IsNull, Lookup +from django.db.models.lookups import IsNull from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE from django.db.models.sql.datastructures import BaseTable -from django.db.models.sql.where import AND, WhereNode +from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING from .expressions.search import SearchExpression, SearchVector from .query import MongoQuery, wrap_database_errors -from .query_utils import is_direct_value +from .query_utils import is_constant_value class SQLCompiler(compiler.SQLCompiler): @@ -661,27 +661,61 @@ def get_combinator_queries(self): combinator_pipeline.append({"$unset": "_id"}) return combinator_pipeline + def _get_pushable_conditions(self): + def collect_pushable(expr, negated=False): + if expr is None or isinstance(expr, NothingNode): + return {} + if isinstance(expr, WhereNode): + negated ^= expr.negated + pushable_expressions = [ + collect_pushable(sub_expr, negated=negated) + for sub_expr in expr.children + if sub_expr is not None + ] + operator = expr.connector + if operator == XOR: + return {} + if negated: + operator = OR if operator == AND else AND + alias_children = defaultdict(list) + for pe in pushable_expressions: + for alias, expressions in pe.items(): + alias_children[alias].append(expressions) + result = {} + for alias, children in alias_children.items(): + result[alias] = WhereNode( + children=children, + negated=False, + connector=operator, + ) + if operator == AND: + return result + shared_alias = ( + set.intersection(*(set(pe) for pe in pushable_expressions)) + if pushable_expressions + else set() + ) + return {k: v for k, v in result.items() if k in shared_alias} + if isinstance(expr.lhs, Col) and ( + is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False) + ): + alias = expr.lhs.alias + expr = WhereNode(children=[expr], negated=negated) + return {alias: expr} + return {} + + return collect_pushable(self.get_where()) + def get_lookup_pipeline(self): result = [] # To improve join performance, push conditions (filters) from the # WHERE ($match) clause to the JOIN ($lookup) clause. - where = self.get_where() - pushed_filters = defaultdict(list) - for expr in where.children if where and where.connector == AND else (): - # Push only basic lookups; no subqueries or complex conditions. - # To avoid duplication across subqueries, only use the LHS target - # table. - if ( - isinstance(expr, Lookup) - and isinstance(expr.lhs, Col) - and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col))) - ): - pushed_filters[expr.lhs.alias].append(expr) + pushed_filters = self._get_pushable_conditions() for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias] or self.collection_name == alias: continue result += self.query.alias_map[alias].as_mql( - self, self.connection, WhereNode(pushed_filters[alias], connector=AND) + self, self.connection, pushed_filters.get(alias) ) return result diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index fa04feb75..ea892ec9f 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,4 +1,5 @@ from django.core.exceptions import FullResultSet +from django.db.models import F from django.db.models.aggregates import Aggregate from django.db.models.expressions import CombinedExpression, Func, Value from django.db.models.sql.query import Query @@ -67,7 +68,7 @@ def is_constant_value(value): else: constants_sub_expressions = True constants_sub_expressions = constants_sub_expressions and not ( - isinstance(value, Query) + isinstance(value, Query | F) or value.contains_aggregate or value.contains_over_clause or value.contains_column_references diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index ee35b4e21..094df7e7d 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -6,6 +6,28 @@ class MongoTestCaseMixin: maxDiff = None + COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"} + + @staticmethod + def _normalize_query(obj): + if isinstance(obj, dict): + normalized = {} + for k, v in obj.items(): + if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list): + # Only sort for commutative operators + normalized[k] = sorted( + (MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x) + ) + else: + normalized[k] = MongoTestCaseMixin._normalize_query(v) + return normalized + + if isinstance(obj, list): + # Lists not under commutative ops keep their order + return [MongoTestCaseMixin._normalize_query(i) for i in obj] + + return obj + def assertAggregateQuery(self, query, expected_collection, expected_pipeline): """ Assert that the logged query is equal to: @@ -16,6 +38,10 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): self.assertEqual(operator, "aggregate") self.assertEqual(collection, expected_collection) self.assertEqual( - eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 - expected_pipeline, + self._normalize_query( + eval( # noqa: S307 + pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {} + ) + ), + self._normalize_query(expected_pipeline), ) diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index e8837bf8a..07e1ee2bd 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -281,7 +281,14 @@ def test_negated_related_filter_is_not_pushable(self): "pipeline": [ { "$match": { - "$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]} + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"$nor": [{"name": "John"}]}, + ] } } ], @@ -742,3 +749,333 @@ def test_or_with_mixed_pushable_and_non_pushable_fields(self): {"$match": {"$or": [{"queries__reader.name": "Alice"}, {"name": "Central"}]}}, ], ) + + def test_double_negation_pushdown(self): + a1 = Author.objects.create(name="Alice") + a2 = Author.objects.create(name="Bob") + b1 = Book.objects.create(title="Book1", author=a1, isbn="111") + Book.objects.create(title="Book2", author=a2, isbn="222") + b3 = Book.objects.create(title="Book3", author=a1, isbn="333") + expected = [b1, b3] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + Book.objects.filter(~(~models.Q(author__name="Alice") | models.Q(title="Book4"))), + expected, + ) + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "queries__book", + [ + { + "$lookup": { + "from": "queries__author", + "let": {"parent__field__0": "$author_id"}, + "pipeline": [ + { + "$match": { + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] + } + } + ], + "as": "queries__author", + } + }, + {"$unwind": "$queries__author"}, + { + "$match": { + "$nor": [ + { + "$or": [ + {"$nor": [{"queries__author.name": "Alice"}]}, + {"title": "Book4"}, + ] + } + ] + } + }, + ], + ) + + def test_partial_or_pushdown(self): + a1 = Author.objects.create(name="Alice") + a2 = Author.objects.create(name="Bob") + a3 = Author.objects.create(name="Charlie") + b1 = Book.objects.create(title="B1", author=a1, isbn="111") + b2 = Book.objects.create(title="B2", author=a2, isbn="111") + Book.objects.create(title="B3", author=a3, isbn="222") + condition = models.Q(author__name="Alice") | ( + models.Q(author__name="Bob") & models.Q(isbn="111") + ) + expected = [b1, b2] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual(list(Book.objects.filter(condition)), expected) + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "queries__book", + [ + { + "$lookup": { + "as": "queries__author", + "from": "queries__author", + "let": {"parent__field__0": "$author_id"}, + "pipeline": [ + { + "$match": { + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"$or": [{"name": "Alice"}, {"name": "Bob"}]}, + ] + } + } + ], + } + }, + {"$unwind": "$queries__author"}, + { + "$match": { + "$or": [ + {"queries__author.name": "Alice"}, + {"$and": [{"queries__author.name": "Bob"}, {"isbn": "111"}]}, + ] + } + }, + ], + ) + + def test_multiple_ors_with_partial_pushdown(self): + a1 = Author.objects.create(name="Alice") + a2 = Author.objects.create(name="Bob") + a3 = Author.objects.create(name="Charlie") + a4 = Author.objects.create(name="David") + b1 = Book.objects.create(title="B1", author=a1, isbn="111") + b2 = Book.objects.create(title="B2", author=a1, isbn="222") + b3 = Book.objects.create(title="B3", author=a2, isbn="333") + b4 = Book.objects.create(title="B4", author=a3, isbn="333") + Book.objects.create(title="B5", author=a4, isbn="444") + + left = models.Q(author__name="Alice") & (models.Q(isbn="111") | models.Q(isbn="222")) + right = (models.Q(author__name="Bob") | models.Q(author__name="Charlie")) & models.Q( + isbn="333" + ) + condition = left | right + + expected = [b1, b2, b3, b4] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual(list(Book.objects.filter(condition)), expected) + + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "queries__book", + [ + { + "$lookup": { + "as": "queries__author", + "from": "queries__author", + "let": {"parent__field__0": "$author_id"}, + "pipeline": [ + { + "$match": { + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + { + "$or": [ + {"name": "Alice"}, + {"$or": [{"name": "Bob"}, {"name": "Charlie"}]}, + ] + }, + ] + } + } + ], + } + }, + {"$unwind": "$queries__author"}, + { + "$match": { + "$or": [ + { + "$and": [ + {"queries__author.name": "Alice"}, + {"$or": [{"isbn": "111"}, {"isbn": "222"}]}, + ] + }, + { + "$and": [ + { + "$or": [ + {"queries__author.name": "Bob"}, + {"queries__author.name": "Charlie"}, + ] + }, + {"isbn": "333"}, + ] + }, + ] + } + }, + ], + ) + + def test_self_join_tag_three_levels_none_pushable(self): + t1 = Tag.objects.create(name="T1") + t2 = Tag.objects.create(name="T2", parent=t1) + t3 = Tag.objects.create(name="T3", parent=t2) + Tag.objects.create(name="T4", parent=t3) + Tag.objects.create(name="T5", parent=t1) + t6 = Tag.objects.create(name="T6", parent=t2) + cond = ( + models.Q(name="T1") | models.Q(parent__name="T2") | models.Q(parent__parent__name="T3") + ) + expected = [t1, t3, t6] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual(list(Tag.objects.filter(cond)), expected) + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "queries__tag", + # Django translate this kind of queries into left outer join + [ + { + "$lookup": { + "as": "T2", + "from": "queries__tag", + "let": {"parent__field__0": "$parent_id"}, + "pipeline": [ + { + "$match": { + "$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]} + } + } + ], + } + }, + { + "$set": { + "T2": { + "$cond": { + "else": "$T2", + "if": { + "$or": [ + {"$eq": [{"$type": "$T2"}, "missing"]}, + {"$eq": [{"$size": "$T2"}, 0]}, + ] + }, + "then": [{}], + } + } + } + }, + {"$unwind": "$T2"}, + { + "$lookup": { + "as": "T3", + "from": "queries__tag", + "let": {"parent__field__0": "$T2.parent_id"}, + "pipeline": [ + { + "$match": { + "$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]} + } + } + ], + } + }, + { + "$set": { + "T3": { + "$cond": { + "else": "$T3", + "if": { + "$or": [ + {"$eq": [{"$type": "$T3"}, "missing"]}, + {"$eq": [{"$size": "$T3"}, 0]}, + ] + }, + "then": [{}], + } + } + } + }, + {"$unwind": "$T3"}, + {"$match": {"$or": [{"name": "T1"}, {"T2.name": "T2"}, {"T3.name": "T3"}]}}, + ], + ) + + def test_self_join_tag_three_levels_pushable(self): + t1 = Tag.objects.create(name="T1") + t2 = Tag.objects.create(name="T2", parent=t1) + t3 = Tag.objects.create(name="T3", parent=t2) + Tag.objects.create(name="T4", parent=t3) + Tag.objects.create(name="T5", parent=t1) + Tag.objects.create(name="T6", parent=t2) + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + list(Tag.objects.filter(name="T1", parent__name="T2", parent__parent__name="T3")), + [], + ) + + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "queries__tag", + [ + { + "$lookup": { + "as": "T2", + "from": "queries__tag", + "let": {"parent__field__0": "$parent_id"}, + "pipeline": [ + { + "$match": { + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "T2"}, + ] + } + } + ], + } + }, + {"$unwind": "$T2"}, + { + "$lookup": { + "as": "T3", + "from": "queries__tag", + "let": {"parent__field__0": "$T2.parent_id"}, + "pipeline": [ + { + "$match": { + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "T3"}, + ] + } + } + ], + } + }, + {"$unwind": "$T3"}, + {"$match": {"$and": [{"name": "T1"}, {"T2.name": "T2"}, {"T3.name": "T3"}]}}, + ], + )