Skip to content

Commit f6bf481

Browse files
Add shared_state argument to FrequencySeasonality
1 parent 81a02be commit f6bf481

File tree

2 files changed

+114
-34
lines changed

2 files changed

+114
-34
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,11 @@ class FrequencySeasonality(Component):
449449
observed_state_names: list[str] | None, default None
450450
List of strings for observed state labels. If None, defaults to ["data"].
451451
452+
share_states: bool, default False
453+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
454+
states, which are observed by all observed states. If False, each observed state has its own set of
455+
latent states. This argument has no effect if `k_endog` is 1.
456+
452457
Notes
453458
-----
454459
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -480,15 +485,17 @@ class FrequencySeasonality(Component):
480485

481486
def __init__(
482487
self,
483-
season_length,
484-
n=None,
485-
name=None,
486-
innovations=True,
488+
season_length: int,
489+
n: int | None = None,
490+
name: str | None = None,
491+
innovations: bool = True,
487492
observed_state_names: list[str] | None = None,
493+
share_states: bool = False,
488494
):
489495
if observed_state_names is None:
490496
observed_state_names = ["data"]
491497

498+
self.share_states = share_states
492499
k_endog = len(observed_state_names)
493500

494501
if n is None:
@@ -504,18 +511,20 @@ def __init__(
504511
# If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
505512
# get a parameter assigned to it and should just be fixed to zero.
506513
# Test this way (rather than n == s // 2) to catch cases when n is non-integer.
507-
self.last_state_not_identified = self.season_length / self.n == 2.0
514+
self.last_state_not_identified = (self.season_length / self.n) == 2.0
508515
self.n_coefs = k_states - int(self.last_state_not_identified)
509516

510517
obs_state_idx = np.zeros(k_states)
511518
obs_state_idx[slice(0, k_states, 2)] = 1
512-
obs_state_idx = np.tile(obs_state_idx, k_endog)
519+
obs_state_idx = np.tile(obs_state_idx, 1 if share_states else k_endog)
513520

514521
super().__init__(
515522
name=name,
516523
k_endog=k_endog,
517-
k_states=k_states * k_endog,
518-
k_posdef=k_states * int(self.innovations) * k_endog,
524+
k_states=k_states if share_states else k_states * k_endog,
525+
k_posdef=k_states * int(self.innovations)
526+
if share_states
527+
else k_states * int(self.innovations) * k_endog,
519528
observed_state_names=observed_state_names,
520529
measurement_error=False,
521530
combine_hidden_states=True,
@@ -524,13 +533,15 @@ def __init__(
524533

525534
def make_symbolic_graph(self) -> None:
526535
k_endog = self.k_endog
527-
k_states = self.k_states // k_endog
528-
k_posdef = self.k_posdef // k_endog
536+
k_endog_effective = 1 if self.share_states else k_endog
537+
538+
k_states = self.k_states // k_endog_effective
539+
k_posdef = self.k_posdef // k_endog_effective
529540
n_coefs = self.n_coefs
530541

531542
Z = pt.zeros((1, k_states))[0, slice(0, k_states, 2)].set(1.0)
532543

533-
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
544+
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
534545

535546
init_state = self.make_and_register_variable(
536547
f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
@@ -539,7 +550,7 @@ def make_symbolic_graph(self) -> None:
539550
init_state_idx = np.concatenate(
540551
[
541552
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
542-
for i in range(k_endog)
553+
for i in range(k_endog_effective)
543554
],
544555
axis=0,
545556
)
@@ -548,11 +559,11 @@ def make_symbolic_graph(self) -> None:
548559

549560
T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
550561
T = pt.linalg.block_diag(*T_mats)
551-
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
562+
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
552563

553564
if self.innovations:
554565
sigma_season = self.make_and_register_variable(
555-
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
566+
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
556567
)
557568
self.ssm["selection", :, :] = pt.eye(self.k_states)
558569
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * pt.repeat(
@@ -561,35 +572,35 @@ def make_symbolic_graph(self) -> None:
561572

562573
def populate_component_properties(self):
563574
k_endog = self.k_endog
575+
k_endog_effective = 1 if self.share_states else k_endog
564576
n_coefs = self.n_coefs
565577

566-
self.state_names = [
567-
f"{f}_{i}_{self.name}[{obs_state_name}]"
568-
for obs_state_name in self.observed_state_names
569-
for i in range(self.n)
570-
for f in ["Cos", "Sin"]
571-
]
572-
# determine which state names correspond to parameters
573-
# all endog variables use same state structure, so we just need
574-
# the first n_coefs state names (which may be less than total if saturated)
575-
param_state_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]][
576-
:n_coefs
577-
]
578+
base_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]]
578579

579-
self.param_names = [f"params_{self.name}"]
580+
if self.share_states:
581+
self.state_names = [f"{name}[shared]" for name in base_names]
582+
else:
583+
self.state_names = [
584+
f"{name}[{obs_state_name}]"
585+
for obs_state_name in self.observed_state_names
586+
for name in base_names
587+
]
580588

589+
# Trim state names if the model is saturated
590+
param_state_names = base_names[:n_coefs]
591+
592+
self.param_names = [f"params_{self.name}"]
581593
self.param_dims = {
582594
f"params_{self.name}": (f"state_{self.name}",)
583-
if k_endog == 1
595+
if k_endog_effective == 1
584596
else (f"endog_{self.name}", f"state_{self.name}")
585597
}
586-
587598
self.param_info = {
588599
f"params_{self.name}": {
589-
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
600+
"shape": (n_coefs,) if k_endog_effective == 1 else (k_endog_effective, n_coefs),
590601
"constraints": None,
591602
"dims": (f"state_{self.name}",)
592-
if k_endog == 1
603+
if k_endog_effective == 1
593604
else (f"endog_{self.name}", f"state_{self.name}"),
594605
}
595606
}
@@ -607,9 +618,9 @@ def populate_component_properties(self):
607618
self.param_names += [f"sigma_{self.name}"]
608619
self.shock_names = self.state_names.copy()
609620
self.param_info[f"sigma_{self.name}"] = {
610-
"shape": () if k_endog == 1 else (k_endog,),
621+
"shape": () if k_endog_effective == 1 else (k_endog_effective, n_coefs),
611622
"constraints": "Positive",
612-
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
623+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
613624
}
614-
if k_endog > 1:
625+
if k_endog_effective > 1:
615626
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,39 @@ def test_frequency_seasonality_multiple_observed(rng):
474474
np.testing.assert_allclose(Q_diag, expected_Q_diag, atol=ATOL, rtol=RTOL)
475475

476476

477+
def test_frequency_seasonality_multivariate_shared_states():
478+
mod = st.FrequencySeasonality(
479+
season_length=4,
480+
n=1,
481+
name="season",
482+
innovations=True,
483+
observed_state_names=["data_1", "data_2"],
484+
share_states=True,
485+
)
486+
487+
assert mod.k_endog == 2
488+
assert mod.k_states == 2
489+
assert mod.k_posdef == 2
490+
491+
assert mod.state_names == ["Cos_0_season[shared]", "Sin_0_season[shared]"]
492+
assert mod.shock_names == ["Cos_0_season[shared]", "Sin_0_season[shared]"]
493+
494+
assert mod.coords["state_season"] == ["Cos_0_season", "Sin_0_season"]
495+
496+
Z, T, R = pytensor.function(
497+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
498+
)()
499+
500+
np.testing.assert_allclose(np.array([[1.0, 0.0], [1.0, 0.0]]), Z)
501+
502+
np.testing.assert_allclose(np.array([[1.0, 0.0], [0.0, 1.0]]), R)
503+
504+
lam = 2 * np.pi * 1 / 4
505+
np.testing.assert_allclose(
506+
np.array([[np.cos(lam), np.sin(lam)], [-np.sin(lam), np.cos(lam)]]), T
507+
)
508+
509+
477510
def test_add_two_frequency_seasonality_different_observed(rng):
478511
mod1 = st.FrequencySeasonality(
479512
season_length=4,
@@ -561,6 +594,42 @@ def test_add_two_frequency_seasonality_different_observed(rng):
561594
np.testing.assert_allclose(expected_T, T_v, atol=ATOL, rtol=RTOL)
562595

563596

597+
def test_add_frequency_seasonality_shared_and_not_shared():
598+
shared_season = st.FrequencySeasonality(
599+
season_length=4,
600+
n=1,
601+
name="shared_season",
602+
innovations=True,
603+
observed_state_names=["data_1", "data_2"],
604+
share_states=True,
605+
)
606+
607+
individual_season = st.FrequencySeasonality(
608+
season_length=4,
609+
n=2,
610+
name="individual_season",
611+
innovations=True,
612+
observed_state_names=["data_1", "data_2"],
613+
share_states=False,
614+
)
615+
616+
mod = (shared_season + individual_season).build(verbose=False)
617+
618+
assert mod.k_endog == 2
619+
assert mod.k_states == 10
620+
assert mod.k_posdef == 10
621+
622+
assert mod.coords["state_shared_season"] == [
623+
"Cos_0_shared_season",
624+
"Sin_0_shared_season",
625+
]
626+
assert mod.coords["state_individual_season"] == [
627+
"Cos_0_individual_season",
628+
"Sin_0_individual_season",
629+
"Cos_1_individual_season",
630+
]
631+
632+
564633
@pytest.mark.parametrize(
565634
"test_case",
566635
[

0 commit comments

Comments
 (0)