Skip to content

Commit f5b0cc0

Browse files
committed
rework multipop r2 stat to avoid nans
1 parent ce7a9a8 commit f5b0cc0

File tree

3 files changed

+154
-43
lines changed

3 files changed

+154
-43
lines changed

c/tests/test_stats.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,10 +2637,13 @@ test_paper_ex_two_site(void)
26372637

26382638
tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
26392639
paper_ex_mutations, paper_ex_individuals, NULL, 0);
2640-
double truth_three_index_tuples[27] = { 1, 1, NAN, 0.1111111111111111,
2641-
0.1111111111111111, NAN, 0.1111111111111111, 0.1111111111111111, NAN,
2642-
0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1,
2643-
0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1 };
2640+
double truth_three_index_tuples[27] = { 1, 1, 0.71111111111111114,
2641+
0.1111111111111111, 0.1111111111111111, -0.0074074074074074042,
2642+
0.1111111111111111, 0.1111111111111111, -0.0074074074074074042,
2643+
0.1111111111111111, 0.1111111111111111, -0.0074074074074074042, 1, 1,
2644+
0.70833333333333326, 1, 1, 0.70833333333333326, 0.1111111111111111,
2645+
0.1111111111111111, -0.0074074074074074042, 1, 1, 0.70833333333333326, 1, 1,
2646+
0.70833333333333326 };
26442647

26452648
tsk_size_t sample_set_sizes[3], num_index_tuples;
26462649
tsk_id_t sample_sets[ts.num_samples * 3], index_tuples[2 * 3] = { 0, 1, 0, 0, 0, 2 };

c/tskit/trees.c

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4941,37 +4941,38 @@ r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
49414941
tsk_size_t result_dim, double *result, void *params)
49424942
{
49434943
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
4944-
double n;
49454944
const double *state_row;
49464945
tsk_size_t k;
49474946
tsk_id_t i, j;
4948-
double p_AB, p_Ab, p_aB, p_A, p_B, D_i, D_j, denom_i, denom_j;
4947+
double ni, w_AB_i, w_Ab_i, w_aB_i, w_A_i, w_B_i, D_i;
4948+
double nj, w_AB_j, w_Ab_j, w_aB_j, w_A_j, w_B_j, D_j;
4949+
double p_A, p_B;
49494950

49504951
for (k = 0; k < result_dim; k++) {
49514952
i = args.set_indexes[2 * k];
49524953
j = args.set_indexes[2 * k + 1];
49534954

4954-
n = (double) args.sample_set_sizes[i];
4955+
ni = (double) args.sample_set_sizes[i];
49554956
state_row = GET_2D_ROW(state, 3, i);
4956-
p_AB = state_row[0] / n;
4957-
p_Ab = state_row[1] / n;
4958-
p_aB = state_row[2] / n;
4959-
p_A = p_AB + p_Ab;
4960-
p_B = p_AB + p_aB;
4961-
D_i = p_AB - (p_A * p_B);
4962-
denom_i = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B));
4957+
w_AB_i = state_row[0];
4958+
w_Ab_i = state_row[1];
4959+
w_aB_i = state_row[2];
4960+
w_A_i = w_AB_i + w_Ab_i;
4961+
w_B_i = w_AB_i + w_aB_i;
4962+
D_i = (ni * w_AB_i - (w_A_i * w_B_i)) / (ni * ni);
49634963

4964-
n = (double) args.sample_set_sizes[j];
4964+
nj = (double) args.sample_set_sizes[j];
49654965
state_row = GET_2D_ROW(state, 3, j);
4966-
p_AB = state_row[0] / n;
4967-
p_Ab = state_row[1] / n;
4968-
p_aB = state_row[2] / n;
4969-
p_A = p_AB + p_Ab;
4970-
p_B = p_AB + p_aB;
4971-
D_j = p_AB - (p_A * p_B);
4972-
denom_j = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B));
4973-
4974-
result[k] = (D_i * D_j) / (denom_i * denom_j);
4966+
w_AB_j = state_row[0];
4967+
w_Ab_j = state_row[1];
4968+
w_aB_j = state_row[2];
4969+
w_A_j = w_AB_j + w_Ab_j;
4970+
w_B_j = w_AB_j + w_aB_j;
4971+
D_j = (nj * w_AB_j - (w_A_j * w_B_j)) / (nj * nj);
4972+
4973+
p_A = (w_A_i + w_A_j) / (ni + nj);
4974+
p_B = (w_B_i + w_B_j) / (ni + nj);
4975+
result[k] = (D_i * D_j) / (p_A * (1 - p_A) * p_B * (1 - p_B));
49754976
}
49764977
return 0;
49774978
}

python/tests/test_ld_matrix.py

