Skip to content

Commit b5aa9bf

Browse files
jessegrabowskiDekermanjianAlexAndorra
authored
Allow multivariate components to share latent states (#558)
* Add shared_state argument to LevelTrendComponent * Add shared_state argument to TimeSeasonality * Add shared_state argument to CycleComponent * Add shared_state argument to FrequencySeasonality * Add shared_state argument to RegressionComponent * Add shared_state argument to AutoregressiveComponent * Add shared_state argument to MeasurementError * Pass `share_states` to `super` calls * Add shared_states flag in core.py * Fix cycle tests after param renaming --------- Co-authored-by: Jonathan Dekermanjian <[email protected]> Co-authored-by: Alexandre Andorra <[email protected]>
1 parent e375978 commit b5aa9bf

File tree

14 files changed

+1030
-177
lines changed

14 files changed

+1030
-177
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class AutoregressiveComponent(Component):
2323
observed_state_names: list[str] | None, default None
2424
List of strings for observed state labels. If None, defaults to ["data"].
2525
26+
share_states: bool, default False
27+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
28+
states, which are observed by all observed states. If False, each observed state has its own set of
29+
latent states. This argument has no effect if `k_endog` is 1.
30+
2631
Notes
2732
-----
2833
An autoregressive component can be thought of as a way o introducing serially correlated errors into the model.
@@ -73,45 +78,59 @@ def __init__(
7378
order: int = 1,
7479
name: str = "auto_regressive",
7580
observed_state_names: list[str] | None = None,
81+
share_states: bool = False,
7682
):
7783
if observed_state_names is None:
7884
observed_state_names = ["data"]
7985

80-
k_posdef = k_endog = len(observed_state_names)
86+
k_endog = len(observed_state_names)
87+
k_endog_effective = k_posdef = 1 if share_states else k_endog
8188

8289
order = order_to_mask(order)
8390
ar_lags = np.flatnonzero(order).ravel().astype(int) + 1
8491
k_states = len(order)
8592

93+
self.share_states = share_states
8694
self.order = order
8795
self.ar_lags = ar_lags
8896

8997
super().__init__(
9098
name=name,
9199
k_endog=k_endog,
92-
k_states=k_states * k_endog,
100+
k_states=k_states * k_endog_effective,
93101
k_posdef=k_posdef,
94102
measurement_error=True,
95103
combine_hidden_states=True,
96104
observed_state_names=observed_state_names,
97-
obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog),
105+
obs_state_idxs=np.tile(np.r_[[1.0], np.zeros(k_states - 1)], k_endog_effective),
106+
share_states=share_states,
98107
)
99108

100109
def populate_component_properties(self):
101-
k_states = self.k_states // self.k_endog # this is also the number of AR lags
110+
k_endog = self.k_endog
111+
k_endog_effective = 1 if self.share_states else k_endog
102112

103-
self.state_names = [
104-
f"L{i + 1}[{state_name}]"
105-
for state_name in self.observed_state_names
106-
for i in range(k_states)
107-
]
113+
k_states = self.k_states // k_endog_effective # this is also the number of AR lags
114+
base_names = [f"L{i + 1}_{self.name}" for i in range(k_states)]
115+
116+
if self.share_states:
117+
self.state_names = [f"{name}[shared]" for name in base_names]
118+
self.shock_names = [f"{self.name}[shared]"]
119+
else:
120+
self.state_names = [
121+
f"{name}[{state_name}]"
122+
for state_name in self.observed_state_names
123+
for name in base_names
124+
]
125+
self.shock_names = [
126+
f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names
127+
]
108128

109-
self.shock_names = [f"{self.name}[{obs_name}]" for obs_name in self.observed_state_names]
110129
self.param_names = [f"params_{self.name}", f"sigma_{self.name}"]
111130
self.param_dims = {f"params_{self.name}": (f"lag_{self.name}",)}
112131
self.coords = {f"lag_{self.name}": self.ar_lags.tolist()}
113132

