Skip to content

Commit 81a02be

Browse files
Add shared_state argument to CycleComponent
1 parent f5054a3 commit 81a02be

File tree

2 files changed

+152
-23
lines changed
  • pymc_extras/statespace/models/structural/components
  • tests/statespace/models/structural/components

2 files changed

+152
-23
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class CycleComponent(Component):
4343
Names of the observed state variables. For univariate time series, defaults to ``["data"]``.
4444
For multivariate time series, specify a list of names for each endogenous variable.
4545
46+
share_states: bool, default False
47+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
48+
states, which are observed by all observed states. If False, each observed state has its own set of
49+
latent states. This argument has no effect if `k_endog` is 1.
50+
4651
Notes
4752
-----
4853
The cycle component is very similar in implementation to the frequency domain seasonal component, expect that it
@@ -155,6 +160,7 @@ def __init__(
155160
dampen: bool = False,
156161
innovations: bool = True,
157162
observed_state_names: list[str] | None = None,
163+
share_states: bool = False,
158164
):
159165
if observed_state_names is None:
160166
observed_state_names = ["data"]
@@ -167,6 +173,7 @@ def __init__(
167173
cycle = int(cycle_length) if cycle_length is not None else "Estimate"
168174
name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]"
169175

176+
self.share_states = share_states
170177
self.estimate_cycle_length = estimate_cycle_length
171178
self.cycle_length = cycle_length
172179
self.innovations = innovations
@@ -175,8 +182,8 @@ def __init__(
175182

176183
k_endog = len(observed_state_names)
177184

178-
k_states = 2 * k_endog
179-
k_posdef = 2 * k_endog
185+
k_states = 2 if share_states else 2 * k_endog
186+
k_posdef = 2 if share_states else 2 * k_endog
180187

181188
obs_state_idx = np.zeros(k_states)
182189
obs_state_idx[slice(0, k_states, 2)] = 1
@@ -193,18 +200,22 @@ def __init__(
193200
)
194201

195202
def make_symbolic_graph(self) -> None:
203+
k_endog = self.k_endog
204+
k_endog_effective = 1 if self.share_states else k_endog
205+
196206
Z = np.array([1.0, 0.0]).reshape((1, -1))
197-
design_matrix = block_diag(*[Z for _ in range(self.k_endog)])
207+
design_matrix = block_diag(*[Z for _ in range(k_endog_effective)])
198208
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
199209

200210
# selection matrix R defines structure of innovations (always identity for cycle components)
201211
# when innovations=False, state cov Q=0, hence R @ Q @ R.T = 0
202212
R = np.eye(2) # 2x2 identity for each cycle component
203-
selection_matrix = block_diag(*[R for _ in range(self.k_endog)])
213+
selection_matrix = block_diag(*[R for _ in range(k_endog_effective)])
204214
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
205215

206216
init_state = self.make_and_register_variable(
207-
f"{self.name}", shape=(self.k_endog, 2) if self.k_endog > 1 else (self.k_states,)
217+
f"{self.name}",
218+
shape=(k_endog_effective, 2) if k_endog_effective > 1 else (self.k_states,),
208219
)
209220
self.ssm["initial_state", :] = init_state.ravel()
210221

@@ -219,37 +230,45 @@ def make_symbolic_graph(self) -> None:
219230
rho = 1
220231

221232
T = rho * _frequency_transition_block(lamb, j=1)
222-
transition = block_diag(*[T for _ in range(self.k_endog)])
233+
transition = block_diag(*[T for _ in range(k_endog_effective)])
223234
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
224235

225236
if self.innovations:
226-
if self.k_endog == 1:
237+
if k_endog_effective == 1:
227238
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
228239
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
229240
else:
230241
sigma_cycle = self.make_and_register_variable(
231-
f"sigma_{self.name}", shape=(self.k_endog,)
242+
f"sigma_{self.name}", shape=(k_endog_effective,)
232243
)
233244
state_cov = block_diag(
234-
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
245+
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(k_endog_effective)]
235246
)
236247
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
237248
else:
238249
# explicitly set state cov to 0 when no innovations
239250
self.ssm["state_cov", :, :] = pt.zeros((self.k_posdef, self.k_posdef))
240251

