Skip to content

Commit 940c425

Browse files
Add shared_state argument to AutoregressiveComponent
1 parent bf18a3e commit 940c425

File tree

2 files changed

+201
-42
lines changed

2 files changed

+201
-42
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class AutoregressiveComponent(Component):
2323
observed_state_names: list[str] | None, default None
2424
List of strings for observed state labels. If None, defaults to ["data"].
2525
26+
share_states: bool, default False
27+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
28+
states, which are observed by all observed states. If False, each observed state has its own set of
29+
latent states. This argument has no effect if `k_endog` is 1.
30+
2631
Notes
2732
-----
2833
An autoregressive component can be thought of as a way o introducing serially correlated errors into the model.
@@ -73,45 +78,58 @@ def __init__(
7378
order: int = 1,
7479
name: str = "auto_regressive",
7580
observed_state_names: list[str] | None = None,
81+
share_states: bool = False,
7682
):
7783
if observed_state_names is None:
7884
observed_state_names = ["data"]
7985

80-
k_posdef = k_endog = len(observed_state_names)
86+
k_endog = len(observed_state_names)
87+
k_endog_effective = k_posdef = 1 if share_states else k_endog
8188

8289
order = order_to_mask(order)
8390
ar_lags = np.flatnonzero(order).ravel().astype(int) + 1
8491
k_states = len(order)
8592

93+
self.share_states = share_states
8694
self.order = order
8795
self.ar_lags = ar_lags
8896

8997
super().__init__(
9098
name=name,
9199
k_endog=k_endog,
92-
k_states=k_states * k_endog,
100+
k_states=k_states * k_endog_effective,
93101
k_posdef=k_posdef,
94102
measurement_error=True,
95103
combine_hidden_states=True,
96104
observed_state_names=observed_state_names,
97-
obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog),
105+
obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog_effective),
98106
)
99107

100108
def populate_component_properties(self):
101-
k_states = self.k_states // self.k_endog # this is also the number of AR lags
109+
k_endog = self.k_endog
110+
k_endog_effective = 1 if self.share_states else k_endog
102111

103-
self.state_names = [
104-
f"L{i + 1}[{state_name}]"
105-
for state_name in self.observed_state_names
106-
for i in range(k_states)
107-
]
112+
k_states = self.k_states // k_endog_effective # this is also the number of AR lags
113+
base_names = [f"L{i + 1}_{self.name}" for i in range(k_states)]
114+
115+
if self.share_states:
116+
self.state_names = [f"{name}[shared]" for name in base_names]
117+
self.shock_names = [f"{self.name}[shared]"]
118+
else:
119+
self.state_names = [
120+
f"{name}[{state_name}]"
121+
for state_name in self.observed_state_names
122+
for name in base_names
123+
]
124+
self.shock_names = [
125+
f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names
126+
]
108127

109-
self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
110128
self.param_names = [f"params_{self.name}", f"sigma_{self.name}"]
111129
self.param_dims = {f"params_{self.name}": (f"lag_{self.name}",)}
112130
self.coords = {f"lag_{self.name}": self.ar_lags.tolist()}
113131

