@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
44
44
observed_state_names: list[str] | None, default None
45
45
List of strings for observed state labels. If None, defaults to ["data"].
46
46
47
+ share_states: bool, default False
48
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
49
+ states, which are observed by all observed states. If False, each observed state has its own set of
50
+ latent states. This argument has no effect if `k_endog` is 1.
51
+
47
52
Notes
48
53
-----
49
54
A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235
240
state_names : list | None = None ,
236
241
remove_first_state : bool = True ,
237
242
observed_state_names : list [str ] | None = None ,
243
+ share_states : bool = False ,
238
244
):
239
245
if observed_state_names is None :
240
246
observed_state_names = ["data" ]
@@ -261,6 +267,7 @@ def __init__(
261
267
)
262
268
state_names = state_names .copy ()
263
269
270
+ self .share_states = share_states
264
271
self .innovations = innovations
265
272
self .duration = duration
266
273
self .remove_first_state = remove_first_state
@@ -281,44 +288,53 @@ def __init__(
281
288
super ().__init__ (
282
289
name = name ,
283
290
k_endog = k_endog ,
284
- k_states = k_states * k_endog ,
285
- k_posdef = k_posdef * k_endog ,
291
+ k_states = k_states if share_states else k_states * k_endog ,
292
+ k_posdef = k_posdef if share_states else k_posdef * k_endog ,
286
293
observed_state_names = observed_state_names ,
287
294
measurement_error = False ,
288
295
combine_hidden_states = True ,
289
- obs_state_idxs = np .tile (np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), k_endog ),
296
+ obs_state_idxs = np .tile (
297
+ np .array ([1.0 ] + [0.0 ] * (k_states - 1 )), 1 if share_states else k_endog
298
+ ),
290
299
)
291
300
292
301
def populate_component_properties (self ):
293
- k_states = self .k_states // self .k_endog
294
302
k_endog = self .k_endog
303
+ k_endog_effective = 1 if self .share_states else k_endog
295
304
296
- self .state_names = [
297
- f"{ state_name } [{ endog_name } ]"
298
- for endog_name in self .observed_state_names
299
- for state_name in self .provided_state_names
300
- ]
305
+ k_states = self .k_states // k_endog_effective
306
+
307
+ if self .share_states :
308
+ self .state_names = [
309
+ f"{ state_name } [{ self .name } _shared]" for state_name in self .provided_state_names
310
+ ]
311
+ else :
312
+ self .state_names = [
313
+ f"{ state_name } [{ endog_name } ]"
314
+ for endog_name in self .observed_state_names
315
+ for state_name in self .provided_state_names
316
+ ]
301
317
self .param_names = [f"coefs_{ self .name } " ]
302
318
303
319
self .param_info = {
304
320
f"coefs_{ self .name } " : {
305
- "shape" : (k_states ,) if k_endog == 1 else (k_endog , k_states ),
321
+ "shape" : (k_states ,) if k_endog_effective == 1 else (k_endog_effective , k_states ),
306
322
"constraints" : None ,
307
323
"dims" : (f"state_{ self .name } " ,)
308
- if k_endog == 1
324
+ if k_endog_effective == 1
309
325
else (f"endog_{ self .name } " , f"state_{ self .name } " ),
310
326
}
311
327
}
312
328
313
329
self .param_dims = {
314
330
f"coefs_{ self .name } " : (f"state_{ self .name } " ,)
315
- if k_endog == 1
331
+ if k_endog_effective == 1
316
332
else (f"endog_{ self .name } " , f"state_{ self .name } " )
317
333
}
318
334
319
335
self .coords = (
320
336
{f"state_{ self .name } " : self .provided_state_names }
321
- if k_endog == 1
337
+ if k_endog_effective == 1
322
338
else {
323
339
f"endog_{ self .name } " : self .observed_state_names ,
324
340
f"state_{ self .name } " : self .provided_state_names ,
@@ -332,14 +348,19 @@ def populate_component_properties(self):
332
348
"constraints" : "Positive" ,
333
349
"dims" : None ,
334
350
}
335
- self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
351
+ if self .share_states :
352
+ self .shock_names = [f"{ self .name } [shared]" ]
353
+ else :
354
+ self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
336
355
337
356
def make_symbolic_graph (self ) -> None :
338
- k_states = self .k_states // self .k_endog
357
+ k_endog = self .k_endog
358
+ k_endog_effective = 1 if self .share_states else k_endog
359
+ k_states = self .k_states // k_endog_effective
339
360
duration = self .duration
361
+
340
362
k_unique_states = k_states // duration
341
- k_posdef = self .k_posdef // self .k_endog
342
- k_endog = self .k_endog
363
+ k_posdef = self .k_posdef // k_endog_effective
343
364
344
365
if self .remove_first_state :
345
366
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -371,16 +392,18 @@ def make_symbolic_graph(self) -> None:
371
392
T = pt .eye (k_states , k = 1 )
372
393
T = pt .set_subtensor (T [- 1 , 0 ], 1 )
373
394
374
- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
395
+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
375
396
376
397
Z = pt .zeros ((1 , k_states ))[0 , 0 ].set (1 )
377
- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
398
+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
378
399
379
400
initial_states = self .make_and_register_variable (
380
401
f"coefs_{ self .name } " ,
381
- shape = (k_unique_states ,) if k_endog == 1 else (k_endog , k_unique_states ),
402
+ shape = (k_unique_states ,)
403
+ if k_endog_effective == 1
404
+ else (k_endog_effective , k_unique_states ),
382
405
)
383
- if k_endog == 1 :
406
+ if k_endog_effective == 1 :
384
407
self .ssm ["initial_state" , :] = pt .extra_ops .repeat (initial_states , duration , axis = 0 )
385
408
else :
386
409
self .ssm ["initial_state" , :] = pt .extra_ops .repeat (
@@ -389,11 +412,11 @@ def make_symbolic_graph(self) -> None:
389
412
390
413
if self .innovations :
391
414
R = pt .zeros ((k_states , k_posdef ))[0 , 0 ].set (1.0 )
392
- self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog )])
415
+ self .ssm ["selection" , :, :] = pt .join (0 , * [R for _ in range (k_endog_effective )])
393
416
season_sigma = self .make_and_register_variable (
394
- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
417
+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
395
418
)
396
- cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog ))
419
+ cov_idx = ("state_cov" , * np .diag_indices (k_posdef * k_endog_effective ))
397
420
self .ssm [cov_idx ] = season_sigma ** 2
398
421
399
422
0 commit comments