Skip to content

Commit 48ac660

Browse files
Add shared_state argument to LevelTrendComponent
1 parent 2d16ad0 commit 48ac660

File tree

2 files changed

+185
-29
lines changed

2 files changed

+185
-29
lines changed

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ class LevelTrendComponent(Component):
1313
Parameters
1414
----------
1515
order : int
16-
1716
Number of time derivatives of the trend to include in the model. For example, when order=3, the trend will
1817
be of the form ``y = a + b * t + c * t ** 2``, where the coefficients ``a, b, c`` come from the initial
1918
state values.
2019
2120
innovations_order : int or sequence of int, optional
22-
2321
The number of stochastic innovations to include in the model. By default, ``innovations_order = order``
2422
2523
name : str, default "level_trend"
@@ -28,6 +26,11 @@ class LevelTrendComponent(Component):
2826
observed_state_names : list[str] | None, default None
2927
List of strings for observed state labels. If None, defaults to ["data"].
3028
29+
share_states: bool, default False
30+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
31+
states, which are observed by all observed states. If False, each observed state has its own set of
32+
latent states. This argument has no effect if `k_endog` is 1.
33+
3134
Notes
3235
-----
3336
This class implements the level and trend components of the general structural time series model. In the most
@@ -120,7 +123,10 @@ def __init__(
120123
innovations_order: int | list[int] | None = None,
121124
name: str = "level_trend",
122125
observed_state_names: list[str] | None = None,
126+
share_states: bool = False,
123127
):
128+
self.share_states = share_states
129+
124130
if innovations_order is None:
125131
innovations_order = order
126132

@@ -156,37 +162,50 @@ def __init__(
156162
super().__init__(
157163
name,
158164
k_endog=k_endog,
159-
k_states=k_states * k_endog,
160-
k_posdef=k_posdef * k_endog,
165+
k_states=k_states * k_endog if not share_states else k_states,
166+
k_posdef=k_posdef * k_endog if not share_states else k_posdef,
161167
observed_state_names=observed_state_names,
162168
measurement_error=False,
163169
combine_hidden_states=False,
164-
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
170+
obs_state_idxs=np.tile(
171+
np.array([1.0] + [0.0] * (k_states - 1)), k_endog if not share_states else 1
172+
),
165173
)
166174

167175
def populate_component_properties(self):
168176
k_endog = self.k_endog
169-
k_states = self.k_states // k_endog
170-
k_posdef = self.k_posdef // k_endog
177+
k_endog_effective = 1 if self.share_states else k_endog
178+
179+
k_states = self.k_states // k_endog_effective
180+
k_posdef = self.k_posdef // k_endog_effective
171181

172182
name_slice = POSITION_DERIVATIVE_NAMES[:k_states]
173183
self.param_names = [f"initial_{self.name}"]
174184
base_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
175-
self.state_names = [
176-
f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
177-
]
185+
186+
if self.share_states:
187+
self.state_names = [f"{name}[{self.name}_shared]" for name in base_names]
188+
else:
189+
self.state_names = [
190+
f"{name}[{obs_name}]"
191+
for obs_name in self.observed_state_names
192+
for name in base_names
193+
]
194+
178195
self.param_dims = {f"initial_{self.name}": (f"state_{self.name}",)}
179196
self.coords = {f"state_{self.name}": base_names}
180197

181198
if k_endog > 1:
199+
self.coords[f"endog_{self.name}"] = self.observed_state_names
200+
201+
if k_endog_effective > 1:
182202
self.param_dims[f"state_{self.name}"] = (
183203
f"endog_{self.name}",
184204
f"state_{self.name}",
185205
)
186206
self.param_dims = {f"initial_{self.name}": (f"endog_{self.name}", f"state_{self.name}")}
187-
self.coords[f"endog_{self.name}"] = self.observed_state_names
188207

189-
shape = (k_endog, k_states) if k_endog > 1 else (k_states,)
208+
shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,)
190209
self.param_info = {f"initial_{self.name}": {"shape": shape, "constraints": None}}
191210

192211
if self.k_posdef > 0:
@@ -196,20 +215,23 @@ def populate_component_properties(self):
196215
name for name, mask in zip(name_slice, self.innovations_order) if mask
197216
]
198217

