Skip to content

Commit bf18a3e

Browse files
Dekermanjianjessegrabowski
authored andcommitted
Add shared_state argument to RegressionComponent
1 parent 126d52c commit bf18a3e

File tree

2 files changed

+148
-18
lines changed

2 files changed

+148
-18
lines changed

pymc_extras/statespace/models/structural/components/regression.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ class RegressionComponent(Component):
3131
Whether to include stochastic innovations in the regression coefficients,
3232
allowing them to vary over time. If True, coefficients follow a random walk.
3333
34+
share_states: bool, default False
35+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
36+
states, which are observed by all observed states. If False, each observed state has its own set of
37+
latent states.
38+
3439
Notes
3540
-----
3641
This component implements regression with exogenous variables in a structural time series
@@ -107,7 +112,10 @@ def __init__(
107112
state_names: list[str] | None = None,
108113
observed_state_names: list[str] | None = None,
109114
innovations=False,
115+
share_states: bool = False,
110116
):
117+
self.share_states = share_states
118+
111119
if observed_state_names is None:
112120
observed_state_names = ["data"]
113121

@@ -121,8 +129,8 @@ def __init__(
121129
super().__init__(
122130
name=name,
123131
k_endog=k_endog,
124-
k_states=k_states * k_endog,
125-
k_posdef=k_posdef * k_endog,
132+
k_states=k_states * k_endog if not share_states else k_states,
133+
k_posdef=k_posdef * k_endog if not share_states else k_posdef,
126134
state_names=self.state_names,
127135
observed_state_names=observed_state_names,
128136
measurement_error=False,
@@ -153,54 +161,74 @@ def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -
153161

154162
def make_symbolic_graph(self) -> None:
155163
k_endog = self.k_endog
156-
k_states = self.k_states // k_endog
164+
k_endog_effective = 1 if self.share_states else k_endog
165+
166+
k_states = self.k_states // k_endog_effective
157167

158168
betas = self.make_and_register_variable(
159-
f"beta_{self.name}", shape=(k_endog, k_states) if k_endog > 1 else (k_states,)
169+
f"beta_{self.name}", shape=(k_endog, k_states) if k_endog_effective > 1 else (k_states,)
160170
)
161171
regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states))
162172

163173
self.ssm["initial_state", :] = betas.ravel()
164174
self.ssm["transition", :, :] = pt.eye(self.k_states)
165175
self.ssm["selection", :, :] = pt.eye(self.k_states)
166176

167-
Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)])
168-
self.ssm["design"] = pt.specify_shape(
169-
Z, (None, k_endog, regression_data.type.shape[1] * k_endog)
170-
)
177+
if self.share_states:
178+
self.ssm["design"] = pt.specify_shape(
179+
pt.join(1, *[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]),
180+
(None, k_endog, self.k_states),
181+
)
182+
else:
183+
Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)])
184+
self.ssm["design"] = pt.specify_shape(
185+
Z, (None, k_endog, regression_data.type.shape[1] * k_endog)
186+
)
171187

172188
if self.innovations:
173189
sigma_beta = self.make_and_register_variable(
174-
f"sigma_beta_{self.name}", (k_states,) if k_endog == 1 else (k_endog, k_states)
190+
f"sigma_beta_{self.name}",
191+
(k_states,) if k_endog_effective == 1 else (k_endog, k_states),
175192
)
176193
row_idx, col_idx = np.diag_indices(self.k_states)
177194
self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2
178195

179196
def populate_component_properties(self) -> None:
180197
k_endog = self.k_endog
181-
k_states = self.k_states // k_endog
198+
k_endog_effective = 1 if self.share_states else k_endog
199+
200+
k_states = self.k_states // k_endog_effective
182201

183-
self.shock_names = self.state_names
202+
if self.share_states:
203+
self.shock_names = [f"{state_name}_shared" for state_name in self.state_names]
204+
else:
205+
self.shock_names = self.state_names
184206

185207
self.param_names = [f"beta_{self.name}"]
186208
self.data_names = [f"data_{self.name}"]
187209
self.param_dims = {
188210
f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}")
189-
if k_endog > 1
211+
if k_endog_effective > 1
190212
else (f"state_{self.name}",)
191213
}
192214

193215
base_names = self.state_names
194-
self.state_names = [
195-
f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
196-
]
216+
217+
if self.share_states:
218+
self.state_names = [f"{name}[{self.name}_shared]" for name in base_names]
219+
else:
220+
self.state_names = [
221+
f"{name}[{obs_name}]"
222+
for obs_name in self.observed_state_names
223+
for name in base_names
224+
]
197225

