Skip to content

Commit 3ec8a7a

Browse files
committed
Merge branch 'shared-multivariate' of https://github.com/jessegrabowski/pymc-extras into shared-multivariate
pull remote updated
2 parents c553ef4 + f6bf481 commit 3ec8a7a

File tree

7 files changed

+598
-159
lines changed

7 files changed

+598
-159
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class CycleComponent(Component):
4343
Names of the observed state variables. For univariate time series, defaults to ``["data"]``.
4444
For multivariate time series, specify a list of names for each endogenous variable.
4545
46+
share_states: bool, default False
47+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
48+
states, which are observed by all observed states. If False, each observed state has its own set of
49+
latent states. This argument has no effect if `k_endog` is 1.
50+
4651
Notes
4752
-----
4853
The cycle component is very similar in implementation to the frequency domain seasonal component, expect that it
@@ -155,6 +160,7 @@ def __init__(
155160
dampen: bool = False,
156161
innovations: bool = True,
157162
observed_state_names: list[str] | None = None,
163+
share_states: bool = False,
158164
):
159165
if observed_state_names is None:
160166
observed_state_names = ["data"]
@@ -167,6 +173,7 @@ def __init__(
167173
cycle = int(cycle_length) if cycle_length is not None else "Estimate"
168174
name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]"
169175

176+
self.share_states = share_states
170177
self.estimate_cycle_length = estimate_cycle_length
171178
self.cycle_length = cycle_length
172179
self.innovations = innovations
@@ -175,8 +182,8 @@ def __init__(
175182

176183
k_endog = len(observed_state_names)
177184

178-
k_states = 2 * k_endog
179-
k_posdef = 2 * k_endog
185+
k_states = 2 if share_states else 2 * k_endog
186+
k_posdef = 2 if share_states else 2 * k_endog
180187

181188
obs_state_idx = np.zeros(k_states)
182189
obs_state_idx[slice(0, k_states, 2)] = 1
@@ -193,18 +200,22 @@ def __init__(
193200
)
194201

195202
def make_symbolic_graph(self) -> None:
203+
k_endog = self.k_endog
204+
k_endog_effective = 1 if self.share_states else k_endog
205+
196206
Z = np.array([1.0, 0.0]).reshape((1, -1))
197-
design_matrix = block_diag(*[Z for _ in range(self.k_endog)])
207+
design_matrix = block_diag(*[Z for _ in range(k_endog_effective)])
198208
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
199209

200210
# selection matrix R defines structure of innovations (always identity for cycle components)
201211
# when innovations=False, state cov Q=0, hence R @ Q @ R.T = 0
202212
R = np.eye(2) # 2x2 identity for each cycle component
203-
selection_matrix = block_diag(*[R for _ in range(self.k_endog)])
213+
selection_matrix = block_diag(*[R for _ in range(k_endog_effective)])
204214
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
205215

206216
init_state = self.make_and_register_variable(
207-
f"{self.name}", shape=(self.k_endog, 2) if self.k_endog > 1 else (self.k_states,)
217+
f"{self.name}",
218+
shape=(k_endog_effective, 2) if k_endog_effective > 1 else (self.k_states,),
208219
)
209220
self.ssm["initial_state", :] = init_state.ravel()
210221

@@ -219,37 +230,45 @@ def make_symbolic_graph(self) -> None:
219230
rho = 1
220231

221232
T = rho * _frequency_transition_block(lamb, j=1)
222-
transition = block_diag(*[T for _ in range(self.k_endog)])
233+
transition = block_diag(*[T for _ in range(k_endog_effective)])
223234
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
224235

225236
if self.innovations:
226-
if self.k_endog == 1:
237+
if k_endog_effective == 1:
227238
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
228239
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
229240
else:
230241
sigma_cycle = self.make_and_register_variable(
231-
f"sigma_{self.name}", shape=(self.k_endog,)
242+
f"sigma_{self.name}", shape=(k_endog_effective,)
232243
)
233244
state_cov = block_diag(
234-
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
245+
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(k_endog_effective)]
235246
)
236247
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
237248
else:
238249
# explicitly set state cov to 0 when no innovations
239250
self.ssm["state_cov", :, :] = pt.zeros((self.k_posdef, self.k_posdef))
240251

241252
def populate_component_properties(self):
242-
self.state_names = [
243-
f"{f}_{self.name}[{var_name}]" if self.k_endog > 1 else f"{f}_{self.name}"
244-
for var_name in self.observed_state_names
245-
for f in ["Cos", "Sin"]
246-
]
253+
k_endog = self.k_endog
254+
k_endog_effective = 1 if self.share_states else k_endog
255+
256+
base_names = [f"{f}_{self.name}" for f in ["Cos", "Sin"]]
257+
258+
if self.share_states:
259+
self.state_names = [f"{name}[shared]" for name in base_names]
260+
else:
261+
self.state_names = [
262+
f"{name}[{var_name}]" if k_endog_effective > 1 else name
263+
for var_name in self.observed_state_names
264+
for name in base_names
265+
]
247266

248267
self.param_names = [f"{self.name}"]
249268

250-
if self.k_endog == 1:
269+
if k_endog_effective == 1:
251270
self.param_dims = {self.name: (f"state_{self.name}",)}
252-
self.coords = {f"state_{self.name}": self.state_names}
271+
self.coords = {f"state_{self.name}": base_names}
253272
self.param_info = {
254273
f"{self.name}": {
255274
"shape": (2,),
@@ -265,7 +284,7 @@ def populate_component_properties(self):
265284
}
266285
self.param_info = {
267286
f"{self.name}": {
268-
"shape": (self.k_endog, 2),
287+
"shape": (k_endog_effective, 2),
269288
"constraints": None,
270289
"dims": (f"endog_{self.name}", f"state_{self.name}"),
271290
}
@@ -274,22 +293,22 @@ def populate_component_properties(self):
274293
if self.estimate_cycle_length:
275294
self.param_names += [f"length_{self.name}"]
276295
self.param_info[f"length_{self.name}"] = {
277-
"shape": () if self.k_endog == 1 else (self.k_endog,),
296+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
278297
"constraints": "Positive, non-zero",
279-
"dims": None if self.k_endog == 1 else f"endog_{self.name}",
298+
"dims": None if k_endog_effective == 1 else f"endog_{self.name}",
280299
}
281300

282301
if self.dampen:
283302
self.param_names += [f"dampening_factor_{self.name}"]
284303
self.param_info[f"dampening_factor_{self.name}"] = {
285-
"shape": () if self.k_endog == 1 else (self.k_endog,),
304+
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
286305
"constraints": "0 < x ≤ 1",
287-
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
306+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
288307
}
289308

290309
if self.innovations:
291310
self.param_names += [f"sigma_{self.name}"]
292-
if self.k_endog == 1:
311+
if k_endog_effective == 1:
293312
self.param_info[f"sigma_{self.name}"] = {
294313
"shape": (),
295314
"constraints": "Positive",
@@ -298,7 +317,7 @@ def populate_component_properties(self):
298317
else:
299318
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
300319
self.param_info[f"sigma_{self.name}"] = {
301-
"shape": (self.k_endog,),
320+
"shape": (k_endog_effective,),
302321
"constraints": "Positive",
303322
"dims": (f"endog_{self.name}",),
304323
}

0 commit comments

Comments
 (0)