114-
if self.k_endog > 1:
132+
if k_endog_effective > 1:
115133
self.param_dims[f"params_{self.name}"] = (
116134
f"endog_{self.name}",
117135
f"lag_{self.name}",
@@ -140,26 +158,29 @@ def populate_component_properties(self):
140158

141159
def make_symbolic_graph(self) -> None:
142160
k_endog = self.k_endog
143-
k_states = self.k_states // k_endog
161+
k_endog_effective = 1 if self.share_states else k_endog
162+
163+
k_states = self.k_states // k_endog_effective
144164
k_posdef = self.k_posdef
145165

146166
k_nonzero = int(sum(self.order))
147167
ar_params = self.make_and_register_variable(
148-
f"params_{self.name}", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
168+
f"params_{self.name}",
169+
shape=(k_nonzero,) if k_endog_effective == 1 else (k_endog_effective, k_nonzero),
149170
)
150171
sigma_ar = self.make_and_register_variable(
151-
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
172+
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
152173
)
153174

154-
if k_endog == 1:
175+
if k_endog_effective == 1:
155176
T = pt.eye(k_states, k=-1)
156177
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
157178
T = T[ar_idx].set(ar_params)
158179

159180
else:
160181
transition_matrices = []
161182

162-
for i in range(k_endog):
183+
for i in range(k_endog_effective):
163184
T = pt.eye(k_states, k=-1)
164185
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
165186
T = T[ar_idx].set(ar_params[i])
@@ -171,18 +192,21 @@ def make_symbolic_graph(self) -> None:
171192
self.ssm["transition", :, :] = T
172193

173194
R = np.eye(k_states)
174-
R_mask = np.full((k_states), False)
195+
R_mask = np.full((k_states,), False)
175196
R_mask[0] = True
176197
R = R[:, R_mask]
177198

178199
self.ssm["selection", :, :] = pt.specify_shape(
179-
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
200+
pt.linalg.block_diag(*[R for _ in range(k_endog_effective)]), (self.k_states, k_posdef)
180201
)
181202

182-
Z = pt.zeros((1, k_states))[0, 0].set(1.0)
183-
self.ssm["design", :, :] = pt.specify_shape(
184-
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
185-
)
203+
Zs = [pt.zeros((1, k_states))[0, 0].set(1.0) for _ in range(k_endog)]
204+
205+
if self.share_states:
206+
Z = pt.join(0, *Zs)
207+
else:
208+
Z = pt.linalg.block_diag(*Zs)
209+
self.ssm["design", :, :] = pt.specify_shape(Z, (k_endog, self.k_states))
186210

187211
cov_idx = ("state_cov", *np.diag_indices(k_posdef))
188212
self.ssm[cov_idx] = sigma_ar**2

tests/statespace/models/structural/components/test_autoregressive.py

Lines changed: 155 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
def test_autoregressive_model(order, rng):
1616
ar = st.AutoregressiveComponent(order=order).build(verbose=False)
1717

18-
# Check coords
1918
_assert_basic_coords_correct(ar)
2019

2120
lags = np.arange(len(order) if isinstance(order, list) else order, dtype="int") + 1
@@ -25,34 +24,34 @@ def test_autoregressive_model(order, rng):
2524

2625

2726
def test_autoregressive_multiple_observed_build(rng):
28-
ar = st.AutoregressiveComponent(order=3, observed_state_names=["data_1", "data_2"])
27+
ar = st.AutoregressiveComponent(order=3, name="ar", observed_state_names=["data_1", "data_2"])
2928
mod = ar.build(verbose=False)
3029

3130
assert mod.k_endog == 2
3231
assert mod.k_states == 6
3332
assert mod.k_posdef == 2
3433

3534
assert mod.state_names == [
36-
"L1[data_1]",
37-
"L2[data_1]",
38-
"L3[data_1]",
39-
"L1[data_2]",
40-
"L2[data_2]",
41-
"L3[data_2]",
35+
"L1_ar[data_1]",
36+
"L2_ar[data_1]",
37+
"L3_ar[data_1]",
38+
"L1_ar[data_2]",
39+
"L2_ar[data_2]",
40+
"L3_ar[data_2]",
4241
]
4342

44-
assert mod.shock_names == ["auto_regressive[data_1]", "auto_regressive[data_2]"]
43+
assert mod.shock_names == ["ar[data_1]", "ar[data_2]"]
4544

4645
params = {
47-
"params_auto_regressive": np.full(
46+
"params_ar": np.full(
4847
(
4948
2,
5049
sum(ar.order),
5150
),
5251
0.5,
5352
dtype=config.floatX,
5453
),
55-
"sigma_auto_regressive": np.array([0.05, 0.12]),
54+
"sigma_ar": np.array([0.05, 0.12]),
5655
}
5756
_, _, _, _, T, Z, R, _, Q = mod._unpack_statespace_with_placeholders()
5857
input_vars = explicit_graph_inputs([T, Z, R, Q])
@@ -89,6 +88,33 @@ def test_autoregressive_multiple_observed_build(rng):
8988
np.testing.assert_allclose(Q, np.diag([0.05**2, 0.12**2]))
9089

9190

91+
def test_autoregressive_multiple_observed_shared():
92+
ar = st.AutoregressiveComponent(
93+
order=1,
94+
name="latent",
95+
observed_state_names=["data_1", "data_2", "data_3"],
96+
share_states=True,
97+
)
98+
mod = ar.build(verbose=False)
99+
100+
assert mod.k_endog == 3
101+
assert mod.k_states == 1
102+
assert mod.k_posdef == 1
103+
104+
assert mod.state_names == ["L1_latent[shared]"]
105+
assert mod.shock_names == ["latent[shared]"]
106+
assert mod.coords["lag_latent"] == [1]
107+
assert "endog_latent" not in mod.coords
108+
109+
outputs = [mod.ssm["transition"], mod.ssm["design"]]
110+
params = {"params_latent": np.array([0.9])}
111+
T, Z = pytensor.function(list(explicit_graph_inputs(outputs)), outputs)(**params)
112+
113+
np.testing.assert_allclose(np.array([[1.0], [1.0], [1.0]]), Z)
114+
115+
np.testing.assert_allclose(np.array([[0.9]]), T)
116+
117+
92118
def test_autoregressive_multiple_observed_data(rng):
93119
ar = st.AutoregressiveComponent(order=1, observed_state_names=["data_1", "data_2", "data_3"])
94120
mod = ar.build(verbose=False)
@@ -112,21 +138,130 @@ def test_add_autoregressive_different_observed():
112138

113139
mod = (mod_1 + mod_2).build(verbose=False)
114140

115-
print(mod.coords)
116-
117141
assert mod.k_endog == 2
118142
assert mod.k_states == 7
119143
assert mod.k_posdef == 2
120144
assert mod.state_names == [
121-
"L1[data_1]",
122-
"L1[data_2]",
123-
"L2[data_2]",
124-
"L3[data_2]",
125-
"L4[data_2]",
126-
"L5[data_2]",
127-
"L6[data_2]",
145+
f"L1_{mod_1.name}[data_1]",
146+
f"L1_{mod_2.name}[data_2]",
147+
f"L2_{mod_2.name}[data_2]",
148+
f"L3_{mod_2.name}[data_2]",
149+
f"L4_{mod_2.name}[data_2]",
150+
f"L5_{mod_2.name}[data_2]",
151+
f"L6_{mod_2.name}[data_2]",
128152
]
129153

130154
assert mod.shock_names == ["ar1[data_1]", "ar6[data_2]"]
131155
assert mod.coords["lag_ar1"] == [1]
132156
assert mod.coords["lag_ar6"] == [1, 2, 3, 4, 5, 6]
157+
158+
159+
def test_autoregressive_shared_and_not_shared():
160+
shared = st.AutoregressiveComponent(
161+
order=3,
162+
name="shared_ar",
163+
observed_state_names=["data_1", "data_2", "data_3"],
164+
share_states=True,
165+
)
166+
individual = st.AutoregressiveComponent(
167+
order=3,
168+
name="individual_ar",
169+
observed_state_names=["data_1", "data_2", "data_3"],
170+
share_states=False,
171+
)
172+
173+
mod = (shared + individual).build(verbose=False)
174+
175+
assert mod.k_endog == 3
176+
assert mod.k_states == 3 + 3 * 3
177+
assert mod.k_posdef == 4
178+
179+
assert mod.state_names == [
180+
"L1_shared_ar[shared]",
181+
"L2_shared_ar[shared]",
182+
"L3_shared_ar[shared]",
183+
"L1_individual_ar[data_1]",
184+
"L2_individual_ar[data_1]",
185+
"L3_individual_ar[data_1]",
186+
"L1_individual_ar[data_2]",
187+
"L2_individual_ar[data_2]",
188+
"L3_individual_ar[data_2]",
189+
"L1_individual_ar[data_3]",
190+
"L2_individual_ar[data_3]",
191+
"L3_individual_ar[data_3]",
192+
]
193+
194+
assert mod.shock_names == [
195+
"shared_ar[shared]",
196+
"individual_ar[data_1]",
197+
"individual_ar[data_2]",
198+
"individual_ar[data_3]",
199+
]
200+
assert mod.coords["lag_shared_ar"] == [1, 2, 3]
201+
assert mod.coords["lag_individual_ar"] == [1, 2, 3]
202+
203+
outputs = [mod.ssm["transition"], mod.ssm["design"], mod.ssm["selection"], mod.ssm["state_cov"]]
204+
T, Z, R, Q = pytensor.function(
205+
list(explicit_graph_inputs(outputs)),
206+
outputs,
207+
)(
208+
**{
209+
"params_shared_ar": np.array([0.9, 0.8, 0.7]),
210+
"params_individual_ar": np.full((3, 3), 0.5),
211+
"sigma_shared_ar": np.array(0.1),
212+
"sigma_individual_ar": np.array([0.05, 0.12, 0.22]),
213+
}
214+
)
215+
216+
np.testing.assert_allclose(
217+
T,
218+
np.array(
219+
[
220+
[0.9, 0.8, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
221+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
222+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
223+
[0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
224+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
225+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
226+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0],
227+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
228+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
229+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5],
230+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
231+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
232+
]
233+
),
234+
)
235+
236+
np.testing.assert_allclose(
237+
Z,
238+
np.array(
239+
[
240+
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
241+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
242+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
243+
]
244+
),
245+
)
246+
247+
np.testing.assert_allclose(
248+
R,
249+
np.array(
250+
[
251+
[1.0, 0.0, 0.0, 0.0],
252+
[0.0, 0.0, 0.0, 0.0],
253+
[0.0, 0.0, 0.0, 0.0],
254+
[0.0, 1.0, 0.0, 0.0],
255+
[0.0, 0.0, 0.0, 0.0],
256+
[0.0, 0.0, 0.0, 0.0],
257+
[0.0, 0.0, 1.0, 0.0],
258+
[0.0, 0.0, 0.0, 0.0],
259+
[0.0, 0.0, 0.0, 0.0],
260+
[0.0, 0.0, 0.0, 1.0],
261+
[0.0, 0.0, 0.0, 0.0],
262+
[0.0, 0.0, 0.0, 0.0],
263+
]
264+
),
265+
)
266+
267+
np.testing.assert_allclose(Q, np.diag([0.1, 0.05, 0.12, 0.22]) ** 2)

0 commit comments

Comments
 (0)