From a003b7c00cca42cb53bdb762a985afa51c883f57 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 18 Jul 2025 20:12:40 -0300 Subject: [PATCH 1/2] Push simple filter conditions into $lookup stage. --- django_mongodb_backend/compiler.py | 18 ++++++++-- django_mongodb_backend/query.py | 55 +++++++++++++++++++----------- tests/queries_/test_mql.py | 3 +- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 3ef6a1ea7..08f4d0cb5 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -9,14 +9,16 @@ from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power -from django.db.models.lookups import IsNull +from django.db.models.lookups import IsNull, Lookup from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE from django.db.models.sql.datastructures import BaseTable +from django.db.models.sql.where import AND from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING from .query import MongoQuery, wrap_database_errors +from .query_utils import is_direct_value class SQLCompiler(compiler.SQLCompiler): @@ -549,10 +551,22 @@ def get_combinator_queries(self): def get_lookup_pipeline(self): result = [] + where = self.get_where() + promote_filters = defaultdict(list) + for expr in where.children if where and where.connector == AND else (): + if ( + isinstance(expr, Lookup) + and isinstance(expr.lhs, Col) + and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value)) + ): + promote_filters[expr.lhs.alias].append(expr) + for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias] or self.collection_name == alias: continue - result += self.query.alias_map[alias].as_mql(self, self.connection) + result += self.query.alias_map[alias].as_mql( + self, self.connection, promote_filters[alias] + ) return result def _get_aggregate_expressions(self, expr): diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index d59bc1631..96a0db466 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.") -def join(self, compiler, connection): - lookup_pipeline = [] - lhs_fields = [] - rhs_fields = [] - # Add a join condition for each pair of joining fields. - parent_template = "parent__field__" - for lhs, rhs in self.join_fields: - lhs, rhs = connection.ops.prepare_join_on_clause( - self.parent_alias, lhs, compiler.collection_name, rhs - ) - lhs_fields.append(lhs.as_mql(compiler, connection)) - # In the lookup stage, the reference to this column doesn't include - # the collection name. - rhs_fields.append(rhs.as_mql(compiler, connection)) - # Handle any join conditions besides matching field pairs. - extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) - if extra: +def join(self, compiler, connection, pushed_expressions=None): + def _get_reroot_replacements(expressions): + if not expressions: + return [] columns = [] - for expr in extra.leaves(): + for expr in expressions: # Determine whether the column needs to be transformed or rerouted # as part of the subquery. for hand_side in ["lhs", "rhs"]: @@ -165,7 +152,7 @@ def join(self, compiler, connection): # based on their rerouted positions in the join pipeline. replacements = {} for col, parent_pos in columns: - column_target = Col(compiler.collection_name, expr.output_field.__class__()) + column_target = Col(compiler.collection_name, col.target, col.output_field) if parent_pos is not None: target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col @@ -173,10 +160,37 @@ def join(self, compiler, connection): else: column_target.target = col.target replacements[col] = column_target - # Apply the transformed expressions in the extra condition. + return replacements + + lookup_pipeline = [] + lhs_fields = [] + rhs_fields = [] + # Add a join condition for each pair of joining fields. + parent_template = "parent__field__" + for lhs, rhs in self.join_fields: + lhs, rhs = connection.ops.prepare_join_on_clause( + self.parent_alias, lhs, compiler.collection_name, rhs + ) + lhs_fields.append(lhs.as_mql(compiler, connection)) + # In the lookup stage, the reference to this column doesn't include + # the collection name. + rhs_fields.append(rhs.as_mql(compiler, connection)) + # Handle any join conditions besides matching field pairs. + extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) + + if extra: + replacements = _get_reroot_replacements(extra.leaves()) extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)] else: extra_condition = [] + if self.join_type == INNER: + rerooted_replacement = _get_reroot_replacements(pushed_expressions) + resolved_pushed_expressions = [ + expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection) + for expr in pushed_expressions + ] + else: + resolved_pushed_expressions = [] lookup_pipeline = [ { @@ -204,6 +218,7 @@ def join(self, compiler, connection): for i, field in enumerate(rhs_fields) ] + extra_condition + + resolved_pushed_expressions } } } diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index d61e5839d..d12ea9602 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -20,7 +20,8 @@ def test_join(self): "{'$lookup': {'from': 'queries__author', " "'let': {'parent__field__0': '$author_id'}, " "'pipeline': [{'$match': {'$expr': " - "{'$and': [{'$eq': ['$$parent__field__0', '$_id']}]}}}], 'as': 'queries__author'}}, " + "{'$and': [{'$eq': ['$$parent__field__0', '$_id']}, " + "{'$eq': ['$name', 'Bob']}]}}}], 'as': 'queries__author'}}, " "{'$unwind': '$queries__author'}, " "{'$match': {'$expr': {'$eq': ['$queries__author.name', 'Bob']}}}])", ) From 764cd5103ed0ec68579aa56a20bd603ad0fe8864 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 19 Jul 2025 14:55:49 -0300 Subject: [PATCH 2/2] Fix unintended overwrite of column.target. --- django_mongodb_backend/query.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 96a0db466..ef1f1a040 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -132,7 +132,7 @@ def extra_where(self, compiler, connection): # noqa: ARG001 def join(self, compiler, connection, pushed_expressions=None): def _get_reroot_replacements(expressions): if not expressions: - return [] + return None columns = [] for expr in expressions: # Determine whether the column needs to be transformed or rerouted @@ -152,7 +152,9 @@ def _get_reroot_replacements(expressions): # based on their rerouted positions in the join pipeline. replacements = {} for col, parent_pos in columns: - column_target = Col(compiler.collection_name, col.target, col.output_field) + target = col.target.clone() + target.remote_field = col.target.remote_field + column_target = Col(compiler.collection_name, target) if parent_pos is not None: target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col