diff --git a/notebooks/structural_components_dataclass.ipynb b/notebooks/structural_components_dataclass.ipynb new file mode 100644 index 000000000..611d76767 --- /dev/null +++ b/notebooks/structural_components_dataclass.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ab70a522", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + } + ], + "source": [ + "from pymc_extras.statespace.models.structural import (\n", + " RegressionComponent,\n", + " RegressionComponentDataClass,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "17021aa3", + "metadata": {}, + "outputs": [], + "source": [ + "# Current way\n", + "reg = RegressionComponent(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "219eb5da", + "metadata": {}, + "outputs": [], + "source": [ + "# Proposed way\n", + "reg_dataclass = RegressionComponentDataClass(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7ff76653", + "metadata": {}, + "source": [ + "# Reminder of current implementation" + ] + }, + { + "cell_type": "markdown", + "id": "c05f86f6", + "metadata": {}, + "source": [ + "Currently state names are a list of string that only contain the names of the states" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7e37e574", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a[regression_shared]', 'b[regression_shared]']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.state_names" + ] + }, + { + "cell_type": "markdown", + "id": "0d484b59", + "metadata": {}, + "source": [ + "In the proposed dataclass implementation each state is a `StateProperty` and all the states are `StateProporties` dataclasses." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dee62a66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "states: ['a[regression_shared]', 'b[regression_shared]']\n", + "observed: [True, True]\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cebd72af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a[regression_shared]\n", + "observed: True\n", + "shared: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names[\"a[regression_shared]\"]) # state name is the key" + ] + }, + { + "cell_type": "markdown", + "id": "1b8690a1", + "metadata": {}, + "source": [ + "Similarly with shock names we now have a shock_info that is a `ShockProperties` dataclass composed of `ShockProperty` dataclasses" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1320adac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a_shared', 'b_shared']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.shock_names" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6c905946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shocks: ['a_shared', 'b_shared']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff60922", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a_shared\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info[\"a_shared\"])" + ] + }, + { + "cell_type": "markdown", + "id": "bdbe8f7c", + "metadata": {}, + "source": [ + "This pattern continues for data, parameters, and coords as shown below" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ead54287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['data_regression']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_names" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ba784a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data_regression': {'shape': (None, 2), 'dims': ('time', 'state_regression')}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_info" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "521382b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data: ['data_regression']\n", + "needs exogenous data: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "85b7e774", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: data_regression\n", + "shape: (None, 2)\n", + "dims: ('time', 'state_regression')\n", + "is_exogenous: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info[\"data_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e1ed9d7a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': {'shape': (2,),\n", + " 'constraints': None,\n", + " 'dims': ('state_regression',)},\n", + " 'sigma_beta_regression': {'shape': (2,),\n", + " 'constraints': 'Positive',\n", + " 'dims': ('state_regression',)}}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_info" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8d194fe2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['beta_regression', 'sigma_beta_regression']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_names" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7fccad81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': ('state_regression',),\n", + " 'sigma_beta_regression': ('state_regression',)}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9787c813", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "914e97da", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "98875fd1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: sigma_beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n", + "constraints: Positive\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"sigma_beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a195cec5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'state_regression': ['a', 'b'], 'endog_regression': ['y']}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.coords" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "62622777", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "coordinates:\n", + " dimension: state_regression\n", + " labels: ['a', 'b']\n", + "\n", + " dimension: endog_regression\n", + " labels: ['y']\n", + "\n" + ] + } + ], + "source": [ + "print(reg_dataclass.coords)" + ] + }, + { + "cell_type": "markdown", + "id": "a79b845c", + "metadata": {}, + "source": [ + "# Mapping between items" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9484c709", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "# Important to be able to map between parameters -> dimensions -> dimension labels\n", + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "85573fa2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"]) # Key is parameter name" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "32f56fd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# dimension for parameter beta_regression is state_regression. Let's map to dimension labels\n", + "print(\n", + " reg_dataclass.coords[\n", + " reg_dataclass.param_info[\"beta_regression\"].dims[0] # Key is dimension name\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "35ae00a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# Equivalently\n", + "print(reg_dataclass.coords[\"state_regression\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py new file mode 100644 index 000000000..35c1394d0 --- /dev/null +++ b/pymc_extras/statespace/core/properties.py @@ -0,0 +1,287 @@ +import warnings + +from collections.abc import Iterator +from copy import deepcopy +from dataclasses import dataclass, fields +from typing import Generic, Self, TypeVar + +from pymc_extras.statespace.core import PyMCStateSpace +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +@dataclass(frozen=True) +class Property: + def __str__(self) -> str: + return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self)) + + +T = TypeVar("T", bound=Property) + + +@dataclass(frozen=True) +class Info(Generic[T]): + items: tuple[T, ...] + key_field: str = "name" + _index: dict[str, T] | None = None + + def __post_init__(self): + index = {} + missing_attr = [] + for item in self.items: + if not hasattr(item, self.key_field): + missing_attr.append(item) + continue + key = getattr(item, self.key_field) + # if key in index: + # raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states + index[key] = item + if missing_attr: + raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}") + object.__setattr__(self, "_index", index) + + def _key(self, item: T) -> str: + return getattr(item, self.key_field) + + def get(self, key: str, default=None) -> T | None: + return self._index.get(key, default) + + def __getitem__(self, key: str) -> T: + try: + return self._index[key] + except KeyError as e: + available = ", ".join(self._index.keys()) + raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e + + def __contains__(self, key: object) -> bool: + return key in self._index + + def __iter__(self) -> Iterator[str]: + return iter(self.items) + + def __len__(self) -> int: + return len(self.items) + + def __str__(self) -> str: + return f"{self.key_field}s: {list(self._index.keys())}" + + @property + def names(self) -> tuple[str, ...]: + return tuple(self._index.keys()) + + def copy(self) -> "Info[T]": + return deepcopy(self) + + +@dataclass(frozen=True) +class Parameter(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + constraints: str | None = None + + +@dataclass(frozen=True) +class ParameterInfo(Info[Parameter]): + def __init__(self, parameters: list[Parameter]): + super().__init__(items=tuple(parameters), key_field="name") + + def add(self, parameter: Parameter) -> "ParameterInfo": + # return a new ParameterInfo with parameter appended + return ParameterInfo(parameters=[*list(self.items), parameter]) + + def merge(self, other: "ParameterInfo", allow_duplicates: bool = False) -> "ParameterInfo": + """Combine parameters from two ParameterInfo objects.""" + if not isinstance(other, ParameterInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + raise ValueError(f"Duplicate parameter names found: {overlapping}") + + return ParameterInfo(parameters=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Data(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + is_exogenous: bool + + +@dataclass(frozen=True) +class DataInfo(Info[Data]): + def __init__(self, data: list[Data]): + super().__init__(items=tuple(data), key_field="name") + + @property + def needs_exogenous_data(self) -> bool: + return any(d.is_exogenous for d in self.items) + + @property + def exogenous_names(self) -> tuple[str, ...]: + return tuple(d.name for d in self.items if d.is_exogenous) + + def __str__(self) -> str: + return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" + + def add(self, data: Data) -> "DataInfo": + # return a new DataInfo with data appended + return DataInfo(data=[*list(self.items), data]) + + def merge(self, other: "DataInfo", allow_duplicates: bool = False) -> "DataInfo": + """Combine data from two DataInfo objects.""" + if not isinstance(other, DataInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + raise ValueError(f"Duplicate data names found: {overlapping}") + + return DataInfo(data=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Coord(Property): + dimension: str + labels: tuple[str, ...] + + +@dataclass(frozen=True) +class CoordInfo(Info[Coord]): + def __init__(self, coords: list[Coord]): + super().__init__(items=tuple(coords), key_field="dimension") + + def __str__(self) -> str: + base = "coordinates:" + for coord in self.items: + coord_str = str(coord) + indented = "\n".join(" " + line for line in coord_str.splitlines()) + base += "\n" + indented + "\n" + return base + + @classmethod + def default_coords_from_model( + cls, model: PyMCStateSpace + ) -> ( + Self + ): # TODO: Need to figure out how to include Component type was causing circular import issues + states = tuple(model.state_names) + obs_states = tuple(model.observed_states) + shocks = tuple(model.shock_names) + + dim_to_labels = ( + (ALL_STATE_DIM, states), + (ALL_STATE_AUX_DIM, states), + (OBS_STATE_DIM, obs_states), + (OBS_STATE_AUX_DIM, obs_states), + (SHOCK_DIM, shocks), + (SHOCK_AUX_DIM, shocks), + ) + + coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels] + return cls(coords) + + def to_dict(self): + return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0} + + def add(self, coord: Coord) -> "CoordInfo": + # return a new CoordInfo with data appended + return CoordInfo(coords=[*list(self.items), coord]) + + def merge(self, other: "CoordInfo", allow_duplicates: bool = False) -> "CoordInfo": + """Combine data from two CoordInfo objects.""" + if not isinstance(other, CoordInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + raise ValueError(f"Duplicate coord names found: {overlapping}") + + return CoordInfo(coords=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class State(Property): + name: str + observed: bool + shared: bool + + +@dataclass(frozen=True) +class StateInfo(Info[State]): + def __init__(self, states: list[State]): + super().__init__(items=tuple(states), key_field="name") + + def __str__(self) -> str: + return ( + f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}" + ) + + @property + def observed_states(self) -> tuple[State, ...]: # Is this needed?? + return tuple(s for s in self.items if s.observed) + + @property + def observed_state_names(self) -> tuple[State, ...]: + return tuple(s.name for s in self.items if s.observed) + + @property + def unobserved_state_names(self) -> tuple[State, ...]: + return tuple(s.name for s in self.items if not s.observed) + + def add(self, state: State) -> "StateInfo": + # return a new StateInfo with state appended + return StateInfo(states=[*list(self.items), state]) + + def merge(self, other: "StateInfo", allow_duplicates: bool = False) -> "StateInfo": + """Combine states from two StateInfo objects.""" + if not isinstance(other, StateInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + # This is necessary for shared states + warnings.warn( + f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states", + UserWarning, + ) + return StateInfo( + states=list(self.items) + + [item for item in other.items if item.name not in overlapping] + ) + + return StateInfo(states=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Shock(Property): + name: str + + +@dataclass(frozen=True) +class ShockInfo(Info[Shock]): + def __init__(self, shocks: list[Shock]): + super().__init__(items=tuple(shocks), key_field="name") + + def add(self, shock: Shock) -> "ShockInfo": + # return a new ShockInfo with shock appended + return ShockInfo(shocks=[*list(self.items), shock]) + + def merge(self, other: "ShockInfo", allow_duplicates: bool = False) -> "ShockInfo": + """Combine shocks from two ShockInfo objects.""" + if not isinstance(other, ShockInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping and not allow_duplicates: + raise ValueError(f"Duplicate shock names found: {overlapping}") + + return ShockInfo(shocks=list(self.items) + list(other.items)) diff --git a/pymc_extras/statespace/models/structural/components/regression.py b/pymc_extras/statespace/models/structural/components/regression.py index 5620b1ea7..f5ecf3d47 100644 --- a/pymc_extras/statespace/models/structural/components/regression.py +++ b/pymc_extras/statespace/models/structural/components/regression.py @@ -2,6 +2,18 @@ from pytensor import tensor as pt +from pymc_extras.statespace.core.properties import ( + Coord, + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) from pymc_extras.statespace.models.structural.core import Component from pymc_extras.statespace.utils.constants import TIME_DIM @@ -194,64 +206,108 @@ def make_symbolic_graph(self) -> None: row_idx, col_idx = np.diag_indices(self.k_states) self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 - def populate_component_properties(self) -> None: + def _set_parameters(self) -> None: k_endog = self.k_endog k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + beta_parameter = Parameter( + name=f"beta_{self.name}", + shape=(k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), + dims=( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ), + constraints=None, + ) + + if self.innovations: + sigma_parameter = Parameter( + name=f"sigma_beta_{self.name}", + shape=(k_states,), + dims=(f"state_{self.name}",), + constraints="Positive", + ) + + self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter]) + self.param_names = self.param_info.names + else: + self.param_info = ParameterInfo(parameters=[beta_parameter]) + self.param_names = self.param_info.names + def _set_data(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog k_states = self.k_states // k_endog_effective + data_prop = Data( + name=f"data_{self.name}", + shape=(None, k_states), + dims=(TIME_DIM, f"state_{self.name}"), + is_exogenous=True, + ) + self.data_info = DataInfo(data=[data_prop]) + self.data_names = self.data_info.names + + def _set_shocks(self) -> None: if self.share_states: - self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] + shock_names = [f"{state_name}_shared" for state_name in self.state_names] else: - self.shock_names = self.state_names + shock_names = self.state_names - self.param_names = [f"beta_{self.name}"] - self.data_names = [f"data_{self.name}"] - self.param_dims = { - f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",) - } + self.shock_info = ShockInfo(shocks=[Shock(name=name) for name in shock_names]) + self.shock_names = self.shock_info.names - base_names = self.state_names + def _set_states(self) -> None: + self.base_names = self.state_names if self.share_states: - self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] + state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] + self.state_info = StateInfo( + states=[State(name=name, observed=False, shared=True) for name in state_names] + ) + self.state_info = self.state_info.merge( + StateInfo( + states=[ + State(name=name, observed=True, shared=False) + for name in self.observed_state_names + ] + ) + ) + self.state_names = self.state_info.unobserved_state_names else: - self.state_names = [ + state_names = [ f"{name}[{obs_name}]" for obs_name in self.observed_state_names - for name in base_names + for name in self.base_names ] + self.state_info = StateInfo( + states=[State(name=name, observed=False, shared=False) for name in state_names] + ) + self.state_info = self.state_info.merge( + StateInfo( + states=[ + State(name=name, observed=True, shared=False) + for name in self.observed_state_names + ] + ) + ) + self.state_names = self.state_info.unobserved_state_names - self.param_info = { - f"beta_{self.name}": { - "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), - "constraints": None, - "dims": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",), - }, - } - - self.data_info = { - f"data_{self.name}": { - "shape": (None, k_states), - "dims": (TIME_DIM, f"state_{self.name}"), - }, - } - self.coords = { - f"state_{self.name}": base_names, - f"endog_{self.name}": self.observed_state_names, - } + def _set_coords(self) -> None: + regression_state_coord = Coord( + dimension=f"state_{self.name}", labels=[state for state in self.base_names] + ) + endogenous_state_coord = Coord( + dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] + ) - if self.innovations: - self.param_names += [f"sigma_beta_{self.name}"] - self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) - self.param_info[f"sigma_beta_{self.name}"] = { - "shape": (k_states,), - "constraints": "Positive", - "dims": (f"state_{self.name}",) - if k_endog_effective == 1 - else (f"endog_{self.name}", f"state_{self.name}"), - } + self.coords_info = CoordInfo(coords=[regression_state_coord, endogenous_state_coord]) + + def populate_component_properties(self) -> None: + self._set_parameters() + self._set_data() + self._set_shocks() + self._set_states() + self._set_coords() diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index a2718251b..8d777bc88 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -2,6 +2,7 @@ import logging from collections.abc import Sequence +from dataclasses import is_dataclass from itertools import pairwise from typing import Any @@ -12,11 +13,21 @@ from pytensor import tensor as pt from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation +from pymc_extras.statespace.core.properties import ( + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) from pymc_extras.statespace.models.utilities import ( add_tensors_by_dim_labels, conform_time_varying_and_time_invariant_matrices, join_tensors_by_dim_labels, - make_default_coords, ) from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, @@ -140,16 +151,11 @@ def __init__( self, ssm: PytensorRepresentation, name: str, - state_names: list[str], - observed_state_names: list[str], - data_names: list[str], - shock_names: list[str], - param_names: list[str], - exog_names: list[str], - param_dims: dict[str, tuple[int]], - coords: dict[str, Sequence], - param_info: dict[str, dict[str, Any]], - data_info: dict[str, dict[str, Any]], + coords_info: CoordInfo, + param_info: ParameterInfo, + data_info: DataInfo, + shock_info: ShockInfo, + state_info: StateInfo, component_info: dict[str, dict[str, Any]], measurement_error: bool, name_to_variable: dict[str, Variable], @@ -161,27 +167,32 @@ def __init__( name = "StructuralTimeSeries" if name is None else name self._name = name - self._observed_state_names = observed_state_names + self._observed_state_names = state_info.observed_state_names k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog param_names, param_dims, param_info = self._add_inital_state_cov_to_properties( - param_names, param_dims, param_info, k_states + param_info, k_states ) - self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog) - self._data_names = self._strip_data_names_if_unambiguous(data_names, k_endog) - self._shock_names = self._strip_data_names_if_unambiguous(shock_names, k_endog) + self._state_names = self._strip_data_names_if_unambiguous( + state_info.unobserved_state_names, k_endog + ) + self._data_names = self._strip_data_names_if_unambiguous( + [d.name for d in data_info if not d.is_exogenous], k_endog + ) + self._shock_names = self._strip_data_names_if_unambiguous(shock_info.names, k_endog) self._param_names = self._strip_data_names_if_unambiguous(param_names, k_endog) self._param_dims = param_dims - default_coords = make_default_coords(self) - coords.update(default_coords) + default_coords = coords_info.default_coords_from_model(self) + coords_info = coords_info.merge(default_coords) - self._coords = { - k: self._strip_data_names_if_unambiguous(v, k_endog) for k, v in coords.items() - } + # TODO: discuss if copying is still needed since these are now immutable + self._coord_info = coords_info.copy() self._param_info = param_info.copy() self._data_info = data_info.copy() + self._shock_info = shock_info.copy() + self._state_info = state_info.copy() self.measurement_error = measurement_error super().__init__( @@ -210,8 +221,8 @@ def __init__( self._name_to_variable = name_to_variable.copy() self._name_to_data = name_to_data.copy() - self._exog_names = exog_names.copy() - self._needs_exog_data = len(exog_names) > 0 + self._exog_names = data_info.exogenous_names + self._needs_exog_data = data_info.needs_exogenous_data P0 = self.make_and_register_variable("P0", shape=(self.k_states, self.k_states)) self.ssm["initial_state_cov"] = P0 @@ -227,25 +238,29 @@ def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int): """ if k_endog == 1: [data_name] = self.observed_states - return [ + return tuple( name.replace(f"[{data_name}]", "") if isinstance(name, str) else name for name in names - ] + ) else: return names @staticmethod - def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states): - param_names += ["P0"] - param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) - param_info["P0"] = { - "shape": (k_states, k_states), - "constraints": "Positive semi-definite", - "dims": param_dims["P0"], - } + def _add_inital_state_cov_to_properties(param_info, k_states): + initial_state_cov_param = Parameter( + name="P0", + shape=(k_states, k_states), + dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM), + constraints="Positive semi-definite", + ) + + if is_dataclass(param_info): + param_info = param_info.add(initial_state_cov_param) + else: + param_info = ParameterInfo(parameters=[initial_state_cov_param]) - return param_names, param_dims, param_info + return param_info.names, [p.dims for p in param_info], param_info @property def param_names(self): @@ -255,6 +270,10 @@ def param_names(self): def data_names(self) -> list[str]: return self._data_names + @property + def exog_names(self) -> list[str]: + return self._exog_names + @property def state_names(self): return self._state_names @@ -273,15 +292,7 @@ def param_dims(self): @property def coords(self) -> dict[str, Sequence]: - return self._coords - - @property - def param_info(self) -> dict[str, dict[str, Any]]: - return self._param_info - - @property - def data_info(self) -> dict[str, dict[str, Any]]: - return self._data_info + return self._coord_info.to_dict() def make_symbolic_graph(self) -> None: """ @@ -525,23 +536,43 @@ def __init__( self.k_posdef = k_posdef self.measurement_error = measurement_error - self.state_names = list(state_names) if state_names is not None else [] - self.observed_state_names = ( - list(observed_state_names) if observed_state_names is not None else [] + self.param_info = ParameterInfo( + parameters=[ + Parameter(name=n, shape=(1,), dims=(f"{n}_placeholder")) + for n in (param_names or []) + ] + ) + self.data_info = DataInfo( + data=[ + Data(name=n, shape=(None, 1), dims=(f"{n}_placeholder"), is_exogenous=False) + for n in (data_names or []) + ] + + [ + Data(name=n, shape=(None, 1), dims=(f"{n}_placeholder"), is_exogenous=True) + for n in (exog_names or []) + ] ) - self.data_names = list(data_names) if data_names is not None else [] - self.shock_names = list(shock_names) if shock_names is not None else [] - self.param_names = list(param_names) if param_names is not None else [] - self.exog_names = list(exog_names) if exog_names is not None else [] + self.shock_info = ShockInfo(shocks=[Shock(name=n) for n in (shock_names or [])]) + self.state_info = StateInfo( + states=[State(name=n, observed=False, shared=share_states) for n in (state_names or [])] + + [ + State(name=n, observed=True, shared=share_states) + for n in (observed_state_names or []) + ] + ) + self.coord_info = CoordInfo(coords=[]) - self.needs_exog_data = len(self.exog_names) > 0 - self.coords = {} - self.param_dims = {} + self.state_names = self.state_info.unobserved_state_names + self.observed_state_names = self.state_info.observed_state_names + self.param_names = self.param_info.names + self.data_names = [d.name for d in self.data_info if not d.is_exogenous] + self.exog_names = self.data_info.exogenous_names + self.shock_names = self.shock_info.names - self.param_info = {} - self.data_info = {} + self.coords = self.coord_info.to_dict() + self.param_dims = [p.dims for p in self.param_info] - self.param_counts = {} + self.needs_exog_data = self.data_info.needs_exogenous_data if representation is None: self.ssm = PytensorRepresentation(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef) @@ -595,7 +626,7 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``param_names`` property. """ - if name not in self.param_names: + if name not in self.param_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." @@ -632,7 +663,7 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``data_names`` property. """ - if name not in self.data_names: + if name not in self.data_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." @@ -648,6 +679,21 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: self._name_to_data[name] = placeholder return placeholder + def _set_parameters(self) -> None: + raise NotImplementedError + + def _set_data(self) -> None: + raise NotImplementedError + + def _set_shocks(self) -> None: + raise NotImplementedError + + def _set_states(self) -> None: + raise NotImplementedError + + def _set_coords(self) -> None: + raise NotImplementedError + def make_symbolic_graph(self) -> None: raise NotImplementedError @@ -659,10 +705,8 @@ def _get_combined_shapes(self, other): k_posdef = self.k_posdef + other.k_posdef # To count endog states, we have to count unique names between the two components. - combined_states = self._combine_property( - other, "observed_state_names", allow_duplicates=False - ) - k_endog = len(combined_states) + combined_states = self._combine_property(other, "state_info", allow_duplicates=False) + k_endog = len(combined_states.observed_state_names) return k_states, k_posdef, k_endog @@ -770,20 +814,13 @@ def _combine_property(self, other, name, allow_duplicates=True): f"{type(self_prop)} for {self} and {type(other_prop)} for {other}'" ) - if not isinstance(self_prop, list | dict): + if not is_dataclass(self_prop): raise TypeError( - f"All component properties are expected to be lists or dicts, but found {type(self_prop)}" + f"All component properties are expected to be dataclasses, but found {type(self_prop)}" f"for property {name} of {self} and {type(other_prop)} for {other}'" ) - if isinstance(self_prop, list) and allow_duplicates: - return self_prop + other_prop - elif isinstance(self_prop, list) and not allow_duplicates: - return self_prop + [x for x in other_prop if x not in self_prop] - elif isinstance(self_prop, dict): - new_prop = self_prop.copy() - new_prop.update(other_prop) - return new_prop + return self_prop.merge(other_prop, allow_duplicates) def _combine_component_info(self, other): combined_info = {} @@ -807,22 +844,23 @@ def _make_combined_name(self): return name def __add__(self, other): - state_names = self._combine_property(other, "state_names") - data_names = self._combine_property(other, "data_names") - observed_state_names = self._combine_property( - other, "observed_state_names", allow_duplicates=False - ) - - param_names = self._combine_property(other, "param_names") - shock_names = self._combine_property(other, "shock_names") param_info = self._combine_property(other, "param_info") data_info = self._combine_property(other, "data_info") - param_dims = self._combine_property(other, "param_dims") - coords = self._combine_property(other, "coords") - exog_names = self._combine_property(other, "exog_names") - - _name_to_variable = self._combine_property(other, "_name_to_variable") - _name_to_data = self._combine_property(other, "_name_to_data") + shock_info = self._combine_property(other, "shock_info") + state_info = self._combine_property(other, "state_info") + coords_info = self._combine_property(other, "coords_info") + + state_names = state_info.unobserved_state_names + observed_state_names = state_info.observed_state_names + data_names = [d.name for d in data_info if not d.is_exogenous] + exog_names = data_info.exogenous_names + param_names = param_info.names + shock_names = shock_info.names + param_dims = [p.dims for p in param_info] + + # TODO: Figure out how to handle these items in dataclasses + # _name_to_variable = self._combine_property(other, "_name_to_variable") + # _name_to_data = self._combine_property(other, "_name_to_data") measurement_error = any([self.measurement_error, other.measurement_error]) @@ -849,14 +887,15 @@ def __add__(self, other): ("data_names", data_names), ("param_names", param_names), ("shock_names", shock_names), - ("param_dims", param_dims), - ("coords", coords), - ("param_dims", param_dims), + ("coords_info", coords_info), ("param_info", param_info), ("data_info", data_info), + ("shock_info", shock_info), + ("state_info", state_info), ("exog_names", exog_names), - ("_name_to_variable", _name_to_variable), - ("_name_to_data", _name_to_data), + ("param_dims", param_dims), + # ("_name_to_variable", _name_to_variable), # TODO: Need to figure out how to handle these objects + # ("_name_to_data", _name_to_data), ] for prop, value in names_and_props: @@ -899,18 +938,13 @@ def build( return StructuralTimeSeries( self.ssm, name=name, - state_names=self.state_names, - observed_state_names=self.observed_state_names, - data_names=self.data_names, - shock_names=self.shock_names, - param_names=self.param_names, - param_dims=self.param_dims, - coords=self.coords, + coords_info=self.coords_info, param_info=self.param_info, data_info=self.data_info, + shock_info=self.shock_info, + state_info=self.state_info, component_info=self._component_info, measurement_error=self.measurement_error, - exog_names=self.exog_names, name_to_variable=self._name_to_variable, name_to_data=self._name_to_data, filter_type=filter_type, diff --git a/pymc_extras/statespace/models/utilities.py b/pymc_extras/statespace/models/utilities.py index 33be8d47d..cab8f3b3c 100644 --- a/pymc_extras/statespace/models/utilities.py +++ b/pymc_extras/statespace/models/utilities.py @@ -5,6 +5,7 @@ from pytensor.tensor import TensorVariable +from pymc_extras.statespace.core.properties import Coord, CoordInfo from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, @@ -19,14 +20,23 @@ def make_default_coords(ss_mod): - coords = { - ALL_STATE_DIM: ss_mod.state_names, - ALL_STATE_AUX_DIM: ss_mod.state_names, - OBS_STATE_DIM: ss_mod.observed_states, - OBS_STATE_AUX_DIM: ss_mod.observed_states, - SHOCK_DIM: ss_mod.shock_names, - SHOCK_AUX_DIM: ss_mod.shock_names, - } + ALL_STATE_COORD = Coord(dimension=ALL_STATE_DIM, labels=ss_mod.state_names) + ALL_STATE_AUX_COORD = Coord(dimension=ALL_STATE_AUX_DIM, labels=ss_mod.state_names) + OBS_STATE_COORD = Coord(dimension=OBS_STATE_DIM, labels=ss_mod.observed_states) + OBS_STATE_AUX_COORD = Coord(dimension=OBS_STATE_AUX_DIM, labels=ss_mod.observed_states) + SHOCK_COORD = Coord(dimension=SHOCK_DIM, labels=ss_mod.shock_names) + SHOCK_AUX_COORD = Coord(dimension=SHOCK_AUX_DIM, labels=ss_mod.shock_names) + + coords = CoordInfo( + coords=[ + ALL_STATE_COORD, + ALL_STATE_AUX_COORD, + OBS_STATE_COORD, + OBS_STATE_AUX_COORD, + SHOCK_COORD, + SHOCK_AUX_COORD, + ] + ) return coords diff --git a/tests/statespace/core/test_properties.py b/tests/statespace/core/test_properties.py new file mode 100644 index 000000000..7f7cb8ae3 --- /dev/null +++ b/tests/statespace/core/test_properties.py @@ -0,0 +1,119 @@ +import pytest + +from pymc_extras.statespace.core.properties import ( + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +def test_property_str_formats_fields(): + p = Parameter(name="alpha", shape=(2,), dims=("param",)) + s = str(p).splitlines() + assert s == [ + "name: alpha", + "shape: (2,)", + "dims: ('param',)", + "constraints: None", + ] + + +def test_info_lookup_contains_and_missing_key(): + params = [ + Parameter("a", (1,), ("d",)), + Parameter("b", (2,), ("d",)), + Parameter("c", (3,), ("d",)), + ] + info = ParameterInfo(params) + + assert info.get("b").name == "b" + assert info["a"].shape == (1,) + assert "c" in info + + with pytest.raises(KeyError) as e: + _ = info["missing"] + assert "No name 'missing'" in str(e.value) + + +def test_data_info_needs_exogenous_and_str(): + data = [ + Data("price", (10,), ("time",), is_exogenous=False), + Data("x", (10,), ("time",), is_exogenous=True), + ] + info = DataInfo(data) + + assert info.needs_exogenous_data is True + s = str(info) + assert "data: ['price', 'x']" in s + assert "needs exogenous data: True" in s + + no_exog = DataInfo([Data("y", (10,), ("time",), is_exogenous=False)]) + assert no_exog.needs_exogenous_data is False + + +def test_coord_info_make_defaults_from_component_and_types(): + class DummyComponent: + state_names = ["x1", "x2"] + observed_state_names = ["x2"] + shock_names = ["eps1"] + + ci = CoordInfo.default_coords_from_model(DummyComponent()) + + expected = [ + (ALL_STATE_DIM, ("x1", "x2")), + (ALL_STATE_AUX_DIM, ("x1", "x2")), + (OBS_STATE_DIM, ("x2",)), + (OBS_STATE_AUX_DIM, ("x2",)), + (SHOCK_DIM, ("eps1",)), + (SHOCK_AUX_DIM, ("eps1",)), + ] + + assert len(ci.items) == 6 + for dim, labels in expected: + assert dim in ci + assert ci[dim].labels == labels + assert isinstance(ci[dim].labels, tuple) + + +def test_state_info_and_shockinfo_basic(): + states = [ + State("x1", observed=True, shared=False), + State("x2", observed=False, shared=True), + ] + state_info = StateInfo(states) + assert state_info["x1"].observed is True + s = str(state_info) + + assert "states: ['x1', 'x2']" in s + assert "observed: [True, False]" in s + + shocks = [Shock("s1"), Shock("s2")] + shock_info = ShockInfo(shocks) + + assert "s1" in shock_info + assert shock_info["s2"].name == "s2" + + +def test_info_is_iterable_and_unpackable(): + items = [Parameter("p1", (1,), ("d",)), Parameter("p2", (2,), ("d",))] + info = ParameterInfo(items) + + names = info.names + assert names == ("p1", "p2") + + a, b = info.items + assert a.name == "p1" and b.name == "p2" diff --git a/tests/statespace/models/structural/components/test_regression.py b/tests/statespace/models/structural/components/test_regression.py index c1732997d..7af48e2f4 100644 --- a/tests/statespace/models/structural/components/test_regression.py +++ b/tests/statespace/models/structural/components/test_regression.py @@ -252,11 +252,9 @@ def test_regression_multiple_shared_construction(): assert mod.coords["state_regression"] == ["A"] assert mod.coords["endog_regression"] == ["data_1", "data_2"] - assert mod.state_names == [ - "A[regression_shared]", - ] + assert mod.state_names == ("A[regression_shared]",) - assert mod.shock_names == ["A_shared"] + assert mod.shock_names == ("A_shared",) data = np.random.standard_normal(size=(10, 1)) Z = mod.ssm["design"].eval({"data_regression": data}) @@ -293,6 +291,7 @@ def test_regression_multiple_shared_observed(rng): np.testing.assert_allclose(y[:, 0], y[:, 2]) +@pytest.mark.filterwarnings("ignore::UserWarning") def test_regression_mixed_shared_and_not_shared(): mod_1 = st.RegressionComponent( name="individual", @@ -312,8 +311,8 @@ def test_regression_mixed_shared_and_not_shared(): assert mod.k_states == 4 assert mod.k_posdef == 4 - assert mod.state_names == ["A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]"] - assert mod.shock_names == ["A", "B_shared", "C_shared"] + assert mod.state_names == ("A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]") + assert mod.shock_names == ("A", "B_shared", "C_shared") data_joint = np.random.standard_normal(size=(10, 2)) data_individual = np.random.standard_normal(size=(10, 1)) diff --git a/tests/statespace/models/structural/conftest.py b/tests/statespace/models/structural/conftest.py index b9e58ca68..a395b528c 100644 --- a/tests/statespace/models/structural/conftest.py +++ b/tests/statespace/models/structural/conftest.py @@ -23,7 +23,7 @@ def _assert_basic_coords_correct(mod): assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names assert mod.coords[SHOCK_DIM] == mod.shock_names assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names - expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ["data"] + expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ("data",) assert mod.coords[OBS_STATE_DIM] == expected_obs assert mod.coords[OBS_STATE_AUX_DIM] == expected_obs