@@ -449,6 +449,11 @@ class FrequencySeasonality(Component):
449
449
observed_state_names: list[str] | None, default None
450
450
List of strings for observed state labels. If None, defaults to ["data"].
451
451
452
+ share_states: bool, default False
453
+ Whether latent states are shared across the observed states. If True, there will be only one set of latent
454
+ states, which are observed by all observed states. If False, each observed state has its own set of
455
+ latent states. This argument has no effect if `k_endog` is 1.
456
+
452
457
Notes
453
458
-----
454
459
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
@@ -480,15 +485,17 @@ class FrequencySeasonality(Component):
480
485
481
486
def __init__ (
482
487
self ,
483
- season_length ,
484
- n = None ,
485
- name = None ,
486
- innovations = True ,
488
+ season_length : int ,
489
+ n : int | None = None ,
490
+ name : str | None = None ,
491
+ innovations : bool = True ,
487
492
observed_state_names : list [str ] | None = None ,
493
+ share_states : bool = False ,
488
494
):
489
495
if observed_state_names is None :
490
496
observed_state_names = ["data" ]
491
497
498
+ self .share_states = share_states
492
499
k_endog = len (observed_state_names )
493
500
494
501
if n is None :
@@ -504,18 +511,20 @@ def __init__(
504
511
# If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
505
512
# get a parameter assigned to it and should just be fixed to zero.
506
513
# Test this way (rather than n == s // 2) to catch cases when n is non-integer.
507
- self .last_state_not_identified = self .season_length / self .n == 2.0
514
+ self .last_state_not_identified = ( self .season_length / self .n ) == 2.0
508
515
self .n_coefs = k_states - int (self .last_state_not_identified )
509
516
510
517
obs_state_idx = np .zeros (k_states )
511
518
obs_state_idx [slice (0 , k_states , 2 )] = 1
512
- obs_state_idx = np .tile (obs_state_idx , k_endog )
519
+ obs_state_idx = np .tile (obs_state_idx , 1 if share_states else k_endog )
513
520
514
521
super ().__init__ (
515
522
name = name ,
516
523
k_endog = k_endog ,
517
- k_states = k_states * k_endog ,
518
- k_posdef = k_states * int (self .innovations ) * k_endog ,
524
+ k_states = k_states if share_states else k_states * k_endog ,
525
+ k_posdef = k_states * int (self .innovations )
526
+ if share_states
527
+ else k_states * int (self .innovations ) * k_endog ,
519
528
observed_state_names = observed_state_names ,
520
529
measurement_error = False ,
521
530
combine_hidden_states = True ,
@@ -524,13 +533,15 @@ def __init__(
524
533
525
534
def make_symbolic_graph (self ) -> None :
526
535
k_endog = self .k_endog
527
- k_states = self .k_states // k_endog
528
- k_posdef = self .k_posdef // k_endog
536
+ k_endog_effective = 1 if self .share_states else k_endog
537
+
538
+ k_states = self .k_states // k_endog_effective
539
+ k_posdef = self .k_posdef // k_endog_effective
529
540
n_coefs = self .n_coefs
530
541
531
542
Z = pt .zeros ((1 , k_states ))[0 , slice (0 , k_states , 2 )].set (1.0 )
532
543
533
- self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
544
+ self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog_effective )])
534
545
535
546
init_state = self .make_and_register_variable (
536
547
f"params_{ self .name } " , shape = (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs )
@@ -539,7 +550,7 @@ def make_symbolic_graph(self) -> None:
539
550
init_state_idx = np .concatenate (
540
551
[
541
552
np .arange (k_states * i , (i + 1 ) * k_states , dtype = int )[:n_coefs ]
542
- for i in range (k_endog )
553
+ for i in range (k_endog_effective )
543
554
],
544
555
axis = 0 ,
545
556
)
@@ -548,11 +559,11 @@ def make_symbolic_graph(self) -> None:
548
559
549
560
T_mats = [_frequency_transition_block (self .season_length , j + 1 ) for j in range (self .n )]
550
561
T = pt .linalg .block_diag (* T_mats )
551
- self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog )])
562
+ self .ssm ["transition" , :, :] = pt .linalg .block_diag (* [T for _ in range (k_endog_effective )])
552
563
553
564
if self .innovations :
554
565
sigma_season = self .make_and_register_variable (
555
- f"sigma_{ self .name } " , shape = () if k_endog == 1 else (k_endog ,)
566
+ f"sigma_{ self .name } " , shape = () if k_endog_effective == 1 else (k_endog_effective ,)
556
567
)
557
568
self .ssm ["selection" , :, :] = pt .eye (self .k_states )
558
569
self .ssm ["state_cov" , :, :] = pt .eye (self .k_posdef ) * pt .repeat (
@@ -561,35 +572,35 @@ def make_symbolic_graph(self) -> None:
561
572
562
573
def populate_component_properties (self ):
563
574
k_endog = self .k_endog
575
+ k_endog_effective = 1 if self .share_states else k_endog
564
576
n_coefs = self .n_coefs
565
577
566
- self .state_names = [
567
- f"{ f } _{ i } _{ self .name } [{ obs_state_name } ]"
568
- for obs_state_name in self .observed_state_names
569
- for i in range (self .n )
570
- for f in ["Cos" , "Sin" ]
571
- ]
572
- # determine which state names correspond to parameters
573
- # all endog variables use same state structure, so we just need
574
- # the first n_coefs state names (which may be less than total if saturated)
575
- param_state_names = [f"{ f } _{ i } _{ self .name } " for i in range (self .n ) for f in ["Cos" , "Sin" ]][
576
- :n_coefs
577
- ]
578
+ base_names = [f"{ f } _{ i } _{ self .name } " for i in range (self .n ) for f in ["Cos" , "Sin" ]]
578
579
579
- self .param_names = [f"params_{ self .name } " ]
580
+ if self .share_states :
581
+ self .state_names = [f"{ name } [shared]" for name in base_names ]
582
+ else :
583
+ self .state_names = [
584
+ f"{ name } [{ obs_state_name } ]"
585
+ for obs_state_name in self .observed_state_names
586
+ for name in base_names
587
+ ]
580
588
589
+ # Trim state names if the model is saturated
590
+ param_state_names = base_names [:n_coefs ]
591
+
592
+ self .param_names = [f"params_{ self .name } " ]
581
593
self .param_dims = {
582
594
f"params_{ self .name } " : (f"state_{ self .name } " ,)
583
- if k_endog == 1
595
+ if k_endog_effective == 1
584
596
else (f"endog_{ self .name } " , f"state_{ self .name } " )
585
597
}
586
-
587
598
self .param_info = {
588
599
f"params_{ self .name } " : {
589
- "shape" : (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs ),
600
+ "shape" : (n_coefs ,) if k_endog_effective == 1 else (k_endog_effective , n_coefs ),
590
601
"constraints" : None ,
591
602
"dims" : (f"state_{ self .name } " ,)
592
- if k_endog == 1
603
+ if k_endog_effective == 1
593
604
else (f"endog_{ self .name } " , f"state_{ self .name } " ),
594
605
}
595
606
}
@@ -607,9 +618,9 @@ def populate_component_properties(self):
607
618
self .param_names += [f"sigma_{ self .name } " ]
608
619
self .shock_names = self .state_names .copy ()
609
620
self .param_info [f"sigma_{ self .name } " ] = {
610
- "shape" : () if k_endog == 1 else (k_endog , ),
621
+ "shape" : () if k_endog_effective == 1 else (k_endog_effective , n_coefs ),
611
622
"constraints" : "Positive" ,
612
- "dims" : None if k_endog == 1 else (f"endog_{ self .name } " ,),
623
+ "dims" : None if k_endog_effective == 1 else (f"endog_{ self .name } " ,),
613
624
}
614
- if k_endog > 1 :
625
+ if k_endog_effective > 1 :
615
626
self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
0 commit comments