Lines changed: 126 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def norm_hap_weighted_ij(
276276
nj = sample_set_sizes[j]
277277
wAB_i = hap_weights[0, i]
278278
wAB_j = hap_weights[0, j]
279-
result[k] = (wAB_i / ni / 2) + (wAB_j / nj / 2)
279+
result[k] = (wAB_i + wAB_j) / (ni + nj)
280280

281281

282282
def norm_total_weighted(
@@ -1034,26 +1034,26 @@ def r2_ij_summary_func(
10341034
for k in range(result_dim):
10351035
i = set_indexes[k][0]
10361036
j = set_indexes[k][1]
1037-
n = sample_set_sizes[i]
1038-
p_AB = state[0, i] / n
1039-
p_Ab = state[1, i] / n
1040-
p_aB = state[2, i] / n
1041-
p_A = p_AB + p_Ab
1042-
p_B = p_AB + p_aB
1043-
D_i = p_AB - (p_A * p_B)
1044-
denom_i = np.sqrt(p_A * p_B * (1 - p_A) * (1 - p_B))
1045-
1046-
n = sample_set_sizes[j]
1047-
p_AB = state[0, j] / n
1048-
p_Ab = state[1, j] / n
1049-
p_aB = state[2, j] / n
1050-
p_A = p_AB + p_Ab
1051-
p_B = p_AB + p_aB
1052-
D_j = p_AB - (p_A * p_B)
1053-
denom_j = np.sqrt(p_A * p_B * (1 - p_A) * (1 - p_B))
1037+
ni = sample_set_sizes[i]
1038+
w_AB_i = state[0, i]
1039+
w_Ab_i = state[1, i]
1040+
w_aB_i = state[2, i]
1041+
w_A_i = w_AB_i + w_Ab_i
1042+
w_B_i = w_AB_i + w_aB_i
1043+
D_i = (ni * w_AB_i - (w_A_i * w_B_i)) / (ni * ni)
10541044

1045+
nj = sample_set_sizes[j]
1046+
w_AB_j = state[0, j]
1047+
w_Ab_j = state[1, j]
1048+
w_aB_j = state[2, j]
1049+
w_A_j = w_AB_j + w_Ab_j
1050+
w_B_j = w_AB_j + w_aB_j
1051+
D_j = (nj * w_AB_j - (w_A_j * w_B_j)) / (nj * nj)
1052+
1053+
p_A = (w_A_i + w_A_j) / (ni + nj)
1054+
p_B = (w_B_i + w_B_j) / (ni + nj)
10551055
with suppress_overflow_div0_warning():
1056-
result[k] = (D_i * D_j) / (denom_i * denom_j)
1056+
result[k] = (D_i * D_j) / (p_A * (1 - p_A) * p_B * (1 - p_B))
10571057

10581058

10591059
def D_summary_func(
@@ -2298,3 +2298,110 @@ def test_two_way_site_ld_matrix(ts, stat):
22982298
ld_matrix(ts, stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]),
22992299
ts.ld_matrix(stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]),
23002300
)
2301+
2302+
2303+
@pytest.mark.parametrize(
2304+
"genotypes,sample_sets,expected",
2305+
[
2306+
(
2307+
# these genotypes are rows from a genotype matrix (sites x samples)
2308+
correlated := np.array(
2309+
[
2310+
[0, 1, 1, 0, 2, 2, 1, 0, 2, 0, 1, 2],
2311+
[1, 2, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0],
2312+
],
2313+
),
2314+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9, 10, 11])),
2315+
np.float64(1.0),
2316+
),
2317+
(
2318+
correlated,
2319+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9, 10])),
2320+
np.float64(0.9708352229780801),
2321+
),
2322+
(
2323+
correlated,
2324+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9])),
2325+
np.float64(0.9526958931720837),
2326+
),
2327+
(
2328+
correlated,
2329+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8])),
2330+
np.float64(1.0),
2331+
),
2332+
(
2333+
correlated,
2334+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7])),
2335+
np.float64(0.7585185185185186),
2336+
),
2337+
(
2338+
correlated,
2339+
(np.array([0, 1, 2, 3, 4, 5]), np.array([6])),
2340+
np.float64(0.0),
2341+
),
2342+
(
2343+
anticorrelated := np.array(
2344+
[
2345+
[0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3],
2346+
[1, 1, 1, 1, 3, 3, 3, 3, 0, 0, 0, 0, 2, 2, 2, 2],
2347+
]
2348+
),
2349+
(
2350+
np.array([0, 2, 4, 6, 8, 10, 12, 14]),
2351+
np.array([1, 3, 5, 7, 9, 11, 13, 15]),
2352+
),
2353+
np.float64(1.0),
2354+
),
2355+
(
2356+
anticorrelated,
2357+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9, 11, 13])),
2358+
np.float64(0.9798566895766568),
2359+
),
2360+
(
2361+
anticorrelated,
2362+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9, 11])),
2363+
np.float64(0.8574999999999999),
2364+
),
2365+
(
2366+
anticorrelated,
2367+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9])),
2368+
np.float64(0.8299777777777777),
2369+
),
2370+
(
2371+
anticorrelated,
2372+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7])),
2373+
np.float64(0.6328124999999999),
2374+
),
2375+
(
2376+
anticorrelated,
2377+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5])),
2378+
np.float64(0.57179616638322),
2379+
),
2380+
(
2381+
anticorrelated,
2382+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3])),
2383+
np.float64(0.0),
2384+
),
2385+
(
2386+
anticorrelated,
2387+
(np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1])),
2388+
np.float64(0.0),
2389+
),
2390+
],
2391+
)
2392+
def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, expected):
2393+
a, b = genotypes
2394+
state_dim = len(sample_sets)
2395+
state = np.zeros((3, state_dim), dtype=int)
2396+
result = np.zeros((max(a) + 1, max(b) + 1, 1))
2397+
norm = np.zeros_like(result)
2398+
params = dict(sample_set_sizes=list(map(len, sample_sets)), set_indexes=[(0, 1)])
2399+
for i, j in np.ndindex(result.shape[:2]):
2400+
for k, ss in enumerate(sample_sets):
2401+
A = a[ss] == i
2402+
B = b[ss] == j
2403+
state[:, k] = (A & B).sum(), (A & ~B).sum(), (~A & B).sum()
2404+
r2_ij_summary_func(state_dim, state, 1, result[i, j], params)
2405+
norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params)
2406+
2407+
np.testing.assert_allclose(expected, (result * norm).sum())

0 commit comments

Comments
 (0)