From 4064481cc0f6c7e4345c7ba5ee7fa36026930b93 Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Wed, 28 May 2025 17:06:00 -0700 Subject: [PATCH 1/6] perf: optimize mapping of category->tags --- src/tagstudio/core/library/alchemy/library.py | 28 +++++ .../qt/widgets/preview/field_containers.py | 107 +++++------------- 2 files changed, 57 insertions(+), 78 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 81e259d44..cde9dbd72 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -1456,6 +1456,34 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None: return session.scalar(statement) + def get_tag_hierarchy( + self, tag_ids: Iterable[int] + ) -> tuple[dict[int, list[int]], dict[int, Tag]]: + current_tag_ids: set[int] = set(tag_ids) + all_tag_ids: set[int] = set() + all_tags: dict[int, Tag] = {} + all_tag_parents: dict[int, list[int]] = {} + + with Session(self.engine) as session: + while len(current_tag_ids) > 0: + all_tag_ids.update(current_tag_ids) + statement = select(TagParent).where(TagParent.parent_id.in_(current_tag_ids)) + tag_parents = session.scalars(statement).fetchall() + current_tag_ids.clear() + for tag_parent in tag_parents: + all_tag_parents.setdefault(tag_parent.parent_id, []).append(tag_parent.child_id) + current_tag_ids.add(tag_parent.child_id) + current_tag_ids = current_tag_ids.difference(all_tag_ids) + + statement = select(Tag).where(Tag.id.in_(all_tag_ids)).options(noload(Tag.parent_tags)) + tags = session.scalars(statement).fetchall() + for tag in tags: + all_tags[tag.id] = tag + for tag in all_tags.values(): + tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])} + + return all_tag_parents, all_tags + def add_parent_tag(self, parent_id: int, child_id: int) -> bool: if parent_id == child_id: return False diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index 76dc48600..b2bc47653 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -158,86 +158,37 @@ def hide_containers(self): c.setHidden(True) def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: - """Get a dictionary of category tags mapped to their respective tags.""" - cats: dict[Tag | None, set[Tag]] = {} - cats[None] = set() - - base_tag_ids: set[int] = {x.id for x in tags} - exhausted: set[int] = set() - cluster_map: dict[int, set[int]] = {} - - def add_to_cluster(tag_id: int, p_ids: list[int] | None = None): - """Maps a Tag's child tags' IDs back to it's parent tag's ID. - - Example: - Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: - "Cartoon Network" -> Johnny Bravo, - "Character" -> "Johnny Bravo", - "TV" -> Johnny Bravo" - """ - tag_obj = self.lib.get_tag(tag_id) # Get full object - if p_ids is None: - p_ids = tag_obj.parent_ids - - for p_id in p_ids: - if cluster_map.get(p_id) is None: - cluster_map[p_id] = set() - # If the p_tag has p_tags of its own, recursively link those to the original Tag. - if tag_id not in cluster_map[p_id]: - cluster_map[p_id].add(tag_id) - p_tag = self.lib.get_tag(p_id) # Get full object - if p_tag.parent_ids: - add_to_cluster( - tag_id, - [sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id], - ) - exhausted.add(p_id) - exhausted.add(tag_id) - - for tag in tags: - add_to_cluster(tag.id) - - logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted) - logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map) + """Get a dictionary of category tags mapped to their respective tags. + Example: + Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: + "Cartoon Network" -> Johnny Bravo, + "Character" -> "Johnny Bravo", + "TV" -> Johnny Bravo" + """ + tag_parents, hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) - # Initialize all categories from parents. - tags_ = {self.lib.get_tag(x) for x in exhausted} - for tag in tags_: + categories: dict[int | None, set[int]] = {None: set()} + for tag in hierarchy_tags.values(): if tag.is_category: - cats[tag] = set() - logger.info("[FieldContainers] Blank Tag Categories", cats=cats) - - # Add tags to any applicable categories. - added_ids: set[int] = set() - for key in cats: - logger.info("[FieldContainers] Checking category tag key", key=key) - - if key: - logger.info( - "[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id) - ) - - if final_tags := cluster_map.get(key.id, set()).union([key.id]): - cats[key] = {self.lib.get_tag(x) for x in final_tags if x in base_tag_ids} - added_ids = added_ids.union({x for x in final_tags if x in base_tag_ids}) - - # Add remaining tags to None key (general case). - cats[None] = {self.lib.get_tag(x) for x in base_tag_ids if x not in added_ids} - logger.info( - f"[FieldContainers] [{key}] Key cluster: None, general case!", - general_tags=cats[key], - added=added_ids, - base_tag_ids=base_tag_ids, - ) - - # Remove unused categories - empty: list[Tag] = [] - for k, v in list(cats.items()): - if not v: - empty.append(k) - for key in empty: - cats.pop(key, None) - + categories[tag.id] = set() + for tag in tags: + has_category_parent = False + parent_ids = tag_parents.get(tag.id, []) + while len(parent_ids) > 0: + grandparent_ids = set() + for parent_id in parent_ids: + if parent_id in categories: + categories[parent_id].add(tag.id) + has_category_parent = True + grandparent_ids.update(tag_parents.get(parent_id, [])) + parent_ids = grandparent_ids + if not has_category_parent: + categories[None].add(tag.id) + + cats = {} + for category_id, descendent_ids in categories.items(): + key = None if category_id is None else hierarchy_tags[category_id] + cats[key] = {hierarchy_tags[d] for d in descendent_ids} logger.info("[FieldContainers] Tag Categories", categories=cats) return cats From b3bd7722232012c0fe9dffca3348dba146dd5ba1 Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Wed, 28 May 2025 17:09:32 -0700 Subject: [PATCH 2/6] perf: one less db call for Library.tag_display_name --- src/tagstudio/core/library/alchemy/library.py | 15 +++++++-------- src/tagstudio/qt/modals/tag_database.py | 2 +- src/tagstudio/qt/modals/tag_search.py | 2 +- src/tagstudio/qt/widgets/tag.py | 2 +- src/tagstudio/qt/widgets/tag_box.py | 4 ++-- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index cde9dbd72..8f1e23a4a 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -314,13 +314,12 @@ def get_field_name_from_id(self, field_id: int) -> _FieldID: return f return None - def tag_display_name(self, tag_id: int) -> str: - with Session(self.engine) as session: - tag = session.scalar(select(Tag).where(Tag.id == tag_id)) - if not tag: - return "" + def tag_display_name(self, tag: Tag | None) -> str: + if not tag: + return "" - if tag.disambiguation_id: + if tag.disambiguation_id: + with Session(self.engine) as session: disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id)) if not disam_tag: return "" @@ -328,8 +327,8 @@ def tag_display_name(self, tag_id: int) -> str: if not disam_name: disam_name = disam_tag.name return f"{tag.name} ({disam_name})" - else: - return tag.name + else: + return tag.name def open_library(self, library_dir: Path, storage_path: Path | None = None) -> LibraryStatus: is_new: bool = True diff --git a/src/tagstudio/qt/modals/tag_database.py b/src/tagstudio/qt/modals/tag_database.py index 423df9933..99b0fc3f3 100644 --- a/src/tagstudio/qt/modals/tag_database.py +++ b/src/tagstudio/qt/modals/tag_database.py @@ -63,7 +63,7 @@ def delete_tag(self, tag: Tag): message_box = QMessageBox( QMessageBox.Question, # type: ignore Translations["tag.remove"], - Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)), + Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)), QMessageBox.Ok | QMessageBox.Cancel, # type: ignore ) diff --git a/src/tagstudio/qt/modals/tag_search.py b/src/tagstudio/qt/modals/tag_search.py index 721b55ce2..fbd7747f0 100644 --- a/src/tagstudio/qt/modals/tag_search.py +++ b/src/tagstudio/qt/modals/tag_search.py @@ -364,7 +364,7 @@ def callback(btp: build_tag.BuildTagPanel): self.edit_modal = PanelModal( build_tag_panel, - self.lib.tag_display_name(tag.id), + self.lib.tag_display_name(tag), done_callback=(self.update_tags(self.search_field.text())), has_save=True, ) diff --git a/src/tagstudio/qt/widgets/tag.py b/src/tagstudio/qt/widgets/tag.py index ca124070f..608bc7c1f 100644 --- a/src/tagstudio/qt/widgets/tag.py +++ b/src/tagstudio/qt/widgets/tag.py @@ -254,7 +254,7 @@ def set_tag(self, tag: Tag | None) -> None: ) if self.lib: - self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id))) + self.bg_button.setText(escape_text(self.lib.tag_display_name(tag))) else: self.bg_button.setText(escape_text(tag.name)) diff --git a/src/tagstudio/qt/widgets/tag_box.py b/src/tagstudio/qt/widgets/tag_box.py index 68ac0fc2a..26c88c7fd 100644 --- a/src/tagstudio/qt/widgets/tag_box.py +++ b/src/tagstudio/qt/widgets/tag_box.py @@ -47,7 +47,7 @@ def __init__( self.set_tags(self.tags) def set_tags(self, tags: typing.Iterable[Tag]): - tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag.id)) + tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag)) logger.info("[TagBoxWidget] Tags:", tags=tags) while self.base_layout.itemAt(0): self.base_layout.takeAt(0).widget().deleteLater() @@ -81,7 +81,7 @@ def edit_tag(self, tag: Tag): self.edit_modal = PanelModal( build_tag_panel, - self.driver.lib.tag_display_name(tag.id), + self.driver.lib.tag_display_name(tag), "Edit Tag", done_callback=lambda: self.driver.preview_panel.update_widgets(update_preview=False), has_save=True, From d44ec5d24258810de0008c459aa05146d5ee7c11 Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Wed, 28 May 2025 17:23:22 -0700 Subject: [PATCH 3/6] fix: include joins in Library.get_tag_hierarchy --- src/tagstudio/core/library/alchemy/library.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 8f1e23a4a..9977a1d8b 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -1474,7 +1474,12 @@ def get_tag_hierarchy( current_tag_ids.add(tag_parent.child_id) current_tag_ids = current_tag_ids.difference(all_tag_ids) - statement = select(Tag).where(Tag.id.in_(all_tag_ids)).options(noload(Tag.parent_tags)) + statement = select(Tag).where(Tag.id.in_(all_tag_ids)) + statement = statement.options( + noload(Tag.parent_tags), + selectinload(Tag.aliases), + joinedload(Tag.color) + ) tags = session.scalars(statement).fetchall() for tag in tags: all_tags[tag.id] = tag From 59460e470c81c3907571cda84f643c4d7e53f91d Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Wed, 28 May 2025 18:08:11 -0700 Subject: [PATCH 4/6] fix: remove category if empty in preview panel --- src/tagstudio/qt/widgets/preview/field_containers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index b2bc47653..67c72a99c 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -187,6 +187,8 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: cats = {} for category_id, descendent_ids in categories.items(): + if len(descendent_ids) == 0: + continue key = None if category_id is None else hierarchy_tags[category_id] cats[key] = {hierarchy_tags[d] for d in descendent_ids} logger.info("[FieldContainers] Tag Categories", categories=cats) From 2f5a71fca4825cc77e2a9e788bd508e8e893d934 Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Thu, 29 May 2025 07:50:55 -0700 Subject: [PATCH 5/6] fix: add missing imports and remove unneeded dict --- src/tagstudio/core/library/alchemy/library.py | 9 +++-- src/tagstudio/core/library/alchemy/models.py | 3 ++ .../qt/widgets/preview/field_containers.py | 36 ++++++++----------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 9977a1d8b..4fcb65ed0 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -12,7 +12,7 @@ from datetime import UTC, datetime from os import makedirs from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable from uuid import uuid4 from warnings import catch_warnings @@ -42,6 +42,7 @@ contains_eager, joinedload, make_transient, + noload, selectinload, ) @@ -1457,7 +1458,8 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None: def get_tag_hierarchy( self, tag_ids: Iterable[int] - ) -> tuple[dict[int, list[int]], dict[int, Tag]]: + ) -> dict[int, Tag]: + """Get a dictionary containing tags in `tag_ids` and all of their ancestor tags.""" current_tag_ids: set[int] = set(tag_ids) all_tag_ids: set[int] = set() all_tags: dict[int, Tag] = {} @@ -1486,7 +1488,8 @@ def get_tag_hierarchy( for tag in all_tags.values(): tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])} - return all_tag_parents, all_tags + return all_tags + def add_parent_tag(self, parent_id: int, child_id: int) -> bool: if parent_id == child_id: diff --git a/src/tagstudio/core/library/alchemy/models.py b/src/tagstudio/core/library/alchemy/models.py index f85a02a44..b157e8c6d 100644 --- a/src/tagstudio/core/library/alchemy/models.py +++ b/src/tagstudio/core/library/alchemy/models.py @@ -156,6 +156,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() + def __hash__(self) -> int: + return hash(self.id) + def __lt__(self, other) -> bool: return self.name < other.name diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index 67c72a99c..8410bde24 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -165,34 +165,28 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: "Character" -> "Johnny Bravo", "TV" -> Johnny Bravo" """ - tag_parents, hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) + hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) - categories: dict[int | None, set[int]] = {None: set()} + categories: dict[Tag | None, set[Tag]] = {None: set()} for tag in hierarchy_tags.values(): if tag.is_category: - categories[tag.id] = set() + categories[tag] = set() for tag in tags: + tag = hierarchy_tags[tag.id] has_category_parent = False - parent_ids = tag_parents.get(tag.id, []) - while len(parent_ids) > 0: - grandparent_ids = set() - for parent_id in parent_ids: - if parent_id in categories: - categories[parent_id].add(tag.id) + parent_tags = tag.parent_tags + while len(parent_tags) > 0: + grandparent_tags: set[Tag] = set() + for parent_tag in parent_tags: + if parent_tag in categories: + categories[parent_tag].add(tag) has_category_parent = True - grandparent_ids.update(tag_parents.get(parent_id, [])) - parent_ids = grandparent_ids + grandparent_tags.update(parent_tag.parent_tags) + parent_tags = grandparent_tags if not has_category_parent: - categories[None].add(tag.id) - - cats = {} - for category_id, descendent_ids in categories.items(): - if len(descendent_ids) == 0: - continue - key = None if category_id is None else hierarchy_tags[category_id] - cats[key] = {hierarchy_tags[d] for d in descendent_ids} - logger.info("[FieldContainers] Tag Categories", categories=cats) - return cats + categories[None].add(tag) + + return dict((c, d) for c, d in categories.items() if len(d) > 0) def remove_field_prompt(self, name: str) -> str: return Translations.format("library.field.confirm_remove", name=name) From 123d570c2aa7b99280b76a567d8458a4b447b6e5 Mon Sep 17 00:00:00 2001 From: Bob Bobs Date: Wed, 4 Jun 2025 07:03:50 -0700 Subject: [PATCH 6/6] fix: add tags that are categories to their own category --- src/tagstudio/core/library/alchemy/library.py | 13 ++++--------- .../qt/widgets/preview/field_containers.py | 5 ++++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 4fcb65ed0..043b004be 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -7,12 +7,12 @@ import shutil import time import unicodedata -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs from pathlib import Path -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING from uuid import uuid4 from warnings import catch_warnings @@ -1456,9 +1456,7 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None: return session.scalar(statement) - def get_tag_hierarchy( - self, tag_ids: Iterable[int] - ) -> dict[int, Tag]: + def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]: """Get a dictionary containing tags in `tag_ids` and all of their ancestor tags.""" current_tag_ids: set[int] = set(tag_ids) all_tag_ids: set[int] = set() @@ -1478,9 +1476,7 @@ def get_tag_hierarchy( statement = select(Tag).where(Tag.id.in_(all_tag_ids)) statement = statement.options( - noload(Tag.parent_tags), - selectinload(Tag.aliases), - joinedload(Tag.color) + noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color) ) tags = session.scalars(statement).fetchall() for tag in tags: @@ -1490,7 +1486,6 @@ def get_tag_hierarchy( return all_tags - def add_parent_tag(self, parent_id: int, child_id: int) -> bool: if parent_id == child_id: return False diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index 8410bde24..ba1f92613 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -159,6 +159,7 @@ def hide_containers(self): def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: """Get a dictionary of category tags mapped to their respective tags. + Example: Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: "Cartoon Network" -> Johnny Bravo, @@ -183,7 +184,9 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: has_category_parent = True grandparent_tags.update(parent_tag.parent_tags) parent_tags = grandparent_tags - if not has_category_parent: + if tag.is_category: + categories[tag].add(tag) + elif not has_category_parent: categories[None].add(tag) return dict((c, d) for c, d in categories.items() if len(d) > 0)