diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index d306d42768..024e92184e 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -1,9 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass, field from functools import partial from operator import attrgetter from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Literal, cast +from weakref import ReferenceType, ref import rich.repr from rich.style import Style @@ -858,25 +858,35 @@ def partial_rich_style(self) -> Style: @rich.repr.auto -@dataclass class Styles(StylesBase): - node: DOMNode | None = None - _rules: RulesMap = field(default_factory=RulesMap) - _updates: int = 0 + def __init__( + self, node: DOMNode | None = None, rules: RulesMap | None = None + ) -> None: + self._node_ref: ReferenceType[DOMNode] | None = ( + ref(node) if node is not None else None + ) + self._rules: RulesMap = rules if rules is not None else RulesMap() + self._updates: int = 0 + self.important: set[str] = set() + self.get_rule: Callable[[str, object], object] = self._rules.get + self.has_rule: Callable[[str], bool] = self._rules.__contains__ - important: set[str] = field(default_factory=set) + @property + def node(self) -> DOMNode | None: + """Get the associated node. If there is no associated node or the associated + node has been garbage collected, return None.""" + return self._node_ref() if self._node_ref is not None else None - def __post_init__(self) -> None: - self.get_rule: Callable[[str, object], object] = self._rules.get # type: ignore[assignment] - self.has_rule: Callable[[str], bool] = self._rules.__contains__ # type: ignore[assignment] + @node.setter + def node(self, node: DOMNode | None) -> None: + """Set the associated node. A weak reference to the node is stored.""" + self._node_ref = ref(node) if node is not None else None def copy(self) -> Styles: """Get a copy of this Styles object.""" - return Styles( - node=self.node, - _rules=self.get_rules(), - important=self.important, - ) + other = Styles(self.node, self.get_rules()) + other.important = self.important + return other def clear_rule(self, rule_name: str) -> bool: """Removes the rule from the Styles object, as if it had never been set. @@ -1308,8 +1318,12 @@ def css(self) -> str: class RenderStyles(StylesBase): """Presents a combined view of two Styles object: a base Styles and inline Styles.""" - def __init__(self, node: DOMNode, base: Styles, inline_styles: Styles) -> None: - self.node = node + def __init__( + self, node: DOMNode | None, base: Styles, inline_styles: Styles + ) -> None: + self._node_ref: ReferenceType[DOMNode] | None = ( + ref(node) if node is not None else None + ) self._base_styles = base self._inline_styles = inline_styles self._animate: BoundAnimator | None = None @@ -1334,6 +1348,17 @@ def _cache_key(self) -> int: """ return self._updates + self._base_styles._updates + self._inline_styles._updates + @property + def node(self) -> DOMNode | None: + """Get the associated node. If there is no associated node or the associated + node has been garbage collected, return None.""" + return self._node_ref() if self._node_ref is not None else None + + @node.setter + def node(self, node: DOMNode | None) -> None: + """Set the associated node. A weak reference to the node is stored.""" + self._node_ref = ref(node) if node is not None else None + @property def base(self) -> Styles: """Quick access to base (css) style.""" diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index 7d97cda9f0..9af0623147 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -616,19 +616,20 @@ def _process_component_classes(self, node: DOMNode) -> None: if component_classes: # Create virtual nodes that exist to extract styles refresh_node = False - old_component_styles = node._component_styles.copy() - node._component_styles.clear() + old_component_styles_nodes = node._component_styles_nodes.copy() + node._component_styles_nodes.clear() for component in sorted(component_classes): virtual_node = DOMNode(classes=component) virtual_node._attach(node) self.apply(virtual_node, animate=False) - if ( - not refresh_node - and old_component_styles.get(component) != virtual_node.styles + if not refresh_node and ( + component not in old_component_styles_nodes + or old_component_styles_nodes[component].styles + != virtual_node.styles ): # If the styles have changed we want to refresh the node refresh_node = True - node._component_styles[component] = virtual_node.styles + node._component_styles_nodes[component] = virtual_node if refresh_node: node.refresh() diff --git a/src/textual/dom.py b/src/textual/dom.py index 2dba90b015..f5f474b8b9 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -207,8 +207,9 @@ def __init__( self.styles: RenderStyles = RenderStyles( self, self._css_styles, self._inline_styles ) - # A mapping of class names to Styles set in COMPONENT_CLASSES - self._component_styles: dict[str, RenderStyles] = {} + # A mapping of class names to virtual nodes whose 'styles' attribute + # corresponds to Styles set in COMPONENT_CLASSES + self._component_styles_nodes: dict[str, DOMNode] = {} self._auto_refresh: float | None = None self._auto_refresh_timer: Timer | None = None @@ -579,9 +580,9 @@ def get_component_styles(self, *names: str) -> RenderStyles: styles = RenderStyles(self, Styles(), Styles()) for name in names: - if name not in self._component_styles: + if name not in self._component_styles_nodes: raise KeyError(f"No {name!r} key in COMPONENT_CLASSES") - component_styles = self._component_styles[name] + component_styles = self._component_styles_nodes[name].styles styles.node = component_styles.node styles.base.merge(component_styles.base) styles.inline.merge(component_styles.inline) diff --git a/src/textual/widget.py b/src/textual/widget.py index 8543706dad..3fcef31098 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -4338,7 +4338,7 @@ async def _message_loop_exit(self) -> None: self._arrangement_cache.clear() self._nodes._clear() self._render_cache = _RenderCache(NULL_SIZE, []) - self._component_styles.clear() + self._component_styles_nodes.clear() self._query_one_cache.clear() async def _on_idle(self, event: events.Idle) -> None: