Skip to content

Have Styles and RenderStyles keep a weak reference to node #6017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions src/textual/css/styles.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from dataclasses import dataclass, field
from functools import partial
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Literal, cast
from weakref import ReferenceType, ref

import rich.repr
from rich.style import Style
Expand Down Expand Up @@ -858,25 +858,35 @@ def partial_rich_style(self) -> Style:


@rich.repr.auto
@dataclass
class Styles(StylesBase):
node: DOMNode | None = None
_rules: RulesMap = field(default_factory=RulesMap)
_updates: int = 0
def __init__(
self, node: DOMNode | None = None, rules: RulesMap | None = None
) -> None:
self._node_ref: ReferenceType[DOMNode] | None = (
ref(node) if node is not None else None
)
self._rules: RulesMap = rules if rules is not None else RulesMap()
self._updates: int = 0
self.important: set[str] = set()
self.get_rule: Callable[[str, object], object] = self._rules.get
self.has_rule: Callable[[str], bool] = self._rules.__contains__

important: set[str] = field(default_factory=set)
@property
def node(self) -> DOMNode | None:
"""Get the associated node. If there is no associated node or the associated
node has been garbage collected, return None."""
return self._node_ref() if self._node_ref is not None else None

def __post_init__(self) -> None:
self.get_rule: Callable[[str, object], object] = self._rules.get # type: ignore[assignment]
self.has_rule: Callable[[str], bool] = self._rules.__contains__ # type: ignore[assignment]
@node.setter
def node(self, node: DOMNode | None) -> None:
"""Set the associated node. A weak reference to the node is stored."""
self._node_ref = ref(node) if node is not None else None

def copy(self) -> Styles:
"""Get a copy of this Styles object."""
return Styles(
node=self.node,
_rules=self.get_rules(),
important=self.important,
)
other = Styles(self.node, self.get_rules())
other.important = self.important
return other

def clear_rule(self, rule_name: str) -> bool:
"""Removes the rule from the Styles object, as if it had never been set.
Expand Down Expand Up @@ -1308,8 +1318,12 @@ def css(self) -> str:
class RenderStyles(StylesBase):
"""Presents a combined view of two Styles object: a base Styles and inline Styles."""

def __init__(self, node: DOMNode, base: Styles, inline_styles: Styles) -> None:
self.node = node
def __init__(
self, node: DOMNode | None, base: Styles, inline_styles: Styles
) -> None:
self._node_ref: ReferenceType[DOMNode] | None = (
ref(node) if node is not None else None
)
self._base_styles = base
self._inline_styles = inline_styles
self._animate: BoundAnimator | None = None
Expand All @@ -1334,6 +1348,17 @@ def _cache_key(self) -> int:
"""
return self._updates + self._base_styles._updates + self._inline_styles._updates

@property
def node(self) -> DOMNode | None:
"""Get the associated node. If there is no associated node or the associated
node has been garbage collected, return None."""
return self._node_ref() if self._node_ref is not None else None

@node.setter
def node(self, node: DOMNode | None) -> None:
"""Set the associated node. A weak reference to the node is stored."""
self._node_ref = ref(node) if node is not None else None

@property
def base(self) -> Styles:
"""Quick access to base (css) style."""
Expand Down
13 changes: 7 additions & 6 deletions src/textual/css/stylesheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,19 +616,20 @@ def _process_component_classes(self, node: DOMNode) -> None:
if component_classes:
# Create virtual nodes that exist to extract styles
refresh_node = False
old_component_styles = node._component_styles.copy()
node._component_styles.clear()
old_component_styles_nodes = node._component_styles_nodes.copy()
node._component_styles_nodes.clear()
for component in sorted(component_classes):
virtual_node = DOMNode(classes=component)
virtual_node._attach(node)
self.apply(virtual_node, animate=False)
if (
not refresh_node
and old_component_styles.get(component) != virtual_node.styles
if not refresh_node and (
component not in old_component_styles_nodes
or old_component_styles_nodes[component].styles
!= virtual_node.styles
):
# If the styles have changed we want to refresh the node
refresh_node = True
node._component_styles[component] = virtual_node.styles
node._component_styles_nodes[component] = virtual_node
if refresh_node:
node.refresh()

Expand Down
9 changes: 5 additions & 4 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def __init__(
self.styles: RenderStyles = RenderStyles(
self, self._css_styles, self._inline_styles
)
# A mapping of class names to Styles set in COMPONENT_CLASSES
self._component_styles: dict[str, RenderStyles] = {}
# A mapping of class names to virtual nodes whose 'styles' attribute
# corresponds to Styles set in COMPONENT_CLASSES
self._component_styles_nodes: dict[str, DOMNode] = {}

self._auto_refresh: float | None = None
self._auto_refresh_timer: Timer | None = None
Expand Down Expand Up @@ -579,9 +580,9 @@ def get_component_styles(self, *names: str) -> RenderStyles:
styles = RenderStyles(self, Styles(), Styles())

for name in names:
if name not in self._component_styles:
if name not in self._component_styles_nodes:
raise KeyError(f"No {name!r} key in COMPONENT_CLASSES")
component_styles = self._component_styles[name]
component_styles = self._component_styles_nodes[name].styles
styles.node = component_styles.node
styles.base.merge(component_styles.base)
styles.inline.merge(component_styles.inline)
Expand Down
2 changes: 1 addition & 1 deletion src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,7 +4338,7 @@ async def _message_loop_exit(self) -> None:
self._arrangement_cache.clear()
self._nodes._clear()
self._render_cache = _RenderCache(NULL_SIZE, [])
self._component_styles.clear()
self._component_styles_nodes.clear()
self._query_one_cache.clear()

async def _on_idle(self, event: events.Idle) -> None:
Expand Down
Loading