Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ ignore_errors = true
qt_api = "pyside6"

[tool.pyright]
ignore = ["src/tagstudio/qt/helpers/vendored/pydub/", ".venv/**"]
ignore = [
".venv/**",
"src/tagstudio/core/library/json/",
"src/tagstudio/qt/helpers/vendored/pydub/",
]
include = ["src/tagstudio", "tests"]
reportAny = false
reportIgnoreCommentWithoutRule = false
Expand Down
11 changes: 11 additions & 0 deletions src/tagstudio/core/library/alchemy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@
)
SELECT * FROM ChildTags;
""")

TAG_CHILDREN_ID_QUERY = text("""
WITH RECURSIVE ChildTags AS (
SELECT :tag_id AS tag_id
UNION
SELECT tp.child_id AS tag_id
FROM tag_parents tp
INNER JOIN ChildTags c ON tp.parent_id = c.tag_id
)
SELECT tag_id FROM ChildTags;
""")
7 changes: 5 additions & 2 deletions src/tagstudio/core/library/alchemy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


from pathlib import Path
from typing import override

import structlog
from sqlalchemy import Dialect, Engine, String, TypeDecorator, create_engine, text
Expand All @@ -19,12 +20,14 @@ class PathType(TypeDecorator):
impl = String
cache_ok = True

def process_bind_param(self, value: Path, dialect: Dialect):
@override
def process_bind_param(self, value: Path | None, dialect: Dialect):
if value is not None:
return Path(value).as_posix()
return None

def process_result_value(self, value: str, dialect: Dialect):
@override
def process_result_value(self, value: str | None, dialect: Dialect):
if value is not None:
return Path(value)
return None
Expand Down
24 changes: 14 additions & 10 deletions src/tagstudio/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, override

from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
Expand All @@ -32,27 +32,28 @@ def type_key(self) -> Mapped[str]:

@declared_attr
def type(self) -> Mapped[ValueType]:
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore # pyright: ignore[reportArgumentType]

@declared_attr
def entry_id(self) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))

@declared_attr
def entry(self) -> Mapped[Entry]:
return relationship(foreign_keys=[self.entry_id]) # type: ignore
return relationship(foreign_keys=[self.entry_id]) # type: ignore # pyright: ignore[reportArgumentType]

@declared_attr
def position(self) -> Mapped[int]:
return mapped_column(default=0)

@override
def __hash__(self):
return hash(self.__key())

def __key(self):
def __key(self): # pyright: ignore[reportUnknownParameterType]
raise NotImplementedError

value: Any
value: Any # pyright: ignore


class BooleanField(BaseField):
Expand All @@ -63,7 +64,8 @@ class BooleanField(BaseField):
def __key(self):
return (self.type, self.value)

def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, BooleanField):
return self.__key() == value.__key()
raise NotImplementedError
Expand All @@ -74,10 +76,11 @@ class TextField(BaseField):

value: Mapped[str | None]

def __key(self) -> tuple:
def __key(self) -> tuple[ValueType, str | None]:
return self.type, self.value

def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, TextField):
return self.__key() == value.__key()
elif isinstance(value, DatetimeField):
Expand All @@ -93,7 +96,8 @@ class DatetimeField(BaseField):
def __key(self):
return (self.type, self.value)

def __eq__(self, value) -> bool:
@override
def __eq__(self, value: object) -> bool:
if isinstance(value, DatetimeField):
return self.__key() == value.__key()
raise NotImplementedError
Expand All @@ -107,7 +111,7 @@ class DefaultField:
is_default: bool = field(default=False)


class _FieldID(Enum):
class FieldID(Enum):
"""Only for bootstrapping content of DB table."""

TITLE = DefaultField(id=0, name="Title", type=FieldTypeEnum.TEXT_LINE, is_default=True)
Expand Down
Loading