Skip to content

Commit 3d4fc51

Browse files
astheeggeggsmergify[bot]
authored andcommitted
added forwards backwards testing
Added forwards backwards testing and now include missingness appropriately added missingness to diploid LS added some fixes for flake errors added missingness to diploid viterbi changed test_genotype_matching_fb.py remove stray print removed caps for bool EQUAL_BOTH_HOM etc Removed caps for EQUAL_BOTH_HOM etc in Viterbi removed unused imported function
1 parent 93cd81f commit 3d4fc51

File tree

2 files changed

+133
-69
lines changed

2 files changed

+133
-69
lines changed

python/tests/test_genotype_matching_fb.py

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Simulation
21
import copy
32
import itertools
43

@@ -14,6 +13,8 @@
1413
REF_HOM_OBS_HET = 1
1514
REF_HET_OBS_HOM = 2
1615

16+
MISSING = -1
17+
1718

1819
def mirror_coordinates(ts):
1920
"""
@@ -411,6 +412,7 @@ def update_probabilities(self, site, genotype_state):
411412
]
412413

413414
query_is_het = genotype_state == 1
415+
query_is_missing = genotype_state == MISSING
414416

415417
for st1 in T:
416418
u1 = st1.tree_node
@@ -444,6 +446,7 @@ def update_probabilities(self, site, genotype_state):
444446
match,
445447
template_is_het,
446448
query_is_het,
449+
query_is_missing,
447450
)
448451

449452
# This will ensure that allelic_state[:n] is filled
@@ -561,7 +564,14 @@ def compute_normalisation_factor_dict(self):
561564
raise NotImplementedError()
562565

563566
def compute_next_probability_dict(
564-
self, site_id, p_last, inner_summation, is_match, template_is_het, query_is_het
567+
self,
568+
site_id,
569+
p_last,
570+
inner_summation,
571+
is_match,
572+
template_is_het,
573+
query_is_het,
574+
query_is_missing,
565575
):
566576
raise NotImplementedError()
567577

@@ -670,41 +680,45 @@ def compute_next_probability_dict(
670680
is_match,
671681
template_is_het,
672682
query_is_het,
683+
query_is_missing,
673684
):
674685
rho = self.rho[site_id]
675686
mu = self.mu[site_id]
676687
n = self.ts.num_samples
677688

678-
template_is_hom = np.logical_not(template_is_het)
679-
query_is_hom = np.logical_not(query_is_het)
680-
681-
EQUAL_BOTH_HOM = np.logical_and(
682-
np.logical_and(is_match, template_is_hom), query_is_hom
683-
)
684-
UNEQUAL_BOTH_HOM = np.logical_and(
685-
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
686-
)
687-
BOTH_HET = np.logical_and(template_is_het, query_is_het)
688-
REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het)
689-
REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom)
690-
691689
p_t = (
692690
(rho / n) ** 2
693691
+ ((1 - rho) * (rho / n)) * inner_normalisation_factor
694692
+ (1 - rho) ** 2 * p_last
695693
)
696-
p_e = (
697-
EQUAL_BOTH_HOM * (1 - mu) ** 2
698-
+ UNEQUAL_BOTH_HOM * (mu**2)
699-
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
700-
+ REF_HET_OBS_HOM * (mu * (1 - mu))
701-
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
702-
)
694+
695+
if query_is_missing:
696+
p_e = 1
697+
else:
698+
query_is_hom = np.logical_not(query_is_het)
699+
template_is_hom = np.logical_not(template_is_het)
700+
701+
equal_both_hom = np.logical_and(
702+
np.logical_and(is_match, template_is_hom), query_is_hom
703+
)
704+
unequal_both_hom = np.logical_and(
705+
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
706+
)
707+
both_het = np.logical_and(template_is_het, query_is_het)
708+
ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het)
709+
ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom)
710+
711+
p_e = (
712+
equal_both_hom * (1 - mu) ** 2
713+
+ unequal_both_hom * (mu**2)
714+
+ ref_hom_obs_het * (2 * mu * (1 - mu))
715+
+ ref_het_obs_hom * (mu * (1 - mu))
716+
+ both_het * ((1 - mu) ** 2 + mu**2)
717+
)
703718

704719
return p_t * p_e
705720

706721

707-
# DEV: Sort this
708722
class BackwardAlgorithm(LsHmmAlgorithm):
709723
"""Runs the Li and Stephens forward algorithm."""
710724

@@ -737,29 +751,35 @@ def compute_next_probability_dict(
737751
is_match,
738752
template_is_het,
739753
query_is_het,
754+
query_is_missing,
740755
):
741756
mu = self.mu[site_id]
742757

743758
template_is_hom = np.logical_not(template_is_het)
744-
query_is_hom = np.logical_not(query_is_het)
745759

746-
EQUAL_BOTH_HOM = np.logical_and(
747-
np.logical_and(is_match, template_is_hom), query_is_hom
748-
)
749-
UNEQUAL_BOTH_HOM = np.logical_and(
750-
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
751-
)
752-
BOTH_HET = np.logical_and(template_is_het, query_is_het)
753-
REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het)
754-
REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom)
755-
756-
p_e = (
757-
EQUAL_BOTH_HOM * (1 - mu) ** 2
758-
+ UNEQUAL_BOTH_HOM * (mu**2)
759-
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
760-
+ REF_HET_OBS_HOM * (mu * (1 - mu))
761-
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
762-
)
760+
if query_is_missing:
761+
p_e = 1
762+
else:
763+
query_is_hom = np.logical_not(query_is_het)
764+
765+
equal_both_hom = np.logical_and(
766+
np.logical_and(is_match, template_is_hom), query_is_hom
767+
)
768+
unequal_both_hom = np.logical_and(
769+
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
770+
)
771+
both_het = np.logical_and(template_is_het, query_is_het)
772+
ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het)
773+
ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom)
774+
775+
p_e = (
776+
equal_both_hom * (1 - mu) ** 2
777+
+ unequal_both_hom * (mu**2)
778+
+ ref_hom_obs_het * (2 * mu * (1 - mu))
779+
+ ref_het_obs_hom * (mu * (1 - mu))
780+
+ both_het * ((1 - mu) ** 2 + mu**2)
781+
)
782+
763783
return p_next * p_e
764784

765785

@@ -797,18 +817,33 @@ def example_genotypes(self, ts):
797817
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
798818
H = H[:, 2:]
799819

820+
genotypes = [
821+
s,
822+
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
823+
]
824+
825+
s_tmp = s.copy()
826+
s_tmp[0, -1] = MISSING
827+
genotypes.append(s_tmp)
828+
s_tmp = s.copy()
829+
s_tmp[0, ts.num_sites // 2] = MISSING
830+
genotypes.append(s_tmp)
831+
s_tmp = s.copy()
832+
s_tmp[0, :] = MISSING
833+
genotypes.append(s_tmp)
834+
800835
m = ts.get_num_sites()
801836
n = H.shape[1]
802837

803838
G = np.zeros((m, n, n))
804839
for i in range(m):
805840
G[i, :, :] = np.add.outer(H[i, :], H[i, :])
806841

807-
return H, G, s
842+
return H, G, genotypes
808843

809844
def example_parameters_genotypes(self, ts, seed=42):
810845
np.random.seed(seed)
811-
H, G, s = self.example_genotypes(ts)
846+
H, G, genotypes = self.example_genotypes(ts)
812847
n = H.shape[1]
813848
m = ts.get_num_sites()
814849

@@ -819,13 +854,16 @@ def example_parameters_genotypes(self, ts, seed=42):
819854

820855
e = self.genotype_emission(mu, m)
821856

822-
yield n, m, G, s, e, r, mu
857+
for s in genotypes:
858+
yield n, m, G, s, e, r, mu
823859

824860
# Mixture of random and extremes
825861
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
826862
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]
827863

828-
for r, mu in itertools.product(rs, mus):
864+
e = self.genotype_emission(mu, m)
865+
866+
for s, r, mu in itertools.product(genotypes, rs, mus):
829867
r[0] = 0
830868
e = self.genotype_emission(mu, m)
831869
yield n, m, G, s, e, r, mu

python/tests/test_genotype_matching_viterbi.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
REF_HOM_OBS_HET = 1
1414
REF_HET_OBS_HOM = 2
1515

16+
MISSING = -1
17+
1618

1719
class ValueTransition:
1820
"""Simple struct holding value transition values."""
@@ -390,6 +392,7 @@ def update_probabilities(self, site, genotype_state):
390392
]
391393

392394
query_is_het = genotype_state == 1
395+
query_is_missing = genotype_state == MISSING
393396

394397
for st1 in T:
395398
u1 = st1.tree_node
@@ -423,6 +426,7 @@ def update_probabilities(self, site, genotype_state):
423426
match,
424427
template_is_het,
425428
query_is_het,
429+
query_is_missing,
426430
u1,
427431
u2,
428432
)
@@ -486,6 +490,7 @@ def compute_next_probability_dict(
486490
is_match,
487491
template_is_het,
488492
query_is_het,
493+
query_is_missing,
489494
node_1,
490495
node_2,
491496
):
@@ -830,6 +835,7 @@ def compute_next_probability_dict(
830835
is_match,
831836
template_is_het,
832837
query_is_het,
838+
query_is_missing,
833839
node_1,
834840
node_2,
835841
):
@@ -841,26 +847,28 @@ def compute_next_probability_dict(
841847
double_recombination_required = False
842848
single_recombination_required = False
843849

844-
template_is_hom = np.logical_not(template_is_het)
845-
query_is_hom = np.logical_not(query_is_het)
846-
847-
EQUAL_BOTH_HOM = np.logical_and(
848-
np.logical_and(is_match, template_is_hom), query_is_hom
849-
)
850-
UNEQUAL_BOTH_HOM = np.logical_and(
851-
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
852-
)
853-
BOTH_HET = np.logical_and(template_is_het, query_is_het)
854-
REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het)
855-
REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom)
856-
857-
p_e = (
858-
EQUAL_BOTH_HOM * (1 - mu) ** 2
859-
+ UNEQUAL_BOTH_HOM * (mu**2)
860-
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
861-
+ REF_HET_OBS_HOM * (mu * (1 - mu))
862-
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
863-
)
850+
if query_is_missing:
851+
p_e = 1
852+
else:
853+
template_is_hom = np.logical_not(template_is_het)
854+
query_is_hom = np.logical_not(query_is_het)
855+
equal_both_hom = np.logical_and(
856+
np.logical_and(is_match, template_is_hom), query_is_hom
857+
)
858+
unequal_both_hom = np.logical_and(
859+
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
860+
)
861+
both_het = np.logical_and(template_is_het, query_is_het)
862+
ref_hom_obs_het = np.logical_and(template_is_hom, query_is_het)
863+
ref_het_obs_hom = np.logical_and(template_is_het, query_is_hom)
864+
865+
p_e = (
866+
equal_both_hom * (1 - mu) ** 2
867+
+ unequal_both_hom * (mu**2)
868+
+ ref_hom_obs_het * (2 * mu * (1 - mu))
869+
+ ref_het_obs_hom * (mu * (1 - mu))
870+
+ both_het * ((1 - mu) ** 2 + mu**2)
871+
)
864872

865873
no_switch = (1 - r) ** 2 + 2 * (r_n * (1 - r)) + r_n**2
866874
single_switch = r_n * (1 - r) + r_n**2
@@ -919,18 +927,33 @@ def example_genotypes(self, ts):
919927
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
920928
H = H[:, 2:]
921929

930+
genotypes = [
931+
s,
932+
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
933+
]
934+
935+
s_tmp = s.copy()
936+
s_tmp[0, -1] = MISSING
937+
genotypes.append(s_tmp)
938+
s_tmp = s.copy()
939+
s_tmp[0, ts.num_sites // 2] = MISSING
940+
genotypes.append(s_tmp)
941+
s_tmp = s.copy()
942+
s_tmp[0, :] = MISSING
943+
genotypes.append(s_tmp)
944+
922945
m = ts.get_num_sites()
923946
n = H.shape[1]
924947

925948
G = np.zeros((m, n, n))
926949
for i in range(m):
927950
G[i, :, :] = np.add.outer(H[i, :], H[i, :])
928951

929-
return H, G, s
952+
return H, G, genotypes
930953

931954
def example_parameters_genotypes(self, ts, seed=42):
932955
np.random.seed(seed)
933-
H, G, s = self.example_genotypes(ts)
956+
H, G, genotypes = self.example_genotypes(ts)
934957
n = H.shape[1]
935958
m = ts.get_num_sites()
936959

@@ -941,13 +964,16 @@ def example_parameters_genotypes(self, ts, seed=42):
941964

942965
e = self.genotype_emission(mu, m)
943966

944-
yield n, m, G, s, e, r, mu
967+
for s in genotypes:
968+
yield n, m, G, s, e, r, mu
945969

946970
# Mixture of random and extremes
947971
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
948972
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]
949973

950-
for r, mu in itertools.product(rs, mus):
974+
e = self.genotype_emission(mu, m)
975+
976+
for s, r, mu in itertools.product(genotypes, rs, mus):
951977
r[0] = 0
952978
e = self.genotype_emission(mu, m)
953979
yield n, m, G, s, e, r, mu

0 commit comments

Comments
 (0)