diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index c01c20b76..e7d3800c2 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -7,7 +7,7 @@ import shutil import time import unicodedata -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from dataclasses import dataclass from datetime import UTC, datetime from os import makedirs @@ -42,6 +42,7 @@ contains_eager, joinedload, make_transient, + noload, selectinload, ) @@ -314,13 +315,12 @@ def get_field_name_from_id(self, field_id: int) -> _FieldID: return f return None - def tag_display_name(self, tag_id: int) -> str: - with Session(self.engine) as session: - tag = session.scalar(select(Tag).where(Tag.id == tag_id)) - if not tag: - return "" + def tag_display_name(self, tag: Tag | None) -> str: + if not tag: + return "" - if tag.disambiguation_id: + if tag.disambiguation_id: + with Session(self.engine) as session: disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id)) if not disam_tag: return "" @@ -328,8 +328,8 @@ def tag_display_name(self, tag_id: int) -> str: if not disam_name: disam_name = disam_tag.name return f"{tag.name} ({disam_name})" - else: - return tag.name + else: + return tag.name def open_library(self, library_dir: Path, storage_path: Path | None = None) -> LibraryStatus: is_new: bool = True @@ -1457,6 +1457,36 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None: return session.scalar(statement) + def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]: + """Get a dictionary containing tags in `tag_ids` and all of their ancestor tags.""" + current_tag_ids: set[int] = set(tag_ids) + all_tag_ids: set[int] = set() + all_tags: dict[int, Tag] = {} + all_tag_parents: dict[int, list[int]] = {} + + with Session(self.engine) as session: + while len(current_tag_ids) > 0: + all_tag_ids.update(current_tag_ids) + statement = select(TagParent).where(TagParent.parent_id.in_(current_tag_ids)) + tag_parents = session.scalars(statement).fetchall() + current_tag_ids.clear() + for tag_parent in tag_parents: + all_tag_parents.setdefault(tag_parent.parent_id, []).append(tag_parent.child_id) + current_tag_ids.add(tag_parent.child_id) + current_tag_ids = current_tag_ids.difference(all_tag_ids) + + statement = select(Tag).where(Tag.id.in_(all_tag_ids)) + statement = statement.options( + noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color) + ) + tags = session.scalars(statement).fetchall() + for tag in tags: + all_tags[tag.id] = tag + for tag in all_tags.values(): + tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])} + + return all_tags + def add_parent_tag(self, parent_id: int, child_id: int) -> bool: if parent_id == child_id: return False diff --git a/src/tagstudio/core/library/alchemy/models.py b/src/tagstudio/core/library/alchemy/models.py index f85a02a44..b157e8c6d 100644 --- a/src/tagstudio/core/library/alchemy/models.py +++ b/src/tagstudio/core/library/alchemy/models.py @@ -156,6 +156,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() + def __hash__(self) -> int: + return hash(self.id) + def __lt__(self, other) -> bool: return self.name < other.name diff --git a/src/tagstudio/qt/modals/tag_database.py b/src/tagstudio/qt/modals/tag_database.py index 467fa7870..8eec4a733 100644 --- a/src/tagstudio/qt/modals/tag_database.py +++ b/src/tagstudio/qt/modals/tag_database.py @@ -63,7 +63,7 @@ def delete_tag(self, tag: Tag): message_box = QMessageBox( QMessageBox.Question, # type: ignore Translations["tag.remove"], - Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)), + Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)), QMessageBox.Ok | QMessageBox.Cancel, # type: ignore ) diff --git a/src/tagstudio/qt/modals/tag_search.py b/src/tagstudio/qt/modals/tag_search.py index 6b2ed4d76..cb55e66a7 100644 --- a/src/tagstudio/qt/modals/tag_search.py +++ b/src/tagstudio/qt/modals/tag_search.py @@ -362,7 +362,7 @@ def callback(btp: build_tag.BuildTagPanel): self.edit_modal = PanelModal( build_tag_panel, - self.lib.tag_display_name(tag.id), + self.lib.tag_display_name(tag), done_callback=(self.update_tags(self.search_field.text())), has_save=True, ) diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index f93b5432a..2cd1d6ee7 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -159,98 +159,38 @@ def hide_containers(self): c.setHidden(True) def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: - """Get a dictionary of category tags mapped to their respective tags.""" - cats: dict[Tag | None, set[Tag]] = {} - cats[None] = set() - - base_tag_ids: set[int] = {x.id for x in tags} - exhausted: set[int] = set() - cluster_map: dict[int, set[int]] = {} - - def add_to_cluster(tag_id: int, p_ids: list[int] | None = None): - """Maps a Tag's child tags' IDs back to it's parent tag's ID. - - Example: - Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: - "Cartoon Network" -> Johnny Bravo, - "Character" -> "Johnny Bravo", - "TV" -> Johnny Bravo" - """ - tag_obj = self.lib.get_tag(tag_id) # Get full object - if p_ids is None: - assert tag_obj is not None - p_ids = tag_obj.parent_ids - - for p_id in p_ids: - if cluster_map.get(p_id) is None: - cluster_map[p_id] = set() - # If the p_tag has p_tags of its own, recursively link those to the original Tag. - if tag_id not in cluster_map[p_id]: - cluster_map[p_id].add(tag_id) - p_tag = self.lib.get_tag(p_id) # Get full object - assert p_tag is not None - if p_tag.parent_ids: - add_to_cluster( - tag_id, - [sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id], - ) - exhausted.add(p_id) - exhausted.add(tag_id) - - for tag in tags: - add_to_cluster(tag.id) + """Get a dictionary of category tags mapped to their respective tags. - logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted) - logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map) + Example: + Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: + "Cartoon Network" -> Johnny Bravo, + "Character" -> "Johnny Bravo", + "TV" -> Johnny Bravo" + """ + hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) - # Initialize all categories from parents. - tags_ = {t for tid in exhausted if (t := self.lib.get_tag(tid)) is not None} - for tag in tags_: + categories: dict[Tag | None, set[Tag]] = {None: set()} + for tag in hierarchy_tags.values(): if tag.is_category: - cats[tag] = set() - logger.info("[FieldContainers] Blank Tag Categories", cats=cats) - - # Add tags to any applicable categories. - added_ids: set[int] = set() - for key in cats: - logger.info("[FieldContainers] Checking category tag key", key=key) - - if key: - logger.info( - "[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id) - ) - - if final_tags := cluster_map.get(key.id, set()).union([key.id]): - cats[key] = { - t - for tid in final_tags - if tid in base_tag_ids and (t := self.lib.get_tag(tid)) is not None - } - added_ids = added_ids.union({tid for tid in final_tags if tid in base_tag_ids}) - - # Add remaining tags to None key (general case). - cats[None] = { - t - for tid in base_tag_ids - if tid not in added_ids and (t := self.lib.get_tag(tid)) is not None - } - logger.info( - "[FieldContainers] Key cluster: None, general case!", - general_tags=cats[None], - added=added_ids, - base_tag_ids=base_tag_ids, - ) - - # Remove unused categories - empty: list[Tag | None] = [] - for k, v in list(cats.items()): - if not v: - empty.append(k) - for key in empty: - cats.pop(key, None) + categories[tag] = set() + for tag in tags: + tag = hierarchy_tags[tag.id] + has_category_parent = False + parent_tags = tag.parent_tags + while len(parent_tags) > 0: + grandparent_tags: set[Tag] = set() + for parent_tag in parent_tags: + if parent_tag in categories: + categories[parent_tag].add(tag) + has_category_parent = True + grandparent_tags.update(parent_tag.parent_tags) + parent_tags = grandparent_tags + if tag.is_category: + categories[tag].add(tag) + elif not has_category_parent: + categories[None].add(tag) - logger.info("[FieldContainers] Tag Categories", categories=cats) - return cats + return dict((c, d) for c, d in categories.items() if len(d) > 0) def remove_field_prompt(self, name: str) -> str: return Translations.format("library.field.confirm_remove", name=name) diff --git a/src/tagstudio/qt/widgets/tag.py b/src/tagstudio/qt/widgets/tag.py index ca124070f..608bc7c1f 100644 --- a/src/tagstudio/qt/widgets/tag.py +++ b/src/tagstudio/qt/widgets/tag.py @@ -254,7 +254,7 @@ def set_tag(self, tag: Tag | None) -> None: ) if self.lib: - self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id))) + self.bg_button.setText(escape_text(self.lib.tag_display_name(tag))) else: self.bg_button.setText(escape_text(tag.name)) diff --git a/src/tagstudio/qt/widgets/tag_box.py b/src/tagstudio/qt/widgets/tag_box.py index dbeea052b..52f4ce365 100644 --- a/src/tagstudio/qt/widgets/tag_box.py +++ b/src/tagstudio/qt/widgets/tag_box.py @@ -47,7 +47,7 @@ def __init__( self.set_tags(self.tags) def set_tags(self, tags: typing.Iterable[Tag]): - tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag.id)) + tags_ = sorted(list(tags), key=lambda tag: self.driver.lib.tag_display_name(tag)) logger.info("[TagBoxWidget] Tags:", tags=tags) while self.base_layout.itemAt(0): self.base_layout.takeAt(0).widget().deleteLater() @@ -79,7 +79,7 @@ def edit_tag(self, tag: Tag): self.edit_modal = PanelModal( build_tag_panel, - self.driver.lib.tag_display_name(tag.id), + self.driver.lib.tag_display_name(tag), "Edit Tag", done_callback=lambda: self.driver.preview_panel.update_widgets(update_preview=False), has_save=True,