36
36
from tsdate .date import (SpansBySamples , PriorParams , LIN , LOG ,
37
37
ConditionalCoalescentTimes , fill_prior , Likelihoods ,
38
38
LogLikelihoods , LogLikelihoodsStreaming , InOutAlgorithms ,
39
- Prior , gamma_approx , constrain_ages_topo ) # NOQA
39
+ NodeGridValues , gamma_approx , constrain_ages_topo ) # NOQA
40
40
41
41
from tests import utility_functions
42
42
@@ -690,32 +690,24 @@ def test_logsumexp_streaming(self):
690
690
np .log (ll_sum )))
691
691
692
692
693
- class TestPriorClass (unittest .TestCase ):
693
+ class TestNodeGridValuesClass (unittest .TestCase ):
694
694
def test_init (self ):
695
- nodetimes = np .ones (5 )
696
695
nonfixed_ids = np .array ([3 , 2 ])
697
696
timepoints = np .array (range (10 ))
698
- store = Prior (
699
- timepoints , nodetimes = nodetimes , gridnodes = nonfixed_ids , fill_value = 6 )
697
+ store = NodeGridValues (timepoints , gridnodes = nonfixed_ids , fill_value = 6 )
700
698
self .assertEquals (store .grid_data .shape , (len (nonfixed_ids ), len (timepoints )))
701
- self .assertEquals (len (store .fixed_times ), (len (nodetimes )- len (nonfixed_ids )))
702
699
self .assertTrue (np .all (store .grid_data == 6 ))
703
- self .assertTrue (np .all (store .fixed_times == 1 ))
704
- for i in range (len (nodetimes )):
700
+ for i in range (np .max (nonfixed_ids )+ 1 ):
705
701
if i in nonfixed_ids :
706
702
self .assertTrue (np .all (store [i ] == 6 ))
707
- self .assertRaises (IndexError , store .fixed_time , i )
708
703
else :
709
- self .assertEqual (store .fixed_time (i ), 1 )
710
704
with self .assertRaises (IndexError ):
711
705
_ = store [i ]
712
706
713
707
def test_probability_spaces (self ):
714
- nodetimes = np .ones (5 )
715
708
nonfixed_ids = np .array ([3 , 4 ])
716
709
timepoints = np .array (range (10 ))
717
- store = Prior (
718
- timepoints , nodetimes = nodetimes , gridnodes = nonfixed_ids , fill_value = 0.5 )
710
+ store = NodeGridValues (timepoints , gridnodes = nonfixed_ids , fill_value = 0.5 )
719
711
self .assertTrue (np .all (store .grid_data == 0.5 ))
720
712
store .force_probability_space (LIN )
721
713
self .assertTrue (np .all (store .grid_data == 0.5 ))
@@ -728,13 +720,12 @@ def test_probability_spaces(self):
728
720
self .assertRaises (ValueError , store .force_probability_space , "foobar" )
729
721
730
722
def test_set_and_get (self ):
731
- nodetimes = np .ones (5 )
732
723
timepoints = [0 , 1.1 ]
733
724
fill = {}
734
725
for nonfixed_ids in ([3 , 4 ], [0 ]):
735
726
np .random .seed (1 )
736
- store = Prior (timepoints , nodetimes = nodetimes , gridnodes = nonfixed_ids )
737
- for i in range (len ( nodetimes ) ):
727
+ store = NodeGridValues (timepoints , gridnodes = nonfixed_ids )
728
+ for i in range (5 ):
738
729
fill [i ] = np .random .random (len (store .timepoints ))
739
730
if i in nonfixed_ids :
740
731
store [i ] = fill [i ]
@@ -746,33 +737,21 @@ def test_set_and_get(self):
746
737
747
738
def test_bad_init (self ):
748
739
timepoints = [0 , 1.2 , 2 ]
749
- nodetimes = np .ones (5 )
750
740
nonfixed_ids = [4 , 0 ]
751
- Prior (timepoints , nodetimes = nodetimes , gridnodes = nonfixed_ids )
752
- # ids > nodetimes
753
- self .assertRaises (
754
- ValueError , Prior , timepoints , nodetimes = nodetimes , gridnodes = [4 , 5 ])
741
+ NodeGridValues (timepoints , gridnodes = nonfixed_ids )
755
742
# duplicate ids
756
- self .assertRaises (
757
- ValueError , Prior , timepoints , nodetimes = nodetimes , gridnodes = [4 , 4 , 0 ])
743
+ self .assertRaises (ValueError , NodeGridValues , timepoints , gridnodes = [4 , 4 , 0 ])
758
744
# bad ids
759
745
self .assertRaises (
760
- ValueError , Prior , timepoints , nodetimes = nodetimes ,
761
- gridnodes = np .array ([[1 , 4 ], [2 , 0 ]]))
762
- self .assertRaises (
763
- OverflowError , Prior , timepoints , nodetimes = nodetimes , gridnodes = [- 1 , 4 ])
746
+ ValueError , NodeGridValues , timepoints , gridnodes = np .array ([[1 , 4 ], [2 , 0 ]]))
747
+ self .assertRaises (OverflowError , NodeGridValues , timepoints , gridnodes = [- 1 , 4 ])
764
748
# bad timepoint
765
- self .assertRaises (
766
- ValueError , Prior , [], nodetimes = nodetimes , gridnodes = nonfixed_ids )
767
- # bad nodetimes
768
- self .assertRaises (
769
- ValueError , Prior , timepoints , nodetimes = [], gridnodes = nonfixed_ids )
749
+ self .assertRaises (ValueError , NodeGridValues , [], gridnodes = nonfixed_ids )
770
750
771
751
def test_clone (self ):
772
752
timepoints = [0 , 1 ]
773
- nodetimes = np .ones (5 )
774
753
nonfixed_ids = [3 , 4 ]
775
- orig = Prior (timepoints , nodetimes = nodetimes , gridnodes = nonfixed_ids )
754
+ orig = NodeGridValues (timepoints , gridnodes = nonfixed_ids )
776
755
orig [3 ] = np .array ([1 , 2 ])
777
756
orig [4 ] = np .array ([4 , 3 ])
778
757
# test with np.zeros
@@ -785,17 +764,16 @@ def test_clone(self):
785
764
self .assertTrue (np .all (clone .grid_data == 5 ))
786
765
787
766
clone = orig .clone_grid_with_new_data (np .array ([[1 , 2 ], [4 , 3 ]]))
788
- for i in range (len ( nodetimes ) ):
767
+ for i in range (np . max ( nonfixed_ids ) + 1 ):
789
768
if i in nonfixed_ids :
790
769
self .assertTrue (np .all (clone [i ] == orig [i ]))
791
770
else :
792
771
self .assertRaises (IndexError , clone .__getitem__ , i )
793
772
794
773
def test_bad_clone (self ):
795
- nodetimes = np .zeros (10 )
796
774
ids = np .array ([3 , 4 ])
797
775
timepoints = np .array ([0 , 1.2 ])
798
- orig = Prior (timepoints , nodetimes = nodetimes , gridnodes = ids )
776
+ orig = NodeGridValues (timepoints , gridnodes = ids )
799
777
self .assertRaises (
800
778
ValueError , orig .clone_grid_with_new_data , np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]]))
801
779
0 commit comments