198226
self.param_info = {
199227
f"beta_{self.name}": {
200-
"shape": (k_endog, k_states) if k_endog > 1 else (k_states,),
228+
"shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,),
201229
"constraints": None,
202230
"dims": (f"endog_{self.name}", f"state_{self.name}")
203-
if k_endog > 1
231+
if k_endog_effective > 1
204232
else (f"state_{self.name}",),
205233
},
206234
}
@@ -223,6 +251,6 @@ def populate_component_properties(self) -> None:
223251
"shape": (k_states,),
224252
"constraints": "Positive",
225253
"dims": (f"state_{self.name}",)
226-
if k_endog == 1
254+
if k_endog_effective == 1
227255
else (f"endog_{self.name}", f"state_{self.name}"),
228256
}

tests/statespace/models/structural/components/test_regression.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor import config
88
from pytensor import tensor as pt
99
from pytensor.graph.basic import explicit_graph_inputs
10+
from scipy.linalg import block_diag
1011

1112
from pymc_extras.statespace.models import structural as st
1213
from tests.statespace.models.structural.conftest import _assert_basic_coords_correct
@@ -235,3 +236,104 @@ def test_filter_scans_time_varying_design_matrix(self, rng, time_series_data, in
235236
if innovations:
236237
# Check that sigma_beta parameter is included in the prior
237238
assert "sigma_beta_exog" in prior.prior.data_vars
239+
240+
241+
def test_regression_multiple_shared_construction():
242+
rc = st.RegressionComponent(
243+
state_names=["A"],
244+
observed_state_names=["data_1", "data_2"],
245+
innovations=True,
246+
share_states=True,
247+
)
248+
mod = rc.build(verbose=False)
249+
250+
assert mod.k_endog == 2
251+
assert mod.k_states == 1
252+
assert mod.k_posdef == 1
253+
254+
assert mod.coords["state_regression"] == ["A"]
255+
assert mod.coords["endog_regression"] == ["data_1", "data_2"]
256+
257+
assert mod.state_names == [
258+
"A[regression_shared]",
259+
]
260+
261+
assert mod.shock_names == ["A_shared"]
262+
263+
data = np.random.standard_normal(size=(10, 1))
264+
Z = mod.ssm["design"].eval({"data_regression": data})
265+
T = mod.ssm["transition"].eval()
266+
R = mod.ssm["selection"].eval()
267+
268+
np.testing.assert_allclose(
269+
Z,
270+
np.hstack(
271+
[
272+
data,
273+
data,
274+
]
275+
)[:, :, np.newaxis],
276+
)
277+
278+
np.testing.assert_allclose(T, np.array([[1.0]]))
279+
np.testing.assert_allclose(R, np.array([[1.0]]))
280+
281+
282+
def test_regression_multiple_shared_observed(rng):
283+
mod = st.RegressionComponent(
284+
state_names=["A"],
285+
observed_state_names=["data_1", "data_2", "data_3"],
286+
innovations=False,
287+
share_states=True,
288+
)
289+
data = np.random.standard_normal(size=(10, 1))
290+
291+
params = {"beta_regression": np.array([1.0])}
292+
data_dict = {"data_regression": data}
293+
x, y = simulate_from_numpy_model(mod, rng, params, data_dict, steps=data.shape[0])
294+
np.testing.assert_allclose(y[:, 0], y[:, 1])
295+
np.testing.assert_allclose(y[:, 0], y[:, 2])
296+
297+
298+
def test_regression_mixed_shared_and_not_shared():
299+
mod_1 = st.RegressionComponent(
300+
name="individual",
301+
state_names=["A"],
302+
observed_state_names=["data_1", "data_2"],
303+
)
304+
mod_2 = st.RegressionComponent(
305+
name="joint",
306+
state_names=["B", "C"],
307+
observed_state_names=["data_1", "data_2"],
308+
share_states=True,
309+
)
310+
311+
mod = (mod_1 + mod_2).build(verbose=False)
312+
313+
assert mod.k_endog == 2
314+
assert mod.k_states == 4
315+
assert mod.k_posdef == 4
316+
317+
assert mod.state_names == ["A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]"]
318+
assert mod.shock_names == ["A", "B_shared", "C_shared"]
319+
320+
data_joint = np.random.standard_normal(size=(10, 2))
321+
data_individual = np.random.standard_normal(size=(10, 1))
322+
Z = mod.ssm["design"].eval({"data_joint": data_joint, "data_individual": data_individual})
323+
T = mod.ssm["transition"].eval()
324+
R = mod.ssm["selection"].eval()
325+
326+
np.testing.assert_allclose(
327+
Z,
328+
np.concat(
329+
(
330+
block_diag(*[data_individual[:, np.newaxis] for _ in range(mod.k_endog)]),
331+
np.concat((data_joint[:, np.newaxis], data_joint[:, np.newaxis]), axis=1),
332+
),
333+
axis=2,
334+
),
335+
)
336+
337+
np.testing.assert_allclose(T, np.eye(mod.k_states))
338+
339+
np.testing.assert_allclose(R, np.eye(mod.k_states))

0 commit comments

Comments
 (0)