241252
def populate_component_properties(self):
242-
self.state_names = [
243-
f"{f}_{self.name}[{var_name}]" if self.k_endog > 1 else f"{f}_{self.name}"
244-
for var_name in self.observed_state_names
245-
for f in ["Cos", "Sin"]
246-
]
253+
k_endog = self.k_endog
254+
k_endog_effective = 1 if self.share_states else k_endog
255+
256+
base_names = [f"{f}_{self.name}" for f in ["Cos", "Sin"]]
257+
258+
if self.share_states:
259+
self.state_names = [f"{name}[shared]" for name in base_names]
260+
else:
261+
self.state_names = [
262+
f"{name}[{var_name}]" if k_endog_effective > 1 else name
263+
for var_name in self.observed_state_names
264+
for name in base_names
265+
]
247266

248267
self.param_names = [f"{self.name}"]
249268

250-
if self.k_endog == 1:
269+
if k_endog_effective == 1:
251270
self.param_dims = {self.name: (f"state_{self.name}",)}
252-
self.coords = {f"state_{self.name}": self.state_names}
271+
self.coords = {f"state_{self.name}": base_names}
253272
self.param_info = {
254273
f"{self.name}": {
255274
"shape": (2,),
@@ -265,7 +284,7 @@ def populate_component_properties(self):
265284
}
266285
self.param_info = {
267286
f"{self.name}": {
268-
"shape": (self.k_endog, 2),
287+
"shape": (k_endog_effective, 2),
269288
"constraints": None,
270289
"dims": (f"endog_{self.name}", f"state_{self.name}"),
271290
}
@@ -274,22 +293,22 @@ def populate_component_properties(self):
274293
if self.estimate_cycle_length:
275294
self.param_names += [f"length_{self.name}"]
276295
self.param_info[f"length_{self.name}"] = {
277-
"shape": () if self.k_endog == 1 else (self.k_endog,),
296+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
278297
"constraints": "Positive, non-zero",
279-
"dims": None if self.k_endog == 1 else f"endog_{self.name}",
298+
"dims": None if k_endog_effective == 1 else f"endog_{self.name}",
280299
}
281300

282301
if self.dampen:
283302
self.param_names += [f"dampening_factor_{self.name}"]
284303
self.param_info[f"dampening_factor_{self.name}"] = {
285-
"shape": () if self.k_endog == 1 else (self.k_endog,),
304+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
286305
"constraints": "0 < x ≤ 1",
287-
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
306+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
288307
}
289308

290309
if self.innovations:
291310
self.param_names += [f"sigma_{self.name}"]
292-
if self.k_endog == 1:
311+
if k_endog_effective == 1:
293312
self.param_info[f"sigma_{self.name}"] = {
294313
"shape": (),
295314
"constraints": "Positive",
@@ -298,7 +317,7 @@ def populate_component_properties(self):
298317
else:
299318
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
300319
self.param_info[f"sigma_{self.name}"] = {
301-
"shape": (self.k_endog,),
320+
"shape": (k_endog_effective,),
302321
"constraints": "Positive",
303322
"dims": (f"endog_{self.name}",),
304323
}

tests/statespace/models/structural/components/test_cycle.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from numpy.testing import assert_allclose
55
from pytensor import config
6+
from pytensor.graph.basic import explicit_graph_inputs
7+
from scipy import linalg
68

79
from pymc_extras.statespace.models import structural as st
810
from pymc_extras.statespace.models.structural.utils import _frequency_transition_block
@@ -105,6 +107,27 @@ def test_cycle_multivariate_deterministic(rng):
105107
np.testing.assert_allclose(R, expected_R)
106108

107109

110+
def test_multivariate_cycle_with_shared(rng):
111+
cycle = st.CycleComponent(
112+
name="cycle",
113+
cycle_length=12,
114+
estimate_cycle_length=False,
115+
innovations=False,
116+
observed_state_names=["data_1", "data_2", "data_3"],
117+
share_states=True,
118+
)
119+
120+
assert cycle.state_names == ["Cos_cycle[shared]", "Sin_cycle[shared]"]
121+
assert cycle.shock_names == []
122+
assert cycle.param_names == ["cycle"]
123+
124+
params = {"cycle": np.array([1.0, 2.0], dtype=config.floatX)}
125+
x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
126+
127+
np.testing.assert_allclose(y[:, 0], y[:, 1], atol=ATOL, rtol=RTOL)
128+
np.testing.assert_allclose(y[:, 0], y[:, 2], atol=ATOL, rtol=RTOL)
129+
130+
108131
def test_cycle_multivariate_with_dampening(rng):
109132
"""Test multivariate cycle component with dampening."""
110133
cycle = st.CycleComponent(
@@ -286,3 +309,90 @@ def test_add_multivariate_cycle_components_with_different_observed():
286309
for i in range(4):
287310
expected_R[2 * i : 2 * i + 2, 2 * i : 2 * i + 2] = np.eye(2)
288311
assert_allclose(R, expected_R)
312+
313+
314+
def test_add_multivariate_shared_and_not_shared():
315+
cycle_shared = st.CycleComponent(
316+
name="shared_cycle",
317+
cycle_length=12,
318+
estimate_cycle_length=False,
319+
innovations=True,
320+
observed_state_names=["gdp", "inflation", "unemployment"],
321+
share_states=True,
322+
)
323+
cycle_individual = st.CycleComponent(
324+
name="individual_cycle",
325+
estimate_cycle_length=True,
326+
innovations=False,
327+
observed_state_names=["gdp", "inflation", "unemployment"],
328+
dampen=True,
329+
)
330+
mod = (cycle_shared + cycle_individual).build(verbose=False)
331+
332+
assert mod.k_endog == 3
333+
assert mod.k_states == 2 + 3 * 2
334+
assert mod.k_posdef == 2 + 3 * 2
335+
336+
expected_states = [
337+
"Cos_shared_cycle[shared]",
338+
"Sin_shared_cycle[shared]",
339+
"Cos_individual_cycle[gdp]",
340+
"Sin_individual_cycle[gdp]",
341+
"Cos_individual_cycle[inflation]",
342+
"Sin_individual_cycle[inflation]",
343+
"Cos_individual_cycle[unemployment]",
344+
"Sin_individual_cycle[unemployment]",
345+
]
346+
347+
assert mod.state_names == expected_states
348+
assert mod.shock_names == expected_states[:2]
349+
350+
assert mod.param_names == [
351+
"shared_cycle",
352+
"sigma_shared_cycle",
353+
"individual_cycle",
354+
"length_individual_cycle",
355+
"dampening_factor_individual_cycle",
356+
"P0",
357+
]
358+
359+
assert "endog_shared_cycle" not in mod.coords
360+
assert mod.coords["state_shared_cycle"] == ["Cos_shared_cycle", "Sin_shared_cycle"]
361+
assert mod.coords["state_individual_cycle"] == ["Cos_individual_cycle", "Sin_individual_cycle"]
362+
assert mod.coords["endog_individual_cycle"] == ["gdp", "inflation", "unemployment"]
363+
364+
assert mod.param_info["shared_cycle"]["dims"] == ("state_shared_cycle",)
365+
assert mod.param_info["shared_cycle"]["shape"] == (2,)
366+
367+
assert mod.param_info["sigma_shared_cycle"]["dims"] is None
368+
assert mod.param_info["sigma_shared_cycle"]["shape"] == ()
369+
370+
assert mod.param_info["individual_cycle"]["dims"] == (
371+
"endog_individual_cycle",
372+
"state_individual_cycle",
373+
)
374+
assert mod.param_info["individual_cycle"]["shape"] == (3, 2)
375+
376+
params = {
377+
"length_individual_cycle": 12.0,
378+
"dampening_factor_individual_cycle": 0.95,
379+
}
380+
outputs = [mod.ssm["transition"], mod.ssm["design"], mod.ssm["selection"]]
381+
T, Z, R = pytensor.function(
382+
list(explicit_graph_inputs(outputs)),
383+
outputs,
384+
mode="FAST_COMPILE",
385+
)(**params)
386+
387+
lamb = 2 * np.pi / 12 # dampening factor for individual cycle
388+
transition_block = np.array(
389+
[[np.cos(lamb), np.sin(lamb)], [-np.sin(lamb), np.cos(lamb)]], dtype=config.floatX
390+
)
391+
T_expected = linalg.block_diag(transition_block, *[0.95 * transition_block] * 3)
392+
np.testing.assert_allclose(T, T_expected)
393+
394+
np.testing.assert_allclose(
395+
Z, np.array([[1, 0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 1, 0]])
396+
)
397+
398+
np.testing.assert_allclose(R, np.eye(8, dtype=config.floatX))

0 commit comments

Comments
 (0)