199-
self.shock_names = [
200-
f"{name}[{obs_name}]"
201-
for obs_name in self.observed_state_names
202-
for name in base_shock_names
203-
]
218+
if self.share_states:
219+
self.shock_names = [f"{name}[{self.name}_shared]" for name in base_shock_names]
220+
else:
221+
self.shock_names = [
222+
f"{name}[{obs_name}]"
223+
for obs_name in self.observed_state_names
224+
for name in base_shock_names
225+
]
204226

205227
self.param_dims[f"sigma_{self.name}"] = (
206228
(f"shock_{self.name}",)
207-
if k_endog == 1
229+
if k_endog_effective == 1
208230
else (f"endog_{self.name}", f"shock_{self.name}")
209231
)
210232
self.coords[f"shock_{self.name}"] = base_shock_names
211233
self.param_info[f"sigma_{self.name}"] = {
212-
"shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
234+
"shape": (k_posdef,) if k_endog_effective == 1 else (k_endog_effective, k_posdef),
213235
"constraints": "Positive",
214236
}
215237

@@ -218,40 +240,49 @@ def populate_component_properties(self):
218240

219241
def make_symbolic_graph(self) -> None:
220242
k_endog = self.k_endog
221-
k_states = self.k_states // k_endog
222-
k_posdef = self.k_posdef // k_endog
243+
k_endog_effective = 1 if self.share_states else k_endog
244+
245+
k_states = self.k_states // k_endog_effective
246+
k_posdef = self.k_posdef // k_endog_effective
223247

224248
initial_trend = self.make_and_register_variable(
225249
f"initial_{self.name}",
226-
shape=(k_states,) if k_endog == 1 else (k_endog, k_states),
250+
shape=(k_states,) if k_endog_effective == 1 else (k_endog, k_states),
227251
)
228252
self.ssm["initial_state", :] = initial_trend.ravel()
229253

230254
triu_idx = pt.triu_indices(k_states)
231255
T = pt.zeros((k_states, k_states))[triu_idx[0], triu_idx[1]].set(1)
232256

233257
self.ssm["transition", :, :] = pt.specify_shape(
234-
pt.linalg.block_diag(*[T for _ in range(k_endog)]), (self.k_states, self.k_states)
258+
pt.linalg.block_diag(*[T for _ in range(k_endog_effective)]),
259+
(self.k_states, self.k_states),
235260
)
236261

237262
R = np.eye(k_states)
238263
R = R[:, self.innovations_order]
239264

240265
self.ssm["selection", :, :] = pt.specify_shape(
241-
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
266+
pt.linalg.block_diag(*[R for _ in range(k_endog_effective)]),
267+
(self.k_states, self.k_posdef),
242268
)
243269

244270
Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1))
245271

