Skip to content

Commit 8f647af

Browse files
Add shared_state argument to TimeSeasonality
1 parent 48ac660 commit 8f647af

File tree

2 files changed

+149
-24
lines changed

2 files changed

+149
-24
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
4444
observed_state_names: list[str] | None, default None
4545
List of strings for observed state labels. If None, defaults to ["data"].
4646
47+
share_states: bool, default False
48+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
49+
states, which are observed by all observed states. If False, each observed state has its own set of
50+
latent states. This argument has no effect if `k_endog` is 1.
51+
4752
Notes
4853
-----
4954
A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235240
state_names: list | None = None,
236241
remove_first_state: bool = True,
237242
observed_state_names: list[str] | None = None,
243+
share_states: bool = False,
238244
):
239245
if observed_state_names is None:
240246
observed_state_names = ["data"]
@@ -261,6 +267,7 @@ def __init__(
261267
)
262268
state_names = state_names.copy()
263269

270+
self.share_states = share_states
264271
self.innovations = innovations
265272
self.duration = duration
266273
self.remove_first_state = remove_first_state
@@ -281,44 +288,53 @@ def __init__(
281288
super().__init__(
282289
name=name,
283290
k_endog=k_endog,
284-
k_states=k_states * k_endog,
285-
k_posdef=k_posdef * k_endog,
291+
k_states=k_states if share_states else k_states * k_endog,
292+
k_posdef=k_posdef if share_states else k_posdef * k_endog,
286293
observed_state_names=observed_state_names,
287294
measurement_error=False,
288295
combine_hidden_states=True,
289-
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
296+
obs_state_idxs=np.tile(
297+
np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog
298+
),
290299
)
291300

292301
def populate_component_properties(self):
293-
k_states = self.k_states // self.k_endog
294302
k_endog = self.k_endog
303+
k_endog_effective = 1 if self.share_states else k_endog
295304

296-
self.state_names = [
297-
f"{state_name}[{endog_name}]"
298-
for endog_name in self.observed_state_names
299-
for state_name in self.provided_state_names
300-
]
305+
k_states = self.k_states // k_endog_effective
306+
307+
if self.share_states:
308+
self.state_names = [
309+
f"{state_name}[{self.name}_shared]" for state_name in self.provided_state_names
310+
]
311+
else:
312+
self.state_names = [
313+
f"{state_name}[{endog_name}]"
314+
for endog_name in self.observed_state_names
315+
for state_name in self.provided_state_names
316+
]
301317
self.param_names = [f"coefs_{self.name}"]
302318

303319
self.param_info = {
304320
f"coefs_{self.name}": {
305-
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
321+
"shape": (k_states,) if k_endog_effective == 1 else (k_endog_effective, k_states),
306322
"constraints": None,
307323
"dims": (f"state_{self.name}",)
308-
if k_endog == 1
324+
if k_endog_effective == 1
309325
else (f"endog_{self.name}", f"state_{self.name}"),
310326
}
311327
}
312328

313329
self.param_dims = {
314330
f"coefs_{self.name}": (f"state_{self.name}",)
315-
if k_endog == 1
331+
if k_endog_effective == 1
316332
else (f"endog_{self.name}", f"state_{self.name}")
317333
}
318334

319335
self.coords = (
320336
{f"state_{self.name}": self.provided_state_names}
321-
if k_endog == 1
337+
if k_endog_effective == 1
322338
else {
323339
f"endog_{self.name}": self.observed_state_names,
324340
f"state_{self.name}": self.provided_state_names,
@@ -332,14 +348,19 @@ def populate_component_properties(self):
332348
"constraints": "Positive",
333349
"dims": None,
334350
}
335-
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
351+
if self.share_states:
352+
self.shock_names = [f"{self.name}[shared]"]
353+
else:
354+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
336355

337356
def make_symbolic_graph(self) -> None:
338-
k_states = self.k_states // self.k_endog
357+
k_endog = self.k_endog
358+
k_endog_effective = 1 if self.share_states else k_endog
359+
k_states = self.k_states // k_endog_effective
339360
duration = self.duration
361+
340362
k_unique_states = k_states // duration
341-
k_posdef = self.k_posdef // self.k_endog
342-
k_endog = self.k_endog
363+
k_posdef = self.k_posdef // k_endog_effective
343364

344365
if self.remove_first_state:
345366
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -371,16 +392,18 @@ def make_symbolic_graph(self) -> None:
371392
T = pt.eye(k_states, k=1)
372393
T = pt.set_subtensor(T[-1, 0], 1)
373394

374-
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
395+
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
375396

376397
Z = pt.zeros((1, k_states))[0, 0].set(1)
377-
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
398+
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
378399

379400
initial_states = self.make_and_register_variable(
380401
f"coefs_{self.name}",
381-
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
402+
shape=(k_unique_states,)
403+
if k_endog_effective == 1
404+
else (k_endog_effective, k_unique_states),
382405
)
383-
if k_endog == 1:
406+
if k_endog_effective == 1:
384407
self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0)
385408
else:
386409
self.ssm["initial_state", :] = pt.extra_ops.repeat(
@@ -389,11 +412,11 @@ def make_symbolic_graph(self) -> None:
389412

390413
if self.innovations:
391414
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
392-
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
415+
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)])
393416
season_sigma = self.make_and_register_variable(
394-
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
417+
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
395418
)
396-
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
419+
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective))
397420
self.ssm[cov_idx] = season_sigma**2
398421

