Skip to content

Commit ac19376

Browse files
committed
Remove fixed values (can be got from the TS)
1 parent 4bf5db8 commit ac19376

File tree

1 file changed

+15
-37
lines changed

1 file changed

+15
-37
lines changed

tests/test_functions.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tsdate.date import (SpansBySamples, PriorParams, LIN, LOG,
3737
ConditionalCoalescentTimes, fill_prior, Likelihoods,
3838
LogLikelihoods, LogLikelihoodsStreaming, InOutAlgorithms,
39-
Prior, gamma_approx, constrain_ages_topo) # NOQA
39+
NodeGridValues, gamma_approx, constrain_ages_topo) # NOQA
4040

4141
from tests import utility_functions
4242

@@ -690,32 +690,24 @@ def test_logsumexp_streaming(self):
690690
np.log(ll_sum)))
691691

692692

693-
class TestPriorClass(unittest.TestCase):
693+
class TestNodeGridValuesClass(unittest.TestCase):
694694
def test_init(self):
695-
nodetimes = np.ones(5)
696695
nonfixed_ids = np.array([3, 2])
697696
timepoints = np.array(range(10))
698-
store = Prior(
699-
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=6)
697+
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=6)
700698
self.assertEquals(store.grid_data.shape, (len(nonfixed_ids), len(timepoints)))
701-
self.assertEquals(len(store.fixed_times), (len(nodetimes)-len(nonfixed_ids)))
702699
self.assertTrue(np.all(store.grid_data == 6))
703-
self.assertTrue(np.all(store.fixed_times == 1))
704-
for i in range(len(nodetimes)):
700+
for i in range(np.max(nonfixed_ids)+1):
705701
if i in nonfixed_ids:
706702
self.assertTrue(np.all(store[i] == 6))
707-
self.assertRaises(IndexError, store.fixed_time, i)
708703
else:
709-
self.assertEqual(store.fixed_time(i), 1)
710704
with self.assertRaises(IndexError):
711705
_ = store[i]
712706

713707
def test_probability_spaces(self):
714-
nodetimes = np.ones(5)
715708
nonfixed_ids = np.array([3, 4])
716709
timepoints = np.array(range(10))
717-
store = Prior(
718-
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=0.5)
710+
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=0.5)
719711
self.assertTrue(np.all(store.grid_data == 0.5))
720712
store.force_probability_space(LIN)
721713
self.assertTrue(np.all(store.grid_data == 0.5))
@@ -728,13 +720,12 @@ def test_probability_spaces(self):
728720
self.assertRaises(ValueError, store.force_probability_space, "foobar")
729721

730722
def test_set_and_get(self):
731-
nodetimes = np.ones(5)
732723
timepoints = [0, 1.1]
733724
fill = {}
734725
for nonfixed_ids in ([3, 4], [0]):
735726
np.random.seed(1)
736-
store = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
737-
for i in range(len(nodetimes)):
727+
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
728+
for i in range(5):
738729
fill[i] = np.random.random(len(store.timepoints))
739730
if i in nonfixed_ids:
740731
store[i] = fill[i]
@@ -746,33 +737,21 @@ def test_set_and_get(self):
746737

747738
def test_bad_init(self):
748739
timepoints = [0, 1.2, 2]
749-
nodetimes = np.ones(5)
750740
nonfixed_ids = [4, 0]
751-
Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
752-
# ids > nodetimes
753-
self.assertRaises(
754-
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 5])
741+
NodeGridValues(timepoints, gridnodes=nonfixed_ids)
755742
# duplicate ids
756-
self.assertRaises(
757-
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 4, 0])
743+
self.assertRaises(ValueError, NodeGridValues, timepoints, gridnodes=[4, 4, 0])
758744
# bad ids
759745
self.assertRaises(
760-
ValueError, Prior, timepoints, nodetimes=nodetimes,
761-
gridnodes=np.array([[1, 4], [2, 0]]))
762-
self.assertRaises(
763-
OverflowError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[-1, 4])
746+
ValueError, NodeGridValues, timepoints, gridnodes=np.array([[1, 4], [2, 0]]))
747+
self.assertRaises(OverflowError, NodeGridValues, timepoints, gridnodes=[-1, 4])
764748
# bad timepoint
765-
self.assertRaises(
766-
ValueError, Prior, [], nodetimes=nodetimes, gridnodes=nonfixed_ids)
767-
# bad nodetimes
768-
self.assertRaises(
769-
ValueError, Prior, timepoints, nodetimes=[], gridnodes=nonfixed_ids)
749+
self.assertRaises(ValueError, NodeGridValues, [], gridnodes=nonfixed_ids)
770750

771751
def test_clone(self):
772752
timepoints = [0, 1]
773-
nodetimes = np.ones(5)
774753
nonfixed_ids = [3, 4]
775-
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
754+
orig = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
776755
orig[3] = np.array([1, 2])
777756
orig[4] = np.array([4, 3])
778757
# test with np.zeros
@@ -785,17 +764,16 @@ def test_clone(self):
785764
self.assertTrue(np.all(clone.grid_data == 5))
786765

787766
clone = orig.clone_grid_with_new_data(np.array([[1, 2], [4, 3]]))
788-
for i in range(len(nodetimes)):
767+
for i in range(np.max(nonfixed_ids)+1):
789768
if i in nonfixed_ids:
790769
self.assertTrue(np.all(clone[i] == orig[i]))
791770
else:
792771
self.assertRaises(IndexError, clone.__getitem__, i)
793772

794773
def test_bad_clone(self):
795-
nodetimes = np.zeros(10)
796774
ids = np.array([3, 4])
797775
timepoints = np.array([0, 1.2])
798-
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=ids)
776+
orig = NodeGridValues(timepoints, gridnodes=ids)
799777
self.assertRaises(
800778
ValueError, orig.clone_grid_with_new_data, np.array([[1, 2, 3], [4, 5, 6]]))
801779

0 commit comments

Comments
 (0)