246-
self.ssm["design", :, :] = pt.specify_shape(
247-
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
248-
)
272+
if self.share_states:
273+
self.ssm["design", :, :] = pt.specify_shape(
274+
pt.join(0, *[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
275+
)
276+
else:
277+
self.ssm["design", :, :] = pt.specify_shape(
278+
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
279+
)
249280

250281
if k_posdef > 0:
251282
sigma_trend = self.make_and_register_variable(
252283
f"sigma_{self.name}",
253-
shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
284+
shape=(k_posdef,) if k_endog_effective == 1 else (k_endog, k_posdef),
254285
)
255-
diag_idx = np.diag_indices(k_posdef * k_endog)
286+
diag_idx = np.diag_indices(k_posdef * k_endog_effective)
256287
idx = np.s_["state_cov", diag_idx[0], diag_idx[1]]
257288
self.ssm[idx] = (sigma_trend**2).ravel()

tests/statespace/models/structural/components/test_level_trend.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,44 @@ def test_level_trend_multiple_observed_construction():
9191
)
9292

9393

94+
def test_level_trend_multiple_shared_construction():
95+
mod = st.LevelTrendComponent(
96+
order=2, innovations_order=1, observed_state_names=["data_1", "data_2"], share_states=True
97+
)
98+
mod = mod.build(verbose=False)
99+
100+
assert mod.k_endog == 2
101+
assert mod.k_states == 2
102+
assert mod.k_posdef == 1
103+
104+
assert mod.coords["state_level_trend"] == ["level", "trend"]
105+
assert mod.coords["endog_level_trend"] == ["data_1", "data_2"]
106+
107+
assert mod.state_names == [
108+
"level[level_trend_shared]",
109+
"trend[level_trend_shared]",
110+
]
111+
assert mod.shock_names == ["level[level_trend_shared]"]
112+
113+
Z, T, R = pytensor.function(
114+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
115+
)()
116+
117+
np.testing.assert_allclose(
118+
Z,
119+
np.array(
120+
[
121+
[1.0, 0.0],
122+
[1.0, 0.0],
123+
]
124+
),
125+
)
126+
127+
np.testing.assert_allclose(T, np.array([[1.0, 1.0], [0.0, 1.0]]))
128+
129+
np.testing.assert_allclose(R, np.array([[1.0], [0.0]]))
130+
131+
94132
def test_level_trend_multiple_observed(rng):
95133
mod = st.LevelTrendComponent(
96134
order=2, innovations_order=0, observed_state_names=["data_1", "data_2", "data_3"]
@@ -102,6 +140,19 @@ def test_level_trend_multiple_observed(rng):
102140
assert (np.diff(x, axis=0) == np.array([[1.0, 0.0, 2.0, 0.0, 3.0, 0.0]])).all().all()
103141

104142

143+
def test_level_trend_multiple_shared_observed(rng):
144+
mod = st.LevelTrendComponent(
145+
order=2,
146+
innovations_order=0,
147+
observed_state_names=["data_1", "data_2", "data_3"],
148+
share_states=True,
149+
)
150+
params = {"initial_level_trend": np.array([10.0, 0.1])}
151+
x, y = simulate_from_numpy_model(mod, rng, params)
152+
np.testing.assert_allclose(y[:, 0], y[:, 1])
153+
np.testing.assert_allclose(y[:, 0], y[:, 2])
154+
155+
105156
def test_add_level_trend_with_different_observed():
106157
mod_1 = st.LevelTrendComponent(
107158
name="ll", order=2, innovations_order=[0, 1], observed_state_names=["data_1"]
@@ -156,3 +207,77 @@ def test_add_level_trend_with_different_observed():
156207
]
157208
),
158209
)
210+
211+
212+
def test_mixed_shared_and_not_shared():
213+
mod_1 = st.LevelTrendComponent(
214+
name="individual",
215+
order=2,
216+
innovations_order=[0, 1],
217+
observed_state_names=["data_1", "data_2"],
218+
)
219+
mod_2 = st.LevelTrendComponent(
220+
name="joint",
221+
order=2,
222+
innovations_order=[1, 1],
223+
observed_state_names=["data_1", "data_2"],
224+
share_states=True,
225+
)
226+
227+
mod = (mod_1 + mod_2).build(verbose=False)
228+
229+
assert mod.k_endog == 2
230+
assert mod.k_states == 6
231+
assert mod.k_posdef == 4
232+
233+
assert mod.state_names == [
234+
"level[data_1]",
235+
"trend[data_1]",
236+
"level[data_2]",
237+
"trend[data_2]",
238+
"level[joint_shared]",
239+
"trend[joint_shared]",
240+
]
241+
242+
assert mod.shock_names == [
243+
"trend[data_1]",
244+
"trend[data_2]",
245+
"level[joint_shared]",
246+
"trend[joint_shared]",
247+
]
248+
249+
Z, T, R = pytensor.function(
250+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
251+
)()
252+
253+
np.testing.assert_allclose(
254+
Z, np.array([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 1.0, 0.0]])
255+
)
256+
257+
np.testing.assert_allclose(
258+
T,
259+
np.array(
260+
[
261+
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
262+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
263+
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
264+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
265+
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
266+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
267+
]
268+
),
269+
)
270+
271+
np.testing.assert_allclose(
272+
R,
273+
np.array(
274+
[
275+
[0.0, 0.0, 0.0, 0.0],
276+
[1.0, 0.0, 0.0, 0.0],
277+
[0.0, 0.0, 0.0, 0.0],
278+
[0.0, 1.0, 0.0, 0.0],
279+
[0.0, 0.0, 1.0, 0.0],
280+
[0.0, 0.0, 0.0, 1.0],
281+
]
282+
),
283+
)

0 commit comments

Comments
 (0)