114-
if self.k_endog > 1:
133+
if k_endog_effective > 1:
115134
self.param_dims[f"params_{self.name}"] = (
116135
f"endog_{self.name}",
117136
f"lag_{self.name}",
@@ -140,26 +159,29 @@ def populate_component_properties(self):
140159

141160
def make_symbolic_graph(self) -> None:
142161
k_endog = self.k_endog
143-
k_states = self.k_states // k_endog
162+
k_endog_effective = 1 if self.share_states else k_endog
163+
164+
k_states = self.k_states // k_endog_effective
144165
k_posdef = self.k_posdef
145166

146167
k_nonzero = int(sum(self.order))
147168
ar_params = self.make_and_register_variable(
148-
f"params_{self.name}", shape=(k_nonzero,) if k_endog == 1 else (k_endog, k_nonzero)
169+
f"params_{self.name}",
170+
shape=(k_nonzero,) if k_endog_effective == 1 else (k_endog_effective, k_nonzero),
149171
)
150172
sigma_ar = self.make_and_register_variable(
151-
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
173+
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
152174
)
153175

154-
if k_endog == 1:
176+
if k_endog_effective == 1:
155177
T = pt.eye(k_states, k=-1)
156178
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
157179
T = T[ar_idx].set(ar_params)
158180

159181
else:
160182
transition_matrices = []
161183

162-
for i in range(k_endog):
184+
for i in range(k_endog_effective):
163185
T = pt.eye(k_states, k=-1)
164186
ar_idx = (np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0])
165187
T = T[ar_idx].set(ar_params[i])
@@ -171,18 +193,21 @@ def make_symbolic_graph(self) -> None:
171193
self.ssm["transition", :, :] = T
172194

173195
R = np.eye(k_states)
174-
R_mask = np.full((k_states), False)
196+
R_mask = np.full((k_states,), False)
175197
R_mask[0] = True
176198
R = R[:, R_mask]
177199

178200
self.ssm["selection", :, :] = pt.specify_shape(
179-
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
201+
pt.linalg.block_diag(*[R for _ in range(k_endog_effective)]), (self.k_states, k_posdef)
180202
)
181203

182-
Z = pt.zeros((1, k_states))[0, 0].set(1.0)
183-
self.ssm["design", :, :] = pt.specify_shape(
184-
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
185-
)
204+
Zs = [pt.zeros((1, k_states))[0, 0].set(1.0) for _ in range(k_endog)]
205+
206+
if self.share_states:
207+
Z = pt.join(0, *Zs)
208+
else:
209+
Z = pt.linalg.block_diag(*Zs)
210+
self.ssm["design", :, :] = pt.specify_shape(Z, (k_endog, self.k_states))
186211

187212
cov_idx = ("state_cov", *np.diag_indices(k_posdef))
188213
self.ssm[cov_idx] = sigma_ar**2

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class CycleComponent(Component):
4343
Names of the observed state variables. For univariate time series, defaults to ``["data"]``.
4444
For multivariate time series, specify a list of names for each endogenous variable.
4545
46+
share_states: bool, default False
47+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
48+
states, which are observed by all observed states. If False, each observed state has its own set of
49+
latent states. This argument has no effect if `k_endog` is 1.
50+
4651
Notes
4752
-----
4853
The cycle component is very similar in implementation to the frequency domain seasonal component, expect that it
@@ -155,6 +160,7 @@ def __init__(
155160
dampen: bool = False,
156161
innovations: bool = True,
157162
observed_state_names: list[str] | None = None,
163+
share_states: bool = False,
158164
):
159165
if observed_state_names is None:
160166
observed_state_names = ["data"]
@@ -167,6 +173,7 @@ def __init__(
167173
cycle = int(cycle_length) if cycle_length is not None else "Estimate"
168174
name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]"
169175

