diff --git a/pyproject.toml b/pyproject.toml index 67286a687..fcd9e9945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,11 @@ ignore_errors = true qt_api = "pyside6" [tool.pyright] -ignore = ["src/tagstudio/qt/previews/vendored/pydub/", ".venv/**"] +ignore = [ + ".venv/**", + "src/tagstudio/core/library/json/", + "src/tagstudio/qt/previews/vendored/pydub/", +] include = ["src/tagstudio", "tests"] reportAny = false reportIgnoreCommentWithoutRule = false diff --git a/src/tagstudio/core/library/alchemy/constants.py b/src/tagstudio/core/library/alchemy/constants.py index 26ce1c832..2032d613a 100644 --- a/src/tagstudio/core/library/alchemy/constants.py +++ b/src/tagstudio/core/library/alchemy/constants.py @@ -23,3 +23,14 @@ ) SELECT * FROM ChildTags; """) + +TAG_CHILDREN_ID_QUERY = text(""" +WITH RECURSIVE ChildTags AS ( + SELECT :tag_id AS tag_id + UNION + SELECT tp.child_id AS tag_id + FROM tag_parents tp + INNER JOIN ChildTags c ON tp.parent_id = c.tag_id +) +SELECT tag_id FROM ChildTags; +""") diff --git a/src/tagstudio/core/library/alchemy/db.py b/src/tagstudio/core/library/alchemy/db.py index 78a766f12..026678ddf 100644 --- a/src/tagstudio/core/library/alchemy/db.py +++ b/src/tagstudio/core/library/alchemy/db.py @@ -4,6 +4,7 @@ from pathlib import Path +from typing import override import structlog from sqlalchemy import Dialect, Engine, String, TypeDecorator, create_engine, text @@ -19,12 +20,14 @@ class PathType(TypeDecorator): impl = String cache_ok = True - def process_bind_param(self, value: Path, dialect: Dialect): + @override + def process_bind_param(self, value: Path | None, dialect: Dialect): if value is not None: return Path(value).as_posix() return None - def process_result_value(self, value: str, dialect: Dialect): + @override + def process_result_value(self, value: str | None, dialect: Dialect): if value is not None: return Path(value) return None diff --git a/src/tagstudio/core/library/alchemy/fields.py b/src/tagstudio/core/library/alchemy/fields.py index c4675dc47..faffae079 100644 --- a/src/tagstudio/core/library/alchemy/fields.py +++ b/src/tagstudio/core/library/alchemy/fields.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, override from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship @@ -32,7 +32,7 @@ def type_key(self) -> Mapped[str]: @declared_attr def type(self) -> Mapped[ValueType]: - return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore + return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore # pyright: ignore[reportArgumentType] @declared_attr def entry_id(self) -> Mapped[int]: @@ -40,19 +40,20 @@ def entry_id(self) -> Mapped[int]: @declared_attr def entry(self) -> Mapped[Entry]: - return relationship(foreign_keys=[self.entry_id]) # type: ignore + return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType] @declared_attr def position(self) -> Mapped[int]: return mapped_column(default=0) + @override def __hash__(self): return hash(self.__key()) - def __key(self): + def __key(self): # pyright: ignore[reportUnknownParameterType] raise NotImplementedError - value: Any + value: Any # pyright: ignore class BooleanField(BaseField): @@ -63,7 +64,8 @@ class BooleanField(BaseField): def __key(self): return (self.type, self.value) - def __eq__(self, value) -> bool: + @override + def __eq__(self, value: object) -> bool: if isinstance(value, BooleanField): return self.__key() == value.__key() raise NotImplementedError @@ -74,10 +76,11 @@ class TextField(BaseField): value: Mapped[str | None] - def __key(self) -> tuple: + def __key(self) -> tuple[ValueType, str | None]: return self.type, self.value - def __eq__(self, value) -> bool: + @override + def __eq__(self, value: object) -> bool: if isinstance(value, TextField): return self.__key() == value.__key() elif isinstance(value, DatetimeField): @@ -93,7 +96,8 @@ class DatetimeField(BaseField): def __key(self): return (self.type, self.value) - def __eq__(self, value) -> bool: + @override + def __eq__(self, value: object) -> bool: if isinstance(value, DatetimeField): return self.__key() == value.__key() raise NotImplementedError @@ -107,7 +111,7 @@ class DefaultField: is_default: bool = field(default=False) -class _FieldID(Enum): +class FieldID(Enum): """Only for bootstrapping content of DB table.""" TITLE = DefaultField(id=0, name="Title", type=FieldTypeEnum.TEXT_LINE, is_default=True) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index fc56f9194..c24643b23 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -2,12 +2,16 @@ # Licensed under the GPL-3.0 License. # Created for TagStudio: https://github.com/CyanVoxel/TagStudio +# NOTE: This file contains necessary use of deprecated first-party code until that +# code is removed in a future version (prefs). +# pyright: reportDeprecated=false + import re import shutil import time import unicodedata -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, MutableSequence from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs @@ -18,7 +22,7 @@ import sqlalchemy import structlog -from humanfriendly import format_timespan +from humanfriendly import format_timespan # pyright: ignore[reportUnknownVariableType] from sqlalchemy import ( URL, ColumnExpressionArgument, @@ -40,6 +44,7 @@ ) from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import ( + InstanceState, Session, contains_eager, joinedload, @@ -47,6 +52,7 @@ noload, selectinload, ) +from typing_extensions import deprecated from tagstudio.core.constants import ( BACKUP_FOLDER_NAME, @@ -81,8 +87,8 @@ from tagstudio.core.library.alchemy.fields import ( BaseField, DatetimeField, + FieldID, TextField, - _FieldID, ) from tagstudio.core.library.alchemy.joins import TagEntry, TagParent from tagstudio.core.library.alchemy.models import ( @@ -210,9 +216,9 @@ class Library: """Class for the Library object, and all CRUD operations made upon it.""" library_dir: Path | None = None - storage_path: Path | str | None + storage_path: Path | str | None = None engine: Engine | None = None - folder: Folder | None + folder: Folder | None = None included_files: set[Path] = set() def __init__(self) -> None: @@ -224,7 +230,7 @@ def __init__(self) -> None: def close(self): if self.engine: self.engine.dispose() - self.library_dir: Path | None = None + self.library_dir = None self.storage_path = None self.folder = None self.included_files = set() @@ -300,8 +306,8 @@ def migrate_json_to_sqlite(self, json_lib: JsonLibrary): ] ) for entry in json_lib.entries: - for field in entry.fields: - for k, v in field.items(): + for field in entry.fields: # pyright: ignore[reportUnknownVariableType] + for k, v in field.items(): # pyright: ignore[reportUnknownVariableType] # Old tag fields get added as tags if k in LEGACY_TAG_FIELD_IDS: self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v) @@ -319,8 +325,8 @@ def migrate_json_to_sqlite(self, json_lib: JsonLibrary): end_time = time.time() logger.info(f"Library Converted! ({format_timespan(end_time - start_time)})") - def get_field_name_from_id(self, field_id: int) -> _FieldID: - for f in _FieldID: + def get_field_name_from_id(self, field_id: int) -> FieldID | None: + for f in FieldID: if field_id == f.value.id: return f return None @@ -482,7 +488,7 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus: except IntegrityError: session.rollback() - for field in _FieldID: + for field in FieldID: try: session.add( ValueType( @@ -562,7 +568,7 @@ def __apply_repairs_for_db6(self, session: Session): # Repair "Description" fields with a TEXT_LINE key instead of a TEXT_BOX key. desc_stmd = ( update(ValueType) - .where(ValueType.key == _FieldID.DESCRIPTION.name) + .where(ValueType.key == FieldID.DESCRIPTION.name) .values(type=FieldTypeEnum.TEXT_BOX.name) ) session.execute(desc_stmd) @@ -697,8 +703,8 @@ def migrate_sql_to_ts_ignore(self, library_dir: Path): logger.error("[ERROR][Library] Could not generate '.ts_ignore' file!", error=e) # Load legacy extension data - extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore[reportAssignmentType] - is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore[reportAssignmentType] + extensions: list[str] = self.prefs(LibraryPrefs.EXTENSION_LIST) # pyright: ignore + is_exclude_list: bool = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) # pyright: ignore # Copy extensions to '.ts_ignore' file if ts_ignore.exists(): @@ -720,12 +726,6 @@ def default_fields(self) -> list[BaseField]: ) return [x.as_field for x in types] - def delete_item(self, item): - logger.info("deleting item", item=item) - with Session(self.engine) as session: - session.delete(item) - session.commit() - def get_entry(self, entry_id: int) -> Entry | None: """Load entry without joins.""" with Session(self.engine) as session: @@ -794,7 +794,7 @@ def get_entries(self, entry_ids: Iterable[int]) -> list[Entry]: entries = dict((e.id, e) for e in session.scalars(statement)) return [entries[id] for id in entry_ids] - def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]: + def get_entries_full(self, entry_ids: MutableSequence[int]) -> Iterator[Entry]: """Load entry and join with all joins and all tags.""" with Session(self.engine) as session: statement = select(Entry).where(Entry.id.in_(set(entry_ids))) @@ -864,7 +864,7 @@ def get_tag_entries( @property def entries_count(self) -> int: with Session(self.engine) as session: - return session.scalar(select(func.count(Entry.id))) + return unwrap(session.scalar(select(func.count(Entry.id)))) def all_entries(self, with_joins: bool = False) -> Iterator[Entry]: """Load entries without joins.""" @@ -906,7 +906,7 @@ def tags(self) -> list[Tag]: return list(tags_list) - def verify_ts_folder(self, library_dir: Path) -> bool: + def verify_ts_folder(self, library_dir: Path | None) -> bool: """Verify/create folders required by TagStudio. Returns: @@ -960,7 +960,7 @@ def has_path_entry(self, path: Path) -> bool: with Session(self.engine) as session: return session.query(exists().where(Entry.path == path)).scalar() - def get_paths(self, glob: str | None = None, limit: int = -1) -> list[str]: + def get_paths(self, limit: int = -1) -> list[str]: path_strings: list[str] = [] with Session(self.engine) as session: if limit > 0: @@ -1020,7 +1020,7 @@ def search_library( ids = [] count = 0 for row in rows: - id, count = row._tuple() + id, count = row._tuple() # pyright: ignore[reportPrivateUsage] ids.append(id) end_time = time.time() logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})") @@ -1109,17 +1109,13 @@ def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool: session.commit() return True - def remove_tag(self, tag: Tag): + def remove_tag(self, tag_id: int): with Session(self.engine, expire_on_commit=False) as session: try: child_tags = session.scalars( - select(TagParent).where(TagParent.child_id == tag.id) + select(TagParent).where(TagParent.child_id == tag_id) ).all() - tags_query = select(Tag).options( - selectinload(Tag.parent_tags), selectinload(Tag.aliases) - ) - tag = session.scalar(tags_query.where(Tag.id == tag.id)) - aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)) + aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag_id)) for alias in aliases or []: session.delete(alias) @@ -1130,24 +1126,18 @@ def remove_tag(self, tag: Tag): disam_stmt = ( update(Tag) - .where(Tag.disambiguation_id == tag.id) + .where(Tag.disambiguation_id == tag_id) .values(disambiguation_id=None) ) session.execute(disam_stmt) session.flush() - - session.delete(tag) + session.query(Tag).filter_by(id=tag_id).delete() session.commit() - session.expunge(tag) - - return tag except IntegrityError as e: logger.error(e) session.rollback() - return None - def update_field_position( self, field_class: type[BaseField], @@ -1247,7 +1237,7 @@ def field_types(self) -> dict[str, ValueType]: def get_value_type(self, field_key: str) -> ValueType: with Session(self.engine) as session: - field = session.scalar(select(ValueType).where(ValueType.key == field_key)) + field = unwrap(session.scalar(select(ValueType).where(ValueType.key == field_key))) session.expunge(field) return field @@ -1256,7 +1246,7 @@ def add_field_to_entry( entry_id: int, *, field: ValueType | None = None, - field_id: _FieldID | str | None = None, + field_id: FieldID | str | None = None, value: str | datetime | None = None, ) -> bool: logger.info( @@ -1270,9 +1260,9 @@ def add_field_to_entry( assert bool(field) != (field_id is not None) if not field: - if isinstance(field_id, _FieldID): + if isinstance(field_id, FieldID): field_id = field_id.name - field = self.get_value_type(field_id) + field = self.get_value_type(unwrap(field_id)) field_model: TextField | DatetimeField if field.type in (FieldTypeEnum.TEXT_LINE, FieldTypeEnum.TEXT_BOX): @@ -1407,9 +1397,9 @@ def delete_namespace(self, namespace: Namespace | str): def add_tag( self, tag: Tag, - parent_ids: list[int] | set[int] | None = None, - alias_names: list[str] | set[str] | None = None, - alias_ids: list[int] | set[int] | None = None, + parent_ids: MutableSequence[int] | None = None, + alias_names: MutableSequence[str] | None = None, + alias_ids: MutableSequence[int] | None = None, ) -> Tag | None: with Session(self.engine, expire_on_commit=False) as session: try: @@ -1432,7 +1422,7 @@ def add_tag( return None def add_tags_to_entries( - self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int] + self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int] ) -> int: """Add one or more tags to one or more entries. @@ -1461,7 +1451,7 @@ def add_tags_to_entries( return total_added def remove_tags_from_entries( - self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int] + self, entry_ids: int | list[int], tag_ids: int | MutableSequence[int] ) -> bool: """Remove one or more tags from one or more entries.""" entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids @@ -1619,9 +1609,9 @@ def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]: # When calling session.add with this tag instance sqlalchemy will # attempt to create TagParents that already exist. - state = inspect(tag) - # Prevent sqlalchemy from thinking any fields are different from what's commited - # commited_state contains original values for fields that have changed. + state: InstanceState[Tag] = inspect(tag) + # Prevent sqlalchemy from thinking any fields are different from what's committed + # committed_state contains original values for fields that have changed. # empty when no fields have changed state.committed_state.clear() @@ -1679,9 +1669,9 @@ def remove_parent_tag(self, base_id: int, remove_tag_id: int) -> bool: def update_tag( self, tag: Tag, - parent_ids: list[int] | set[int] | None = None, - alias_names: list[str] | set[str] | None = None, - alias_ids: list[int] | set[int] | None = None, + parent_ids: MutableSequence[int] | None = None, + alias_names: MutableSequence[str] | None = None, + alias_ids: MutableSequence[int] | None = None, ) -> None: """Edit a Tag in the Library.""" self.add_tag(tag, parent_ids, alias_names, alias_ids) @@ -1735,7 +1725,13 @@ def update_color(self, old_color_group: TagColorGroup, new_color_group: TagColor else: self.add_color(new_color_group) - def update_aliases(self, tag, alias_ids, alias_names, session): + def update_aliases( + self, + tag: Tag, + alias_ids: MutableSequence[int], + alias_names: MutableSequence[str], + session: Session, + ): prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all() for alias in prev_aliases: @@ -1749,7 +1745,7 @@ def update_aliases(self, tag, alias_ids, alias_names, session): alias = TagAlias(alias_name, tag.id) session.add(alias) - def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session): + def update_parent_tags(self, tag: Tag, parent_ids: MutableSequence[int], session: Session): if tag.id in parent_ids: parent_ids.remove(tag.id) @@ -1822,10 +1818,11 @@ def set_version(self, key: str, value: int) -> None: # by older TagStudio versions. engine = sqlalchemy.inspect(self.engine) if engine and engine.has_table("Preferences"): - pref = session.scalar( - select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY) + pref = unwrap( + session.scalar( + select(Preferences).where(Preferences.key == DB_VERSION_LEGACY_KEY) + ) ) - assert pref is not None pref.value = value # pyright: ignore session.add(pref) session.commit() @@ -1833,15 +1830,19 @@ def set_version(self, key: str, value: int) -> None: logger.error("[Library][ERROR] Couldn't add default tag color namespaces", error=e) session.rollback() - def prefs(self, key: str | LibraryPrefs): + # TODO: Remove this once the 'preferences' table is removed. + @deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.") + def prefs(self, key: str | LibraryPrefs): # pyright: ignore[reportUnknownParameterType] # load given item from Preferences table with Session(self.engine) as session: if isinstance(key, LibraryPrefs): - return session.scalar(select(Preferences).where(Preferences.key == key.name)).value + return session.scalar(select(Preferences).where(Preferences.key == key.name)).value # pyright: ignore else: - return session.scalar(select(Preferences).where(Preferences.key == key)).value + return session.scalar(select(Preferences).where(Preferences.key == key)).value # pyright: ignore - def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None: + # TODO: Remove this once the 'preferences' table is removed. + @deprecated("Use `get_version() for version and `ts_ignore` system for extension exclusion.") + def set_prefs(self, key: str | LibraryPrefs, value: Any) -> None: # pyright: ignore[reportExplicitAny] # set given item in Preferences table with Session(self.engine) as session: # load existing preference and update value @@ -1873,7 +1874,7 @@ def mirror_entry_fields(self, *entries: Entry) -> None: # assign the field to all entries for entry in entries: - for field_key, field in fields.items(): + for field_key, field in fields.items(): # pyright: ignore[reportUnknownVariableType] if field_key not in existing_fields: self.add_field_to_entry( entry_id=entry.id, diff --git a/src/tagstudio/core/library/alchemy/models.py b/src/tagstudio/core/library/alchemy/models.py index 4d103be20..223dc0216 100644 --- a/src/tagstudio/core/library/alchemy/models.py +++ b/src/tagstudio/core/library/alchemy/models.py @@ -4,6 +4,7 @@ from datetime import datetime as dt from pathlib import Path +from typing import override from sqlalchemy import JSON, ForeignKey, ForeignKeyConstraint, Integer, event from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -150,25 +151,28 @@ def __init__( self.id = id # pyright: ignore[reportAttributeAccessIssue] super().__init__() + @override def __str__(self) -> str: return f"" + @override def __repr__(self) -> str: return self.__str__() + @override def __hash__(self) -> int: return hash(self.id) - def __lt__(self, other) -> bool: + def __lt__(self, other: "Tag") -> bool: return self.name < other.name - def __le__(self, other) -> bool: + def __le__(self, other: "Tag") -> bool: return self.name <= other.name - def __gt__(self, other) -> bool: + def __gt__(self, other: "Tag") -> bool: return self.name > other.name - def __ge__(self, other) -> bool: + def __ge__(self, other: "Tag") -> bool: return self.name >= other.name @@ -233,6 +237,7 @@ def __init__( date_modified: dt | None = None, date_added: dt | None = None, ) -> None: + super().__init__() self.path = path self.folder = folder self.id = id # pyright: ignore[reportAttributeAccessIssue] @@ -280,8 +285,8 @@ class ValueType(Base): key: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(nullable=False) type: Mapped[FieldTypeEnum] = mapped_column(default=FieldTypeEnum.TEXT_LINE) - is_default: Mapped[bool] - position: Mapped[int] + is_default: Mapped[bool] # pyright: ignore[reportUninitializedInstanceVariable] + position: Mapped[int] # pyright: ignore[reportUninitializedInstanceVariable] # add relations to other tables text_fields: Mapped[list[TextField]] = relationship("TextField", back_populates="type") @@ -306,7 +311,7 @@ def as_field(self) -> BaseField: @event.listens_for(ValueType, "before_insert") -def slugify_field_key(mapper, connection, target): +def slugify_field_key(mapper, connection, target): # pyright: ignore """Slugify the field key before inserting into the database.""" if not target.key: from tagstudio.core.library.alchemy.library import slugify diff --git a/src/tagstudio/core/library/alchemy/visitors.py b/src/tagstudio/core/library/alchemy/visitors.py index d7f63e9c9..c24c8ed7a 100644 --- a/src/tagstudio/core/library/alchemy/visitors.py +++ b/src/tagstudio/core/library/alchemy/visitors.py @@ -3,13 +3,14 @@ # Created for TagStudio: https://github.com/CyanVoxel/TagStudio import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, override import structlog -from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text +from sqlalchemy import ColumnElement, and_, distinct, func, or_, select from sqlalchemy.orm import Session from sqlalchemy.sql.operators import ilike_op +from tagstudio.core.library.alchemy.constants import TAG_CHILDREN_ID_QUERY from tagstudio.core.library.alchemy.joins import TagEntry from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias from tagstudio.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories @@ -32,17 +33,6 @@ logger = structlog.get_logger(__name__) -TAG_CHILDREN_ID_QUERY = text(""" -WITH RECURSIVE ChildTags AS ( - SELECT :tag_id AS tag_id - UNION - SELECT tp.child_id AS tag_id - FROM tag_parents tp - INNER JOIN ChildTags c ON tp.parent_id = c.tag_id -) -SELECT tag_id FROM ChildTags; -""") - def get_filetype_equivalency_list(item: str) -> list[str] | set[str]: for s in FILETYPE_EQUIVALENTS: @@ -56,19 +46,22 @@ def __init__(self, lib: Library) -> None: super().__init__() self.lib = lib - def visit_or_list(self, node: ORList) -> ColumnElement[bool]: + @override + def visit_or_list(self, node: ORList) -> ColumnElement[bool]: # type: ignore tag_ids, bool_expressions = self.__separate_tags(node.elements, only_single=False) if len(tag_ids) > 0: bool_expressions.append(self.__entry_has_any_tags(tag_ids)) return or_(*bool_expressions) - def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: + @override + def visit_and_list(self, node: ANDList) -> ColumnElement[bool]: # type: ignore tag_ids, bool_expressions = self.__separate_tags(node.terms, only_single=True) if len(tag_ids) > 0: bool_expressions.append(self.__entry_has_all_tags(tag_ids)) return and_(*bool_expressions) - def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: + @override + def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: # type: ignore """Returns a Boolean Expression that is true, if the Entry satisfies the constraint.""" if len(node.properties) != 0: raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG @@ -119,10 +112,12 @@ def visit_constraint(self, node: Constraint) -> ColumnElement[bool]: # raise exception if Constraint stays unhandled raise NotImplementedError("This type of constraint is not implemented yet") - def visit_property(self, node: Property) -> ColumnElement[bool]: + @override + def visit_property(self, node: Property) -> ColumnElement[bool]: # type: ignore raise NotImplementedError("This should never be reached!") - def visit_not(self, node: Not) -> ColumnElement[bool]: + @override + def visit_not(self, node: Not) -> ColumnElement[bool]: # type: ignore return ~self.visit(node.child) def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: @@ -143,7 +138,7 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in ) if not include_children: return tag_ids - outp = [] + outp: list[int] = [] for tag_id in tag_ids: outp.extend(list(session.scalars(TAG_CHILDREN_ID_QUERY, {"tag_id": tag_id}))) return outp @@ -174,6 +169,14 @@ def __separate_tags( elif len(ids) == 1: tag_ids.append(ids[0]) continue + case ConstraintType.FileType: + pass + case ConstraintType.Path: + pass + case ConstraintType.Special: + pass + case _: + raise NotImplementedError(f"Unhandled constraint: '{term.type}'") bool_expressions.append(self.visit(term)) return tag_ids, bool_expressions diff --git a/src/tagstudio/core/query_lang/ast.py b/src/tagstudio/core/query_lang/ast.py index 102203ed7..0323bf26d 100644 --- a/src/tagstudio/core/query_lang/ast.py +++ b/src/tagstudio/core/query_lang/ast.py @@ -1,6 +1,11 @@ +# Copyright (C) 2025 +# Licensed under the GPL-3.0 License. +# Created for TagStudio: https://github.com/CyanVoxel/TagStudio + + from abc import ABC, abstractmethod from enum import Enum -from typing import Generic, TypeVar, Union +from typing import Generic, TypeVar, override class ConstraintType(Enum): @@ -12,7 +17,7 @@ class ConstraintType(Enum): Special = 5 @staticmethod - def from_string(text: str) -> Union["ConstraintType", None]: + def from_string(text: str) -> "ConstraintType | None": return { "tag": ConstraintType.Tag, "tag_id": ConstraintType.TagID, @@ -24,14 +29,16 @@ def from_string(text: str) -> Union["ConstraintType", None]: class AST: - parent: Union["AST", None] = None + parent: "AST | None" = None + @override def __str__(self): class_name = self.__class__.__name__ fields = vars(self) # Get all instance variables as a dictionary field_str = ", ".join(f"{key}={value}" for key, value in fields.items()) return f"{class_name}({field_str})" + @override def __repr__(self) -> str: return self.__str__() diff --git a/src/tagstudio/core/query_lang/parser.py b/src/tagstudio/core/query_lang/parser.py index 8566d5dbb..ff17465d7 100644 --- a/src/tagstudio/core/query_lang/parser.py +++ b/src/tagstudio/core/query_lang/parser.py @@ -1,3 +1,8 @@ +# Copyright (C) 2025 +# Licensed under the GPL-3.0 License. +# Created for TagStudio: https://github.com/CyanVoxel/TagStudio + + from tagstudio.core.query_lang.ast import ( AST, ANDList, @@ -27,7 +32,7 @@ def parse(self) -> AST: if self.next_token.type == TokenType.EOF: return ORList([]) out = self.__or_list() - if self.next_token.type != TokenType.EOF: + if self.next_token.type != TokenType.EOF: # pyright: ignore[reportUnnecessaryComparison] raise ParsingError(self.next_token.start, self.next_token.end, "Syntax Error") return out @@ -41,7 +46,7 @@ def __or_list(self) -> AST: return ORList(terms) if len(terms) > 1 else terms[0] def __is_next_or(self) -> bool: - return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR" + return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "OR" # pyright: ignore def __and_list(self) -> AST: elements = [self.__term()] @@ -67,7 +72,7 @@ def __skip_and(self) -> None: raise self.__syntax_error("Unexpected AND") def __is_next_and(self) -> bool: - return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND" + return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "AND" # pyright: ignore def __term(self) -> AST: if self.__is_next_not(): @@ -85,11 +90,14 @@ def __term(self) -> AST: return self.__constraint() def __is_next_not(self) -> bool: - return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT" + return self.next_token.type == TokenType.ULITERAL and self.next_token.value.upper() == "NOT" # pyright: ignore def __constraint(self) -> Constraint: if self.next_token.type == TokenType.CONSTRAINTTYPE: - self.last_constraint_type = self.__eat(TokenType.CONSTRAINTTYPE).value + constraint = self.__eat(TokenType.CONSTRAINTTYPE).value + if not isinstance(constraint, ConstraintType): + raise self.__syntax_error() + self.last_constraint_type = constraint value = self.__literal() @@ -98,7 +106,7 @@ def __constraint(self) -> Constraint: self.__eat(TokenType.SBRACKETO) properties.append(self.__property()) - while self.next_token.type == TokenType.COMMA: + while self.next_token.type == TokenType.COMMA: # pyright: ignore[reportUnnecessaryComparison] self.__eat(TokenType.COMMA) properties.append(self.__property()) @@ -110,11 +118,16 @@ def __property(self) -> Property: key = self.__eat(TokenType.ULITERAL).value self.__eat(TokenType.EQUALS) value = self.__literal() + if not isinstance(key, str): + raise self.__syntax_error() return Property(key, value) def __literal(self) -> str: if self.next_token.type in [TokenType.QLITERAL, TokenType.ULITERAL]: - return self.__eat(self.next_token.type).value + literal = self.__eat(self.next_token.type).value + if not isinstance(literal, str): + raise self.__syntax_error() + return literal raise self.__syntax_error() def __eat(self, type: TokenType) -> Token: diff --git a/src/tagstudio/core/query_lang/tokenizer.py b/src/tagstudio/core/query_lang/tokenizer.py index 4970a5feb..a279fbf69 100644 --- a/src/tagstudio/core/query_lang/tokenizer.py +++ b/src/tagstudio/core/query_lang/tokenizer.py @@ -1,5 +1,10 @@ +# Copyright (C) 2025 +# Licensed under the GPL-3.0 License. +# Created for TagStudio: https://github.com/CyanVoxel/TagStudio + + from enum import Enum -from typing import Any +from typing import override from tagstudio.core.query_lang.ast import ConstraintType from tagstudio.core.query_lang.util import ParsingError @@ -21,12 +26,14 @@ class TokenType(Enum): class Token: type: TokenType - value: Any + value: str | ConstraintType | None start: int end: int - def __init__(self, type: TokenType, value: Any, start: int, end: int) -> None: + def __init__( + self, type: TokenType, value: str | ConstraintType | None, start: int, end: int + ) -> None: self.type = type self.value = value self.start = start @@ -40,9 +47,11 @@ def from_type(type: TokenType, pos: int) -> "Token": def EOF(pos: int) -> "Token": # noqa: N802 return Token.from_type(TokenType.EOF, pos) + @override def __str__(self) -> str: return f"Token({self.type}, {self.value}, {self.start}, {self.end})" # pragma: nocover + @override def __repr__(self) -> str: return self.__str__() # pragma: nocover diff --git a/src/tagstudio/core/query_lang/util.py b/src/tagstudio/core/query_lang/util.py index 8deaecf22..95e53dbea 100644 --- a/src/tagstudio/core/query_lang/util.py +++ b/src/tagstudio/core/query_lang/util.py @@ -1,15 +1,26 @@ +# Copyright (C) 2025 +# Licensed under the GPL-3.0 License. +# Created for TagStudio: https://github.com/CyanVoxel/TagStudio + + +from typing import override + + class ParsingError(BaseException): start: int end: int msg: str def __init__(self, start: int, end: int, msg: str = "Syntax Error") -> None: + super().__init__() self.start = start self.end = end self.msg = msg + @override def __str__(self) -> str: return f"Syntax Error {self.start}->{self.end}: {self.msg}" # pragma: nocover + @override def __repr__(self) -> str: return self.__str__() # pragma: nocover diff --git a/src/tagstudio/core/ts_core.py b/src/tagstudio/core/ts_core.py index 7eaa153de..19aebae88 100644 --- a/src/tagstudio/core/ts_core.py +++ b/src/tagstudio/core/ts_core.py @@ -10,7 +10,7 @@ import structlog from tagstudio.core.constants import TS_FOLDER_NAME -from tagstudio.core.library.alchemy.fields import _FieldID +from tagstudio.core.library.alchemy.fields import FieldID from tagstudio.core.library.alchemy.library import Library from tagstudio.core.library.alchemy.models import Entry @@ -46,27 +46,27 @@ def get_gdl_sidecar(cls, filepath: Path, source: str = "") -> dict: return {} if source == "twitter": - info[_FieldID.DESCRIPTION] = json_dump["content"].strip() - info[_FieldID.DATE_PUBLISHED] = json_dump["date"] + info[FieldID.DESCRIPTION] = json_dump["content"].strip() + info[FieldID.DATE_PUBLISHED] = json_dump["date"] elif source == "instagram": - info[_FieldID.DESCRIPTION] = json_dump["description"].strip() - info[_FieldID.DATE_PUBLISHED] = json_dump["date"] + info[FieldID.DESCRIPTION] = json_dump["description"].strip() + info[FieldID.DATE_PUBLISHED] = json_dump["date"] elif source == "artstation": - info[_FieldID.TITLE] = json_dump["title"].strip() - info[_FieldID.ARTIST] = json_dump["user"]["full_name"].strip() - info[_FieldID.DESCRIPTION] = json_dump["description"].strip() - info[_FieldID.TAGS] = json_dump["tags"] + info[FieldID.TITLE] = json_dump["title"].strip() + info[FieldID.ARTIST] = json_dump["user"]["full_name"].strip() + info[FieldID.DESCRIPTION] = json_dump["description"].strip() + info[FieldID.TAGS] = json_dump["tags"] # info["tags"] = [x for x in json_dump["mediums"]["name"]] - info[_FieldID.DATE_PUBLISHED] = json_dump["date"] + info[FieldID.DATE_PUBLISHED] = json_dump["date"] elif source == "newgrounds": # info["title"] = json_dump["title"] # info["artist"] = json_dump["artist"] # info["description"] = json_dump["description"] - info[_FieldID.TAGS] = json_dump["tags"] - info[_FieldID.DATE_PUBLISHED] = json_dump["date"] - info[_FieldID.ARTIST] = json_dump["user"].strip() - info[_FieldID.DESCRIPTION] = json_dump["description"].strip() - info[_FieldID.SOURCE] = json_dump["post_url"].strip() + info[FieldID.TAGS] = json_dump["tags"] + info[FieldID.DATE_PUBLISHED] = json_dump["date"] + info[FieldID.ARTIST] = json_dump["user"].strip() + info[FieldID.DESCRIPTION] = json_dump["description"].strip() + info[FieldID.SOURCE] = json_dump["post_url"].strip() except Exception: logger.exception("Error handling sidecar file.", path=_filepath) diff --git a/src/tagstudio/qt/mixed/tag_database.py b/src/tagstudio/qt/mixed/tag_database.py index 4ff6499e8..180cee9c7 100644 --- a/src/tagstudio/qt/mixed/tag_database.py +++ b/src/tagstudio/qt/mixed/tag_database.py @@ -71,5 +71,5 @@ def delete_tag(self, tag: Tag): if result != QMessageBox.Ok: # type: ignore return - self.lib.remove_tag(tag) + self.lib.remove_tag(tag.id) self.update_tags() diff --git a/src/tagstudio/qt/ts_qt.py b/src/tagstudio/qt/ts_qt.py index 401560c26..934a0d284 100644 --- a/src/tagstudio/qt/ts_qt.py +++ b/src/tagstudio/qt/ts_qt.py @@ -55,7 +55,7 @@ ItemType, SortingModeEnum, ) -from tagstudio.core.library.alchemy.fields import _FieldID +from tagstudio.core.library.alchemy.fields import FieldID from tagstudio.core.library.alchemy.library import Library, LibraryStatus from tagstudio.core.library.alchemy.models import Entry from tagstudio.core.library.ignore import Ignore @@ -1129,7 +1129,7 @@ def run_macro(self, name: MacroID, entry_id: int): elif name == MacroID.BUILD_URL: url = TagStudioCore.build_url(entry, source) if url is not None: - self.lib.add_field_to_entry(entry.id, field_id=_FieldID.SOURCE, value=url) + self.lib.add_field_to_entry(entry.id, field_id=FieldID.SOURCE, value=url) elif name == MacroID.MATCH: TagStudioCore.match_conditions(self.lib, entry.id) elif name == MacroID.CLEAN_URL: diff --git a/tests/test_library.py b/tests/test_library.py index 26f34f23e..447344512 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -13,8 +13,8 @@ from tagstudio.core.enums import DefaultEnum, LibraryPrefs from tagstudio.core.library.alchemy.enums import BrowsingState from tagstudio.core.library.alchemy.fields import ( + FieldID, # pyright: ignore[reportPrivateUsage] TextField, - _FieldID, # pyright: ignore[reportPrivateUsage] ) from tagstudio.core.library.alchemy.library import Library from tagstudio.core.library.alchemy.models import Entry, Tag @@ -174,7 +174,7 @@ def test_remove_tag(library: Library, generate_tag: Callable[..., Tag]): tag_count = len(library.tags) - library.remove_tag(tag) + library.remove_tag(tag.id) assert len(library.tags) == tag_count - 1 @@ -270,7 +270,7 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry): path=Path("xxx"), fields=[ TextField( - type_key=_FieldID.NOTES.name, + type_key=FieldID.NOTES.name, value="notes", position=0, ) @@ -292,8 +292,8 @@ def test_mirror_entry_fields(library: Library, entry_full: Entry): # make sure fields are there after getting it from the library again assert len(entry.fields) == 2 assert {x.type_key for x in entry.fields} == { - _FieldID.TITLE.name, - _FieldID.NOTES.name, + FieldID.TITLE.name, + FieldID.NOTES.name, } @@ -308,14 +308,14 @@ def test_merge_entries(library: Library): folder=folder, path=Path("a"), fields=[ - TextField(type_key=_FieldID.AUTHOR.name, value="Author McAuthorson", position=0), - TextField(type_key=_FieldID.DESCRIPTION.name, value="test description", position=2), + TextField(type_key=FieldID.AUTHOR.name, value="Author McAuthorson", position=0), + TextField(type_key=FieldID.DESCRIPTION.name, value="test description", position=2), ], ) b = Entry( folder=folder, path=Path("b"), - fields=[TextField(type_key=_FieldID.NOTES.name, value="test note", position=1)], + fields=[TextField(type_key=FieldID.NOTES.name, value="test note", position=1)], ) ids = library.add_entries([a, b])