|
9 | 9 | from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When |
10 | 10 | from django.db.models.functions.comparison import Coalesce |
11 | 11 | from django.db.models.functions.math import Power |
12 | | -from django.db.models.lookups import IsNull, Lookup |
| 12 | +from django.db.models.lookups import IsNull |
13 | 13 | from django.db.models.sql import compiler |
14 | 14 | from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE |
15 | 15 | from django.db.models.sql.datastructures import BaseTable |
16 | | -from django.db.models.sql.where import AND, WhereNode |
| 16 | +from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode |
17 | 17 | from django.utils.functional import cached_property |
18 | 18 | from pymongo import ASCENDING, DESCENDING |
19 | 19 |
|
20 | 20 | from .expressions.search import SearchExpression, SearchVector |
21 | 21 | from .query import MongoQuery, wrap_database_errors |
22 | | -from .query_utils import is_direct_value |
| 22 | +from .query_utils import is_constant_value |
23 | 23 |
|
24 | 24 |
|
25 | 25 | class SQLCompiler(compiler.SQLCompiler): |
@@ -661,27 +661,61 @@ def get_combinator_queries(self): |
661 | 661 | combinator_pipeline.append({"$unset": "_id"}) |
662 | 662 | return combinator_pipeline |
663 | 663 |
|
| 664 | + def _get_pushable_conditions(self): |
| 665 | + def collect_pushable(expr, negated=False): |
| 666 | + if expr is None or isinstance(expr, NothingNode): |
| 667 | + return {} |
| 668 | + if isinstance(expr, WhereNode): |
| 669 | + negated ^= expr.negated |
| 670 | + pushable_expressions = [ |
| 671 | + collect_pushable(sub_expr, negated=negated) |
| 672 | + for sub_expr in expr.children |
| 673 | + if sub_expr is not None |
| 674 | + ] |
| 675 | + operator = expr.connector |
| 676 | + if operator == XOR: |
| 677 | + return {} |
| 678 | + if negated: |
| 679 | + operator = OR if operator == AND else AND |
| 680 | + alias_children = defaultdict(list) |
| 681 | + for pe in pushable_expressions: |
| 682 | + for alias, expressions in pe.items(): |
| 683 | + alias_children[alias].append(expressions) |
| 684 | + result = {} |
| 685 | + for alias, children in alias_children.items(): |
| 686 | + result[alias] = WhereNode( |
| 687 | + children=children, |
| 688 | + negated=False, |
| 689 | + connector=operator, |
| 690 | + ) |
| 691 | + if operator == AND: |
| 692 | + return result |
| 693 | + shared_alias = ( |
| 694 | + set.intersection(*(set(pe) for pe in pushable_expressions)) |
| 695 | + if pushable_expressions |
| 696 | + else set() |
| 697 | + ) |
| 698 | + return {k: v for k, v in result.items() if k in shared_alias} |
| 699 | + if isinstance(expr.lhs, Col) and ( |
| 700 | + is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False) |
| 701 | + ): |
| 702 | + alias = expr.lhs.alias |
| 703 | + expr = WhereNode(children=[expr], negated=negated) |
| 704 | + return {alias: expr} |
| 705 | + return {} |
| 706 | + |
| 707 | + return collect_pushable(self.get_where()) |
| 708 | + |
664 | 709 | def get_lookup_pipeline(self): |
665 | 710 | result = [] |
666 | 711 | # To improve join performance, push conditions (filters) from the |
667 | 712 | # WHERE ($match) clause to the JOIN ($lookup) clause. |
668 | | - where = self.get_where() |
669 | | - pushed_filters = defaultdict(list) |
670 | | - for expr in where.children if where and where.connector == AND else (): |
671 | | - # Push only basic lookups; no subqueries or complex conditions. |
672 | | - # To avoid duplication across subqueries, only use the LHS target |
673 | | - # table. |
674 | | - if ( |
675 | | - isinstance(expr, Lookup) |
676 | | - and isinstance(expr.lhs, Col) |
677 | | - and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col))) |
678 | | - ): |
679 | | - pushed_filters[expr.lhs.alias].append(expr) |
| 713 | + pushed_filters = self._get_pushable_conditions() |
680 | 714 | for alias in tuple(self.query.alias_map): |
681 | 715 | if not self.query.alias_refcount[alias] or self.collection_name == alias: |
682 | 716 | continue |
683 | 717 | result += self.query.alias_map[alias].as_mql( |
684 | | - self, self.connection, WhereNode(pushed_filters[alias], connector=AND) |
| 718 | + self, self.connection, pushed_filters.get(alias) |
685 | 719 | ) |
686 | 720 | return result |
687 | 721 |
|
|
0 commit comments