Skip to content

perf: Optimize db queries for preview panel #942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
import time
import unicodedata
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from datetime import UTC, datetime
from os import makedirs
Expand Down Expand Up @@ -42,6 +42,7 @@
contains_eager,
joinedload,
make_transient,
noload,
selectinload,
)

Expand Down Expand Up @@ -314,22 +315,21 @@ def get_field_name_from_id(self, field_id: int) -> _FieldID:
return f
return None

def tag_display_name(self, tag_id: int) -> str:
with Session(self.engine) as session:
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
if not tag:
return "<NO TAG>"
def tag_display_name(self, tag: Tag | None) -> str:
if not tag:
return "<NO TAG>"

if tag.disambiguation_id:
if tag.disambiguation_id:
with Session(self.engine) as session:
disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id))
if not disam_tag:
return "<NO DISAM TAG>"
disam_name = disam_tag.shorthand
if not disam_name:
disam_name = disam_tag.name
return f"{tag.name} ({disam_name})"
else:
return tag.name
else:
return tag.name

def open_library(self, library_dir: Path, storage_path: Path | None = None) -> LibraryStatus:
is_new: bool = True
Expand Down Expand Up @@ -1457,6 +1457,36 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None:

return session.scalar(statement)

def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]:
"""Get a dictionary containing tags in `tag_ids` and all of their ancestor tags."""
current_tag_ids: set[int] = set(tag_ids)
all_tag_ids: set[int] = set()
all_tags: dict[int, Tag] = {}
all_tag_parents: dict[int, list[int]] = {}

with Session(self.engine) as session:
while len(current_tag_ids) > 0:
all_tag_ids.update(current_tag_ids)
statement = select(TagParent).where(TagParent.parent_id.in_(current_tag_ids))
tag_parents = session.scalars(statement).fetchall()
current_tag_ids.clear()
for tag_parent in tag_parents:
all_tag_parents.setdefault(tag_parent.parent_id, []).append(tag_parent.child_id)
current_tag_ids.add(tag_parent.child_id)
current_tag_ids = current_tag_ids.difference(all_tag_ids)

statement = select(Tag).where(Tag.id.in_(all_tag_ids))
statement = statement.options(
noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color)
)
tags = session.scalars(statement).fetchall()
for tag in tags:
all_tags[tag.id] = tag
for tag in all_tags.values():
tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])}

return all_tags

def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
if parent_id == child_id:
return False
Expand Down
3 changes: 3 additions & 0 deletions src/tagstudio/core/library/alchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.__str__()

def __hash__(self) -> int:
return hash(self.id)

def __lt__(self, other) -> bool:
return self.name < other.name

Expand Down
2 changes: 1 addition & 1 deletion src/tagstudio/qt/modals/tag_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def delete_tag(self, tag: Tag):
message_box = QMessageBox(
QMessageBox.Question, # type: ignore
Translations["tag.remove"],
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)),
Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)),
QMessageBox.Ok | QMessageBox.Cancel, # type: ignore
)

Expand Down
2 changes: 1 addition & 1 deletion src/tagstudio/qt/modals/tag_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def callback(btp: build_tag.BuildTagPanel):

self.edit_modal = PanelModal(
build_tag_panel,
self.lib.tag_display_name(tag.id),
self.lib.tag_display_name(tag),
done_callback=(self.update_tags(self.search_field.text())),
has_save=True,
)
Expand Down
116 changes: 28 additions & 88 deletions src/tagstudio/qt/widgets/preview/field_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,98 +159,38 @@ def hide_containers(self):
c.setHidden(True)

def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]:
"""Get a dictionary of category tags mapped to their respective tags."""
cats: dict[Tag | None, set[Tag]] = {}
cats[None] = set()

base_tag_ids: set[int] = {x.id for x in tags}
exhausted: set[int] = set()
cluster_map: dict[int, set[int]] = {}

def add_to_cluster(tag_id: int, p_ids: list[int] | None = None):
"""Maps a Tag's child tags' IDs back to it's parent tag's ID.

Example:
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
"Cartoon Network" -> Johnny Bravo,
"Character" -> "Johnny Bravo",
"TV" -> Johnny Bravo"
"""
tag_obj = self.lib.get_tag(tag_id) # Get full object
if p_ids is None:
assert tag_obj is not None
p_ids = tag_obj.parent_ids