399422

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,108 @@ def test_time_seasonality_multiple_observed(rng, d, remove_first_state):
146146
np.testing.assert_allclose(matrix, expected)
147147

148148

149+
def test_time_seasonality_shared_states():
150+
mod = st.TimeSeasonality(
151+
season_length=3,
152+
duration=1,
153+
innovations=True,
154+
name="season",
155+
state_names=["season_1", "season_2", "season_3"],
156+
observed_state_names=["data_1", "data_2"],
157+
remove_first_state=False,
158+
share_states=True,
159+
)
160+
161+
assert mod.k_endog == 2
162+
assert mod.k_states == 3
163+
assert mod.k_posdef == 1
164+
165+
assert mod.coords["state_season"] == ["season_1", "season_2", "season_3"]
166+
167+
assert mod.state_names == [
168+
"season_1[season_shared]",
169+
"season_2[season_shared]",
170+
"season_3[season_shared]",
171+
]
172+
assert mod.shock_names == ["season[shared]"]
173+
174+
Z, T, R = pytensor.function(
175+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
176+
)()
177+
178+
np.testing.assert_allclose(np.array([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), Z)
179+
180+
np.testing.assert_allclose(np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), T)
181+
182+
np.testing.assert_allclose(np.array([[1.0], [0.0], [0.0]]), R)
183+
184+
185+
def test_add_mixed_shared_not_shared_time_seasonality():
186+
shared_season = st.TimeSeasonality(
187+
season_length=3,
188+
duration=1,
189+
innovations=True,
190+
name="shared",
191+
state_names=["season_1", "season_2", "season_3"],
192+
observed_state_names=["data_1", "data_2"],
193+
remove_first_state=False,
194+
share_states=True,
195+
)
196+
individual_season = st.TimeSeasonality(
197+
season_length=3,
198+
duration=1,
199+
innovations=False,
200+
name="individual",
201+
state_names=["season_1", "season_2", "season_3"],
202+
observed_state_names=["data_1", "data_2"],
203+
remove_first_state=True,
204+
share_states=False,
205+
)
206+
mod = (shared_season + individual_season).build(verbose=False)
207+
208+
assert mod.k_endog == 2
209+
assert mod.k_states == 7
210+
assert mod.k_posdef == 1
211+
212+
assert mod.coords["state_shared"] == ["season_1", "season_2", "season_3"]
213+
assert mod.coords["state_individual"] == ["season_2", "season_3"]
214+
215+
assert mod.state_names == [
216+
"season_1[shared_shared]",
217+
"season_2[shared_shared]",
218+
"season_3[shared_shared]",
219+
"season_2[data_1]",
220+
"season_3[data_1]",
221+
"season_2[data_2]",
222+
"season_3[data_2]",
223+
]
224+
225+
Z, T, R = pytensor.function(
226+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
227+
)()
228+
229+
np.testing.assert_allclose(
230+
np.array([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]]), Z
231+
)
232+
233+
np.testing.assert_allclose(
234+
np.array(
235+
[
236+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
237+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
238+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
239+
[0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0],
240+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
241+
[0.0, 0.0, 0.0, 0.0, 0.0, -1.0, -1.0],
242+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
243+
]
244+
),
245+
T,
246+
)
247+
248+
np.testing.assert_allclose(np.array([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]), R)
249+
250+
149251
@pytest.mark.parametrize("d1, d2", [(1, 1), (1, 3), (3, 1), (3, 3)])
150252
def test_add_two_time_seasonality_different_observed(rng, d1, d2):
151253
mod1 = st.TimeSeasonality(

0 commit comments

Comments
 (0)