diff --git a/src/tagstudio/core/library/alchemy/visitors.py b/src/tagstudio/core/library/alchemy/visitors.py index b3d173e3c..017d76104 100644 --- a/src/tagstudio/core/library/alchemy/visitors.py +++ b/src/tagstudio/core/library/alchemy/visitors.py @@ -14,6 +14,7 @@ from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias from tagstudio.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories from tagstudio.core.query_lang.ast import ( + AST, ANDList, BaseVisitor, Constraint, @@ -58,42 +59,15 @@ def __init__(self, lib: Library) -> None: self.lib = lib def visit_or_list(self, node: ORList) -> ColumnElement[bool]: - return or_(*[self.visit(element) for element in node.elements]) + tag_ids, bool_expressions = self.__separate_tags(node.elements, only_single=False) + if len(tag_ids) > 0: + bool_expressions.append(self.__entry_has_any_tags(tag_ids)) + return or_(*bool_expressions) def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: - tag_ids: list[int] = [] - bool_expressions: list[ColumnElement[bool]] = [] - - # Search for TagID / unambiguous Tag Constraints and store the respective tag ids separately - for term in node.terms: - if isinstance(term, Constraint) and len(term.properties) == 0: - match term.type: - case ConstraintType.TagID: - try: - tag_ids.append(int(term.value)) - except ValueError: - logger.error( - "[SQLBoolExpressionBuilder] Could not cast value to an int Tag ID", - value=term.value, - ) - continue - case ConstraintType.Tag: - if len(ids := self.__get_tag_ids(term.value)) == 1: - tag_ids.append(ids[0]) - continue - - bool_expressions.append(self.visit(term)) - - # If there are at least two tag ids use a relational division query - # to efficiently check all of them - if len(tag_ids) > 1: + tag_ids, bool_expressions = self.__separate_tags(node.terms, only_single=True) + if len(tag_ids) > 0: bool_expressions.append(self.__entry_has_all_tags(tag_ids)) - # If there is just one tag id, check the normal way - elif len(tag_ids) == 1: - bool_expressions.append( - self.__entry_satisfies_expression(TagEntry.tag_id == tag_ids[0]) - ) - return and_(*bool_expressions) def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: @@ -102,9 +76,9 @@ def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG if node.type == ConstraintType.Tag: - return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value)) + return self.__entry_has_any_tags(self.__get_tag_ids(node.value)) elif node.type == ConstraintType.TagID: - return self.__entry_matches_tag_ids([int(node.value)]) + return self.__entry_has_any_tags([int(node.value)]) elif node.type == ConstraintType.Path: ilike = False glob = False @@ -153,15 +127,6 @@ def visit_property(self, node: Property) -> ColumnElement[bool]: def visit_not(self, node: Not) -> ColumnElement[bool]: return ~self.visit(node.child) - def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]: - """Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501 - return ( - select(1) - .correlate(Entry) - .where(and_(TagEntry.entry_id == Entry.id, TagEntry.tag_id.in_(tag_ids))) - .exists() - ) - def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: """Given a tag name find the ids of all tags that this name could refer to.""" with Session(self.lib.engine) as session: @@ -185,6 +150,36 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in outp.extend(list(session.scalars(TAG_CHILDREN_ID_QUERY, {"tag_id": tag_id}))) return outp + def __separate_tags( + self, terms: list[AST], only_single: bool = True + ) -> tuple[list[int], list[ColumnElement[bool]]]: + tag_ids: list[int] = [] + bool_expressions: list[ColumnElement[bool]] = [] + + for term in terms: + if isinstance(term, Constraint) and len(term.properties) == 0: + match term.type: + case ConstraintType.TagID: + try: + tag_ids.append(int(term.value)) + except ValueError: + logger.error( + "[SQLBoolExpressionBuilder] Could not cast value to an int Tag ID", + value=term.value, + ) + continue + case ConstraintType.Tag: + ids = self.__get_tag_ids(term.value) + if not only_single: + tag_ids.extend(ids) + continue + elif len(ids) == 1: + tag_ids.append(ids[0]) + continue + + bool_expressions.append(self.visit(term)) + return tag_ids, bool_expressions + def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids.""" # Relational Division Query @@ -195,9 +190,8 @@ def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]: .having(func.count(distinct(TagEntry.tag_id)) == len(tag_ids)) ) - def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]: - """Returns Binary Expression that is true if the Entry satisfies the column expression. - - Executed on: Entry ⟕ TagEntry (Entry LEFT OUTER JOIN TagEntry). - """ - return Entry.id.in_(select(Entry.id).outerjoin(TagEntry).where(expr)) + def __entry_has_any_tags(self, tag_ids: list[int]) -> ColumnElement[bool]: + """Returns Binary Expression that is true if the Entry has any of the provided tag ids.""" + return Entry.id.in_( + select(TagEntry.entry_id).where(TagEntry.tag_id.in_(tag_ids)).distinct() + )