for p_id in p_ids:
if cluster_map.get(p_id) is None:
cluster_map[p_id] = set()
# If the p_tag has p_tags of its own, recursively link those to the original Tag.
if tag_id not in cluster_map[p_id]:
cluster_map[p_id].add(tag_id)
p_tag = self.lib.get_tag(p_id) # Get full object
assert p_tag is not None
if p_tag.parent_ids:
add_to_cluster(
tag_id,
[sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id],
)
exhausted.add(p_id)
exhausted.add(tag_id)

for tag in tags:
add_to_cluster(tag.id)
"""Get a dictionary of category tags mapped to their respective tags.

logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted)
logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map)
Example:
Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to:
"Cartoon Network" -> Johnny Bravo,
"Character" -> "Johnny Bravo",
"TV" -> Johnny Bravo"
"""
hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags)

# Initialize all categories from parents.
tags_ = {t for tid in exhausted if (t := self.lib.get_tag(tid)) is not None}
for tag in tags_:
categories: dict[Tag | None, set[Tag]] = {None: set()}
for tag in hierarchy_tags.values():
if tag.is_category:
cats[tag] = set()
logger.info("[FieldContainers] Blank Tag Categories", cats=cats)

# Add tags to any applicable categories.
added_ids: set[int] = set()
for key in cats:
logger.info("[FieldContainers] Checking category tag key", key=key)

if key:
logger.info(
"[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id)
)

if final_tags := cluster_map.get(key.id, set()).union([key.id]):
cats[key] = {
t
for tid in final_tags
if tid in base_tag_ids and (t := self.lib.get_tag(tid)) is not None
}
added_ids = added_ids.union({tid for tid in final_tags if tid in base_tag_ids})

# Add remaining tags to None key (general case).
cats[None] = {
t
for tid in base_tag_ids
if tid not in added_ids and (t := self.lib.get_tag(tid)) is not None
}
logger.info(
"[FieldContainers] Key cluster: None, general case!",
general_tags=cats[None],
added=added_ids,
base_tag_ids=base_tag_ids,
)

# Remove unused categories
empty: list[Tag | None] = []
for k, v in list(cats.items()):
if not v:
empty.append(k)
for key in empty:
cats.pop(key, None)
categories[tag] = set()
for tag in tags:
tag = hierarchy_tags[tag.id]
has_category_parent = False
parent_tags = tag.parent_tags
while len(parent_tags) > 0:
grandparent_tags: set[Tag] = set()
for parent_tag in parent_tags:
if parent_tag in categories:
categories[parent_tag].add(tag)
has_category_parent = True
grandparent_tags.update(parent_tag.parent_tags)
parent_tags = grandparent_tags
if tag.is_category:
categories[tag].add(tag)
elif not has_category_parent:
categories[None].add(tag)

logger.info("[FieldContainers] Tag Categories", categories=cats)
return cats
return dict((c, d) for c, d in categories.items() if len(d) > 0)

def remove_field_prompt(self, name: str) -> str:
return Translations.format("library.field.confirm_remove", name=name)
Expand Down
2 changes: 1 addition & 1 deletion src/tagstudio/qt/widgets/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def set_tag(self, tag: Tag | None) -> None:
)

if self.lib:
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id)))
self.bg_button.setText(escape_text(self.lib.tag_display_name(tag)))
else:
self.bg_button.setText(escape_text(tag.name))

Expand Down
4 changes: 2 additions & 2 deletions src/tagstudio/qt/widgets/tag_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.set_tags(self.tags)

def set_tags(self, tags: typing.Iterable[Tag]):
tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag.id))
tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag))
logger.info("[TagBoxWidget] Tags:", tags=tags)
while self.base_layout.itemAt(0):
self.base_layout.takeAt(0).widget().deleteLater()
Expand Down Expand Up @@ -79,7 +79,7 @@ def edit_tag(self, tag: Tag):

self.edit_modal = PanelModal(
build_tag_panel,
self.driver.lib.tag_display_name(tag.id),
self.driver.lib.tag_display_name(tag),
"Edit Tag",
done_callback=lambda: self.driver.preview_panel.update_widgets(update_preview=False),
has_save=True,
Expand Down