Skip to content

Commit 09ae207

Browse files
Speed up diploid LS tests by reducing parameter space checked
1 parent 124cb51 commit 09ae207

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

python/tests/test_genotype_matching.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ def process_site(
524524
s = self.output.normalisation_factor[site.id]
525525
for st1 in self.T:
526526
if st1.tree_node != tskit.NULL:
527-
528527
for st2 in st1.value_list:
529528
st2.value = (
530529
((self.rho[site.id] / self.ts.num_samples) ** 2)
@@ -1198,7 +1197,6 @@ def genotype_emission(self, mu, m):
11981197
return e
11991198

12001199
def example_genotypes(self, ts):
1201-
12021200
H = ts.genotype_matrix()
12031201
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
12041202
H = H[:, 2:]
@@ -1247,9 +1245,8 @@ def example_parameters_genotypes(self, ts, seed=42):
12471245
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
12481246
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]
12491247

1250-
e = self.genotype_emission(mu, m)
1251-
1252-
for s, r, mu in itertools.product(genotypes, rs, mus):
1248+
s = genotypes[0]
1249+
for r, mu in itertools.product(rs, mus):
12531250
r[0] = 0
12541251
e = self.genotype_emission(mu, m)
12551252
yield n, m, G, s, e, r, mu
@@ -1267,36 +1264,44 @@ def test_simple_n_10_no_recombination(self):
12671264
assert ts.num_sites > 3
12681265
self.verify(ts)
12691266

1270-
def test_simple_n_10_no_recombination_high_mut(self):
1271-
ts = msprime.simulate(10, recombination_rate=0, mutation_rate=3, random_seed=42)
1272-
assert ts.num_sites > 3
1273-
self.verify(ts)
1274-
1275-
def test_simple_n_10_no_recombination_higher_mut(self):
1276-
ts = msprime.simulate(20, recombination_rate=0, mutation_rate=3, random_seed=42)
1277-
assert ts.num_sites > 3
1278-
self.verify(ts)
1279-
12801267
def test_simple_n_6(self):
12811268
ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42)
12821269
assert ts.num_sites > 5
12831270
self.verify(ts)
12841271

1285-
def test_simple_n_8(self):
1286-
ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42)
1287-
assert ts.num_sites > 5
1288-
self.verify(ts)
1289-
12901272
def test_simple_n_8_high_recombination(self):
12911273
ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42)
12921274
assert ts.num_trees > 15
12931275
assert ts.num_sites > 5
12941276
self.verify(ts)
12951277

1296-
def test_simple_n_16(self):
1297-
ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42)
1298-
assert ts.num_sites > 5
1299-
self.verify(ts)
1278+
# FIXME Reducing the number of test cases here as they take a long time to run,
1279+
# and we will want to refactor the test infrastructure when implementing these
1280+
# diploid methods in the library.
1281+
1282+
# def test_simple_n_10_no_recombination_high_mut(self):
1283+
# ts = msprime.simulate(
1284+
# 10, recombination_rate=0, mutation_rate=3, random_seed=42)
1285+
# assert ts.num_sites > 3
1286+
# self.verify(ts)
1287+
1288+
# def test_simple_n_10_no_recombination_higher_mut(self):
1289+
# ts = msprime.simulate(
1290+
# 20, recombination_rate=0, mutation_rate=3, random_seed=42)
1291+
# assert ts.num_sites > 3
1292+
# self.verify(ts)
1293+
1294+
# def test_simple_n_8(self):
1295+
# ts = msprime.simulate(
1296+
# 8, recombination_rate=2, mutation_rate=5, random_seed=42)
1297+
# assert ts.num_sites > 5
1298+
# self.verify(ts)
1299+
1300+
# def test_simple_n_16(self):
1301+
# ts = msprime.simulate(
1302+
# 16, recombination_rate=2, mutation_rate=5, random_seed=42)
1303+
# assert ts.num_sites > 5
1304+
# self.verify(ts)
13001305

13011306
def verify(self, ts):
13021307
raise NotImplementedError()
@@ -1436,7 +1441,6 @@ class TestTreeViterbiDip(VitAlgorithmBase):
14361441
"""
14371442

14381443
def verify(self, ts):
1439-
14401444
for n, m, _, s, _, r, mu in self.example_parameters_genotypes(ts):
14411445
# Note, need to remove the first sample from the ts, and ensure that
14421446
# invariant sites aren't removed.
@@ -1450,14 +1454,14 @@ def verify(self, ts):
14501454
)
14511455
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
14521456
phased_path, ll = ls.viterbi(
1453-
G_check, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False
1457+
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
14541458
)
14551459
path_ll_matrix = ls.path_ll(
14561460
G_check,
14571461
s,
14581462
phased_path,
14591463
r,
1460-
mutation_rate=mu,
1464+
p_mutation=mu,
14611465
scale_mutation_based_on_n_alleles=False,
14621466
)
14631467

@@ -1472,7 +1476,7 @@ def verify(self, ts):
14721476
s,
14731477
np.transpose(path_tree_dict),
14741478
r,
1475-
mutation_rate=mu,
1479+
p_mutation=mu,
14761480
scale_mutation_based_on_n_alleles=False,
14771481
)
14781482

0 commit comments

Comments
 (0)