176+
self.share_states = share_states
170177
self.estimate_cycle_length = estimate_cycle_length
171178
self.cycle_length = cycle_length
172179
self.innovations = innovations
@@ -175,8 +182,8 @@ def __init__(
175182

176183
k_endog = len(observed_state_names)
177184

178-
k_states = 2 * k_endog
179-
k_posdef = 2 * k_endog
185+
k_states = 2 if share_states else 2 * k_endog
186+
k_posdef = 2 if share_states else 2 * k_endog
180187

181188
obs_state_idx = np.zeros(k_states)
182189
obs_state_idx[slice(0, k_states, 2)] = 1
@@ -190,21 +197,26 @@ def __init__(
190197
combine_hidden_states=True,
191198
obs_state_idxs=obs_state_idx,
192199
observed_state_names=observed_state_names,
200+
share_states=share_states,
193201
)
194202

195203
def make_symbolic_graph(self) -> None:
204+
k_endog = self.k_endog
205+
k_endog_effective = 1 if self.share_states else k_endog
206+
196207
Z = np.array([1.0, 0.0]).reshape((1, -1))
197-
design_matrix = block_diag(*[Z for _ in range(self.k_endog)])
208+
design_matrix = block_diag(*[Z for _ in range(k_endog_effective)])
198209
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
199210

200211
# selection matrix R defines structure of innovations (always identity for cycle components)
201212
# when innovations=False, state cov Q=0, hence R @ Q @ R.T = 0
202213
R = np.eye(2) # 2x2 identity for each cycle component
203-
selection_matrix = block_diag(*[R for _ in range(self.k_endog)])
214+
selection_matrix = block_diag(*[R for _ in range(k_endog_effective)])
204215
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
205216

206217
init_state = self.make_and_register_variable(
207-
f"params_{self.name}", shape=(self.k_endog, 2) if self.k_endog > 1 else (self.k_states,)
218+
f"params_{self.name}",
219+
shape=(k_endog_effective, 2) if k_endog_effective > 1 else (self.k_states,),
208220
)
209221
self.ssm["initial_state", :] = init_state.ravel()
210222

@@ -219,37 +231,45 @@ def make_symbolic_graph(self) -> None:
219231
rho = 1
220232

221233
T = rho * _frequency_transition_block(lamb, j=1)
222-
transition = block_diag(*[T for _ in range(self.k_endog)])
234+
transition = block_diag(*[T for _ in range(k_endog_effective)])
223235
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
224236

225237
if self.innovations:
226-
if self.k_endog == 1:
238+
if k_endog_effective == 1:
227239
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
228240
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
229241
else:
230242
sigma_cycle = self.make_and_register_variable(
231-
f"sigma_{self.name}", shape=(self.k_endog,)
243+
f"sigma_{self.name}", shape=(k_endog_effective,)
232244
)
233245
state_cov = block_diag(
234-
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
246+
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(k_endog_effective)]
235247
)
236248
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
237249
else:
238250
# explicitly set state cov to 0 when no innovations
239251
self.ssm["state_cov", :, :] = pt.zeros((self.k_posdef, self.k_posdef))
240252

241253
def populate_component_properties(self):
242-
self.state_names = [
243-
f"{f}_{self.name}[{var_name}]" if self.k_endog > 1 else f"{f}_{self.name}"
244-
for var_name in self.observed_state_names
245-
for f in ["Cos", "Sin"]
246-
]
254+
k_endog = self.k_endog
255+
k_endog_effective = 1 if self.share_states else k_endog
256+
257+
base_names = [f"{f}_{self.name}" for f in ["Cos", "Sin"]]
258+
259+
if self.share_states:
260+
self.state_names = [f"{name}[shared]" for name in base_names]
261+
else:
262+
self.state_names = [
263+
f"{name}[{var_name}]" if k_endog_effective > 1 else name
264+
for var_name in self.observed_state_names
265+
for name in base_names
266+
]
247267

248268
self.param_names = [f"params_{self.name}"]
249269

250-
if self.k_endog == 1:
270+
if k_endog_effective == 1:
251271
self.param_dims = {f"params_{self.name}": (f"state_{self.name}",)}
252-
self.coords = {f"state_{self.name}": self.state_names}
272+
self.coords = {f"state_{self.name}": base_names}
253273
self.param_info = {
254274
f"params_{self.name}": {
255275
"shape": (2,),
@@ -265,7 +285,7 @@ def populate_component_properties(self):
265285
}
266286
self.param_info = {
267287
f"params_{self.name}": {
268-
"shape": (self.k_endog, 2),
288+
"shape": (k_endog_effective, 2),
269289
"constraints": None,
270290
"dims": (f"endog_{self.name}", f"state_{self.name}"),
271291
}
@@ -274,22 +294,22 @@ def populate_component_properties(self):
274294
if self.estimate_cycle_length:
275295
self.param_names += [f"length_{self.name}"]
276296
self.param_info[f"length_{self.name}"] = {
277-
"shape": () if self.k_endog == 1 else (self.k_endog,),
297+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
278298
"constraints": "Positive, non-zero",
279-
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
299+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
280300
}
281301

282302
if self.dampen:
283303
self.param_names += [f"dampening_factor_{self.name}"]
284304
self.param_info[f"dampening_factor_{self.name}"] = {
285-
"shape": () if self.k_endog == 1 else (self.k_endog,),
305+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
286306
"constraints": "0 < x ≤ 1",
287-
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
307+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
288308
}
289309

290310
if self.innovations:
291311
self.param_names += [f"sigma_{self.name}"]
292-
if self.k_endog == 1:
312+
if k_endog_effective == 1:
293313
self.param_info[f"sigma_{self.name}"] = {
294314
"shape": (),
295315
"constraints": "Positive",
@@ -298,7 +318,7 @@ def populate_component_properties(self):
298318
else:
299319
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
300320
self.param_info[f"sigma_{self.name}"] = {
301-
"shape": (self.k_endog,),
321+
"shape": (k_endog_effective,),
302322
"constraints": "Positive",
303323
"dims": (f"endog_{self.name}",),
304324
}

0 commit comments

Comments
 (0)