diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index b3515ef2c5..74aa961aa4 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2637,12 +2637,15 @@ test_paper_ex_two_site(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); + double truth_three_index_tuples[27] = { 1, 1, NAN, 0.1111111111111111, + 0.1111111111111111, NAN, 0.1111111111111111, 0.1111111111111111, NAN, + 0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1, + 0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1 }; - tsk_size_t sample_set_sizes[3]; - tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t sample_set_sizes[3], num_index_tuples; + tsk_id_t sample_sets[ts.num_samples * 3], index_tuples[2 * 3] = { 0, 1, 0, 0, 0, 2 }; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); // First sample set contains all of the samples sample_set_sizes[0] = ts.num_samples; @@ -2651,14 +2654,13 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } result_size = num_sites * num_sites; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); @@ -2672,7 +2674,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); @@ -2686,15 +2688,48 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan( result_size * num_sample_sets, result, truth_three_sets); + // Two-way stats: we'll reuse all sample sets from the first 3 tests + num_sample_sets = 3; + + num_index_tuples = 1; + // We'll compute r2 between sample set 0 and sample set 1 + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_one_set); + + // Compare sample sets [(0, 1), (0, 0)] + num_index_tuples = 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_two_sets); + + // Compare sample sets [(0, 1), (0, 0), (0, 2)] + num_index_tuples = 3; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan( + result_size * num_index_tuples, result, truth_three_index_tuples); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -2705,42 +2740,44 @@ test_paper_ex_two_branch(void) double result[27]; tsk_size_t i, result_size, num_sample_sets; tsk_flags_t options = 0; - double truth_one_set[9] - = { 0.001066666666666695, -0.00012666666666665688, -0.0001266666666666534, - -0.00012666666666665688, 6.016666666665456e-05, 6.016666666665629e-05, - -0.0001266666666666534, 6.016666666665629e-05, 6.016666666665629e-05 }; - double truth_two_sets[18] - = { 0.001066666666666695, 0.001066666666666695, -0.00012666666666665688, - -0.00012666666666665688, -0.0001266666666666534, -0.0001266666666666534, - -0.00012666666666665688, -0.00012666666666665688, 6.016666666665456e-05, - 6.016666666665456e-05, 6.016666666665629e-05, 6.016666666665629e-05, - -0.0001266666666666534, -0.0001266666666666534, 6.016666666665629e-05, - 6.016666666665629e-05, 6.016666666665629e-05, 6.016666666665629e-05 }; - double truth_three_sets[27] = { 0.001066666666666695, 0.001066666666666695, NAN, - -0.00012666666666665688, -0.00012666666666665688, NAN, -0.0001266666666666534, - -0.0001266666666666534, NAN, -0.00012666666666665688, -0.00012666666666665688, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665629e-05, - 6.016666666665629e-05, NAN, -0.0001266666666666534, -0.0001266666666666534, NAN, - 6.016666666665629e-05, 6.016666666665629e-05, NAN, 6.016666666665629e-05, - 6.016666666665629e-05, NAN }; - double truth_positions_subset_1[12] = { 0.001066666666666695, 0.001066666666666695, - NAN, 0.001066666666666695, 0.001066666666666695, NAN, 0.001066666666666695, - 0.001066666666666695, NAN, 0.001066666666666695, 0.001066666666666695, NAN }; - double truth_positions_subset_2[12] = { 6.016666666665456e-05, 6.016666666665456e-05, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, - 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; - double truth_positions_subset_3[12] = { 6.016666666665456e-05, 6.016666666665456e-05, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, - 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + double truth_one_set[9] = { 0.008890640625, 0.004624203125, 0.005215703125, + 0.004624203125, 0.003737578125, 0.004377078125, 0.005215703125, + 0.004377078124999999, 0.005160578124999998 }; + double truth_two_sets[18] = { 0.008890640625, 0.008890640625, 0.004624203125, + 0.004624203125, 0.005215703125, 0.005215703125, 0.004624203125, 0.004624203125, + 0.003737578125, 0.003737578125, 0.004377078125, 0.004377078125, 0.005215703125, + 0.005215703125, 0.004377078124999999, 0.004377078124999999, 0.005160578124999998, + 0.005160578124999998 }; + double truth_three_sets[27] + = { 0.008890640625, 0.008890640625, 0.007225, 0.004624203125000001, + 0.004624203125, 0.007225, 0.005215703125000002, 0.005215703125, 0.008585, + 0.004624203125, 0.004624203125, 0.007225, 0.003737578125, 0.003737578125, + 0.007225, 0.004377078125, 0.004377078125, 0.008585, 0.005215703125, + 0.005215703125, 0.008585, 0.004377078124999999, 0.004377078124999999, + 0.008585, 0.005160578124999998, 0.005160578124999998, 0.010201 }; + double truth_positions_subset_1[12] = { 0.008890640625, 0.008890640625, 0.007225, + 0.008890640625, 0.008890640625, 0.007225, 0.008890640625, 0.008890640625, + 0.007225, 0.008890640625, 0.008890640625, 0.007225 }; + double truth_positions_subset_2[12] = { 0.003737578125, 0.003737578125, 0.007225, + 0.003737578125, 0.003737578125, 0.007225, 0.003737578125, 0.003737578125, + 0.007225, 0.003737578125, 0.003737578125, 0.007225 }; + double truth_positions_subset_3[12] = { 0.005160578125, 0.005160578125, 0.010201, + 0.005160578125, 0.005160578125, 0.010201, 0.005160578125, 0.005160578125, + 0.010201, 0.005160578125, 0.005160578125, 0.010201 }; + double truth_three_index_tuples[27] = { 0.008890640625, 0.008890640625, 0.0039125, + 0.004624203125, 0.004624203125, 0.0038125, 0.005215703125, 0.005215703125, + 0.0045725, 0.004624203125, 0.004624203125, 0.0038125, 0.003737578125, + 0.003737578125, 0.0040125, 0.004377078125, 0.004377078125, 0.0048525, + 0.005215703125, 0.005215703125, 0.0045725, 0.004377078125, 0.004377078125, + 0.0048525, 0.005160578125, 0.005160578125, 0.0058845 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_size_t sample_set_sizes[3]; - tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t sample_set_sizes[3], num_index_tuples; + tsk_id_t sample_sets[ts.num_samples * 3], index_tuples[2 * 3] = { 0, 1, 0, 0, 0, 2 }; tsk_size_t num_trees = ts.num_trees; - double *row_positions = tsk_malloc(num_trees * sizeof(*row_positions)); - double *col_positions = tsk_malloc(num_trees * sizeof(*col_positions)); + double *positions = tsk_malloc(num_trees * sizeof(*positions)); double positions_subset_1[2] = { 0., 0.1 }; double positions_subset_2[2] = { 2., 6. }; double positions_subset_3[2] = { 9., 9.999 }; @@ -2752,16 +2789,15 @@ test_paper_ex_two_branch(void) sample_sets[i] = (tsk_id_t) i; } for (i = 0; i < num_trees; i++) { - row_positions[i] = ts.breakpoints[i]; - col_positions[i] = ts.breakpoints[i]; + positions[i] = ts.breakpoints[i]; } options |= TSK_STAT_BRANCH; result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_one_set); @@ -2774,8 +2810,8 @@ test_paper_ex_two_branch(void) result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_two_sets); @@ -2788,35 +2824,69 @@ test_paper_ex_two_branch(void) result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_three_sets); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_1, 2, NULL, positions_subset_1, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_1, 2, NULL, positions_subset_1, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_1); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_2, 2, NULL, positions_subset_2, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_2, 2, NULL, positions_subset_2, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_2); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_3, 2, NULL, positions_subset_3, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_3, 2, NULL, positions_subset_3, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_3); + // Two-way stats: we'll reuse all sample sets from the first 3 tests + num_sample_sets = 3; + result_size = num_trees * num_trees; + + num_index_tuples = 1; + // We'll compute D2 between sample set 0 and sample set 1 + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_one_set); + + // Compare sample sets [(0, 1), (0, 0)] + num_index_tuples = 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_two_sets); + + // Compare sample sets [(0, 1), (0, 0), (0, 2)] + num_index_tuples = 3; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan( + result_size * num_index_tuples, result, truth_three_index_tuples); + tsk_treeseq_free(&ts); - tsk_safe_free(row_positions); - tsk_safe_free(col_positions); + tsk_safe_free(positions); } static void @@ -2853,8 +2923,8 @@ test_two_site_correlated_multiallelic(void) "0 10 16 13\n" "0 10 16 15\n" "10 20 16 15\n"; - const char *sites = "7 A\n" - "13 G\n"; + const char *tree_sites = "7 A\n" + "13 G\n"; const char *mutations = "0 15 T -1\n" "0 14 G 0\n" "1 15 T -1\n" @@ -2877,71 +2947,133 @@ test_two_site_correlated_multiallelic(void) 0.003387017561686057, 0.003387017561686057 }; double truth_pi2[4] = { 0.04579247743399549, 0.04579247743399549, 0.04579247743399549, 0.0457924774339955 }; + double truth_D2_unbiased[4] = { 0.026455026455026454, 0.026455026455026454, + 0.026455026455026454, 0.026455026455026454 }; + double truth_Dz_unbiased[4] = { -0.008818342151675485, -0.008818342151675485, + -0.008818342151675485, -0.008818342151675485 }; + double truth_pi2_unbiased[4] = { 0.0582010582010582, 0.0582010582010582, + 0.0582010582010582, 0.0582010582010582 }; + double truth_D2_unbiased_disjoint[4] = { 0.007407407407407407, 0.007407407407407407, + 0.007407407407407407, 0.007407407407407407 }; - tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + tsk_treeseq_from_text( + &ts, 20, nodes, edges, NULL, tree_sites, mutations, NULL, NULL, 0); tsk_size_t num_sample_sets = 1; - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; - tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; + tsk_id_t sample_sets[ts.num_samples * 2]; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); result_size = num_sites * num_sites; double result[result_size]; + // Two sample sets for multipop at the bottom, only presenting one to single pop + // results for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2_unbiased); + + // We'll compute r2 between sample set 0 and sample set 1 + num_sample_sets = 2; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2); + + // perfectly overlapping sample sets will produce a result equal to the single + // population case + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + // two disjoint sample sets with 5 and 4 samples {0,1,2,3,4}{5,6,7,8} + sample_set_sizes[0] = 5; + sample_set_sizes[1] = 4; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 1 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased_disjoint); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -2988,8 +3120,8 @@ test_two_site_uncorrelated_multiallelic(void) "10 20 23 18,20\n" "0 10 16 14,15\n" "10 20 24 22,23\n"; - const char *sites = "7 A\n" - "13 G\n"; + const char *tree_sites = "7 A\n" + "13 G\n"; const char *mutations = "0 15 T -1\n" "0 12 G 0\n" "1 23 T -1\n" @@ -3007,72 +3139,134 @@ test_two_site_uncorrelated_multiallelic(void) double truth_Dz[4] = { 0.0, 0.0, 0.0, 0.0 }; double truth_pi2[4] = { 0.04938271604938272, 0.04938271604938272, 0.04938271604938272, 0.04938271604938272 }; + double truth_D2_unbiased[4] = { 0.027777777777777776, -0.009259259259259259, + -0.009259259259259259, 0.027777777777777776 }; + double truth_Dz_unbiased[4] = { -0.015873015873015872, 0.005291005291005289, + 0.005291005291005289, -0.015873015873015872 }; + double truth_pi2_unbiased[4] = { 0.06349206349206349, 0.06216931216931215, + 0.06216931216931215, 0.06349206349206349 }; + double truth_D2_unbiased_disjoint[4] = { 0.008333333333333333, + -0.0027777777777777775, -0.0027777777777777775, 0.03518518518518518 }; - tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + tsk_treeseq_from_text( + &ts, 20, nodes, edges, NULL, tree_sites, mutations, NULL, NULL, 0); tsk_size_t s; tsk_size_t num_sample_sets = 1; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; - tsk_id_t sample_sets[ts.num_samples]; + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; + tsk_id_t sample_sets[ts.num_samples * 2]; tsk_size_t result_size = num_sites * num_sites; double result[result_size]; + // Two sample sets for multipop at the bottom, only presenting one to single pop + // results for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2_unbiased); + + // We'll compute r2 between sample set 0 and sample set 1 + num_sample_sets = 2; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2); + + // perfectly overlapping sample sets will produce a result equal to the single + // population case + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + // two disjoint sample sets with 5 and 4 samples {0,1,2,3,4}{5,6,7,8} + sample_set_sizes[0] = 5; + sample_set_sizes[1] = 4; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 1 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased_disjoint); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -3145,7 +3339,7 @@ test_two_site_backmutation(void) } static void -test_two_locus_site_all_stats(void) +test_two_locus_branch_all_stats(void) { int ret; tsk_treeseq_t ts; @@ -3167,12 +3361,7 @@ test_two_locus_site_all_stats(void) "0 2 20 13\n0 2 20 16\n2 10 21 13\n6 10 21 18\n0 6 21 19\n" "0 2 21 20\n"; - double truth_D[16] = { -6.938893903907228e-18, 5.551115123125783e-17, - 4.85722573273506e-17, 2.7755575615628914e-17, 1.0408340855860843e-17, - 8.326672684688674e-17, 7.979727989493313e-17, 6.938893903907228e-17, - -2.42861286636753e-17, 4.163336342344337e-17, 2.42861286636753e-17, - 4.163336342344337e-17, 1.3877787807814457e-17, 5.551115123125783e-17, - 2.0816681711721685e-17, 2.7755575615628914e-17 }; + double truth_D[16] = { 0 }; double truth_D2[16] = { 0.21949755999999998, 0.1867003599999999, 0.18798699999999988, 0.18941379999999983, 0.18670035999999995, 0.21159555999999993, 0.21257979999999996, 0.21222580000000005, 0.187987, 0.21257979999999996, @@ -3355,9 +3544,11 @@ test_two_locus_stat_input_errors(void) tsk_size_t num_sites = ts.tables->sites.num_rows; tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; tsk_size_t num_sample_sets = 1; - tsk_id_t sample_sets[ts.num_samples]; + tsk_id_t index_tuples[2] = { 0 }; + tsk_size_t num_index_tuples = 1; + tsk_id_t sample_sets[ts.num_samples * 2]; // need 2 sample sets for multipop double positions[10] = { 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 }; double bad_col_positions[2] = { 0., 0. }; // used in 1 test to cover column check double result[100]; @@ -3365,17 +3556,25 @@ test_two_locus_stat_input_errors(void) for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { row_sites[s] = (tsk_id_t) s; col_sites[s] = (tsk_id_t) s; } + // begin with the happy path + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); - sample_set_sizes[0] = ts.num_samples; - num_sample_sets = 1; - for (s = 0; s < ts.num_samples; s++) { - sample_sets[s] = (tsk_id_t) s; - } + ret = tsk_treeseq_two_locus_count_stat(&ts, num_sample_sets, sample_set_sizes, + sample_sets, 0, NULL, NULL, NULL, num_sites, row_sites, NULL, num_sites, + col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_RESULT_DIMS); + + ret = tsk_treeseq_r2(&ts, 1, sample_set_sizes, sample_sets, num_sites, row_sites, + NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); sample_sets[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, @@ -3478,6 +3677,27 @@ test_two_locus_stat_input_errors(void) positions, 10, NULL, positions, TSK_STAT_NODE, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + num_sample_sets = 2; + num_index_tuples = 0; + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_INDEX_TUPLES); + + num_sample_sets = 1; + num_index_tuples = 1; + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); + + num_sample_sets = 2; + index_tuples[0] = 2; + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); + tsk_treeseq_free(&ts); tsk_safe_free(row_sites); tsk_safe_free(col_sites); @@ -3817,7 +4037,7 @@ main(int argc, char **argv) { "test_two_site_uncorrelated_multiallelic", test_two_site_uncorrelated_multiallelic }, { "test_two_site_backmutation", test_two_site_backmutation }, - { "test_two_locus_site_all_stats", test_two_locus_site_all_stats }, + { "test_two_locus_site_all_stats", test_two_locus_branch_all_stats }, { "test_paper_ex_two_site_subset", test_paper_ex_two_site_subset }, { "test_two_locus_stat_input_errors", test_two_locus_stat_input_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 2b778f195b..adc8034d4e 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2225,7 +2225,7 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, } static int -norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, +norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; @@ -2233,7 +2233,7 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, double n; tsk_size_t k; - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { weight_row = GET_2D_ROW(hap_weights, 3, k); n = (double) args.sample_set_sizes[k]; // TODO: what to do when n = 0 @@ -2243,12 +2243,38 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t state_dim, const double *TSK_UNUSED(hap_weights), +norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, + tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *weight_row; + double ni, nj, wAB_i, wAB_j; + tsk_id_t i, j; + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + weight_row = GET_2D_ROW(hap_weights, 3, i); + wAB_i = weight_row[0]; + weight_row = GET_2D_ROW(hap_weights, 3, j); + wAB_j = weight_row[0]; + + result[k] = (wAB_i + wAB_j) / (ni + nj); + } + + return 0; +} + +static int +norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { result[k] = 1 / (double) (n_a * n_b); } return 0; @@ -2268,9 +2294,6 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) } } -typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); - static int compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, @@ -2290,14 +2313,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] tsk_size_t k, mut_a, mut_b; - tsk_size_t row_len = num_b_alleles * state_dim; + tsk_size_t result_row_len = num_b_alleles * result_dim; tsk_size_t w_A = 0, w_B = 0, w_AB = 0; uint8_t polarised_val = polarised ? 1 : 0; double *hap_weight_row; double *result_tmp_row; double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(state_dim * sizeof(*norm)); - double *result_tmp = tsk_malloc(row_len * num_a_alleles * sizeof(*result_tmp)); + double *norm = tsk_malloc(result_dim * sizeof(*norm)); + double *result_tmp + = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); @@ -2327,7 +2351,7 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, } for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { - result_tmp_row = GET_2D_ROW(result_tmp, row_len, mut_a); + result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) { tsk_bit_array_get_row(site_a_state, mut_a, &A_samples); tsk_bit_array_get_row(site_b_state, mut_b, &B_samples); @@ -2352,15 +2376,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, if (ret != 0) { goto out; } - ret = norm_f(state_dim, weights, num_a_alleles - polarised_val, + ret = norm_f(result_dim, weights, num_a_alleles - polarised_val, num_b_alleles - polarised_val, norm, f_params); if (ret != 0) { goto out; } - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { result[k] += result_tmp_row[k] * norm[k]; } - result_tmp_row += state_dim; // Advance to the next column + result_tmp_row += result_dim; // Advance to the next column } } @@ -2538,8 +2562,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, get_site_row_col_indices( n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx); - // We rely on n_sites to allocate these arrays, they're initialized to NULL for safe - // deallocation if the previous allocation fails + // We rely on n_sites to allocate these arrays, which are initialized + // to NULL for safe deallocation if the previous allocation fails num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); if (num_alleles == NULL || site_offsets == NULL) { @@ -3195,7 +3219,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di return ret; } -static int +int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, @@ -3209,7 +3233,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl tsk_bit_array_t sample_sets_bits; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); - // double default_windows[] = { 0, self->tables->sequence_length }; tsk_size_t state_dim = num_sample_sets; sample_count_stat_params_t f_params = { .sample_sets = sample_sets, .num_sample_sets = num_sample_sets, @@ -3232,17 +3255,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); goto out; } - // TODO: impossible until we implement branch/windows - // if (result_dim < 1) { - // ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); - // goto out; - // } ret = tsk_treeseq_check_sample_sets( self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { goto out; } - tsk_bug_assert(state_dim > 0); + if (result_dim < 1) { + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); + goto out; + } ret = sample_sets_to_bit_array( self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); if (ret != 0) { @@ -4781,6 +4802,200 @@ tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +static int +D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + double n; + tsk_size_t k; + tsk_id_t i, j; + double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_i = p_AB - (p_A * p_B); + + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_j = p_AB - (p_A * p_B); + + result[k] = D_i * D_j; + } + + return 0; +} + +int +tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + +static int +D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double n_i, n_j; + double w_AB_i, w_Ab_i, w_aB_i, w_ab_i; + double w_AB_j, w_Ab_j, w_aB_j, w_ab_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + if (i == j) { + // We require disjoint sample sets because we test equality here + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1) + + w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1) + - 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i) + / n_i / (n_i - 1) / (n_i - 2) / (n_i - 3); + } + + else { + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + + n_j = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + w_AB_j = state_row[0]; + w_Ab_j = state_row[1]; + w_aB_j = state_row[2]; + w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j); + + result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) + * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j + / (n_j - 1); + } + } + + return 0; +} + +int +tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + +static int +r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double n, pAB, pAb, paB, pA, pB, D_i, D_j, denom_i, denom_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + pAB = state_row[0] / n; + pAb = state_row[1] / n; + paB = state_row[2] / n; + pA = pAB + pAb; + pB = pAB + paB; + D_i = pAB - (pA * pB); + denom_i = sqrt(pA * (1 - pA) * pB * (1 - pB)); + + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + pAB = state_row[0] / n; + pAb = state_row[1] / n; + paB = state_row[2] / n; + pA = pAB + pAb; + pB = pAB + paB; + D_j = pAB - (pA * pB); + denom_j = sqrt(pA * (1 - pA) * pB * (1 - pB)); + + result[k] = (D_i * D_j) / (denom_i * denom_j); + } + return 0; +} + +int +tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, + norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + /*********************************** * Three way stats ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index ac4100f7b0..040e88f86a 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -981,15 +981,17 @@ typedef int general_stat_func_t(tsk_size_t state_dim, const double *state, int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const double *W, tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -// TODO: expose this externally? -/* int tsk_treeseq_two_locus_general_stat(const tsk_treeseq_t *self, */ -/* tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, */ -/* const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, - */ -/* general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t num_left_windows, */ -/* const double *left_windows, tsk_size_t num_right_windows, */ -/* const double *right_windows, tsk_flags_t options, tsk_size_t num_result, */ -/* double *result); */ + +typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *params); + +int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, + general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); /* One way weighted stats */ @@ -1063,24 +1065,6 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, - const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, - const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result); - typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, @@ -1138,6 +1122,51 @@ int tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_se const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result); +typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + +/* Two way sample set stats */ + +int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); +int tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); + /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index cab445e5d0..10a136be27 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10507,17 +10507,25 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "row_sites", "col_sites", "row_positions", "column_positions", "mode", NULL }; - - PyObject *row_sites = NULL, *col_sites = NULL, *row_positions = NULL, - *col_positions = NULL, *sample_set_sizes = NULL, *sample_sets = NULL; - PyArrayObject *row_sites_array = NULL, *col_sites_array = NULL, - *row_positions_array = NULL, *col_positions_array = NULL, - *sample_sets_array = NULL, *sample_set_sizes_array = NULL, - *result_matrix = NULL; - tsk_id_t *row_sites_parsed = NULL, *col_sites_parsed = NULL; - double *row_positions_parsed = NULL, *col_positions_parsed = NULL; - npy_intp result_dim[3] = { 0, 0, 0 }; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; int err; @@ -10659,6 +10667,148 @@ TreeSequence_Dz_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kw return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_Dz_unbiased); } +static PyObject * +TreeSequence_k_way_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, + npy_intp tuple_size, k_way_two_locus_count_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "row_sites", + "col_sites", "row_positions", "column_positions", "mode", NULL }; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *indexes = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + tsk_size_t num_sample_sets; + tsk_size_t num_set_index_tuples; + npy_intp *shape, result_dim[3] = { 0, 0, 0 }; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OOOOs", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &row_sites, &col_sites, &row_positions, + &col_positions, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_set_index_tuples = shape[0]; + + result_dim[2] = num_set_index_tuples; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + num_set_index_tuples, PyArray_DATA(indexes_array), result_dim[0], + row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, + col_positions_parsed, options, PyArray_DATA(result_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(indexes_array); + Py_XDECREF(result_matrix); + return ret; +} + +static PyObject * +TreeSequence_D2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_D2_ij); +} + +static PyObject * +TreeSequence_D2_ij_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_D2_ij_unbiased); +} + +static PyObject * +TreeSequence_r2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_r2_ij); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -11765,6 +11915,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the unbiased pi2 matrix." }, + { .ml_name = "D2_ij_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way D^2 matrix." }, + { .ml_name = "D2_ij_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way unbiased D^2 matrix." }, + { .ml_name = "r2_ij_matrix", + .ml_meth = (PyCFunction) TreeSequence_r2_ij_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way r^2 matrix." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index be8036d5a6..a3eece1e2d 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -244,6 +244,42 @@ def norm_hap_weighted( result[k] = hap_weights[0, k] / n +def norm_hap_weighted_ij( + result_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """ + Create a vector of normalizing coefficients, length of the number of + index tuples. Each allele's statistic will be weighted by the average + of the proportion of AB haplotypes in each population present in the + index tuple. + + :param result_dim: Number of dimensions in output. Dependent on arity of stat. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del n_a, n_b # handle unused params + sample_set_sizes = params["sample_set_sizes"] + set_indexes = params["set_indexes"] + + for k in range(result_dim): + i = set_indexes[k][0] + j = set_indexes[k][1] + ni = sample_set_sizes[i] + nj = sample_set_sizes[j] + wAB_i = hap_weights[0, i] + wAB_j = hap_weights[0, j] + result[k] = (wAB_i + wAB_j) / (ni + nj) + # result[k] = (wAB_i / ni / 2) + (wAB_j / nj / 2) + + def norm_total_weighted( result_dim: int, hap_weights: np.ndarray, @@ -523,7 +559,6 @@ def compute_general_two_site_stat_result( result_tmp = np.zeros(result_dim, np.float64) polarised_val = 1 if polarised else 0 - for mut_a in range(polarised_val, num_row_alleles): a = int(mut_a + row_site_offset) for mut_b in range(polarised_val, num_col_alleles): @@ -1001,22 +1036,22 @@ def r2_ij_summary_func( i = set_indexes[k][0] j = set_indexes[k][1] n = sample_set_sizes[i] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n - p_A = p_AB + p_Ab - p_B = p_AB + p_aB - D_i = p_AB - (p_A * p_B) - denom_i = np.sqrt(p_A * p_B * (1 - p_A) * (1 - p_B)) + pAB = state[0, i] / n + pAb = state[1, i] / n + paB = state[2, i] / n + pA = pAB + pAb + pB = pAB + paB + D_i = pAB - pA * pB + denom_i = np.sqrt(pA * (1 - pA) * pB * (1 - pB)) n = sample_set_sizes[j] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n - p_A = p_AB + p_Ab - p_B = p_AB + p_aB - D_j = p_AB - (p_A * p_B) - denom_j = np.sqrt(p_A * p_B * (1 - p_A) * (1 - p_B)) + pAB = state[0, j] / n + pAb = state[1, j] / n + paB = state[2, j] / n + pA = pAB + pAb + pB = pAB + paB + D_j = pAB - pA * pB + denom_j = np.sqrt(pA * (1 - pA) * pB * (1 - pB)) with suppress_overflow_div0_warning(): result[k] = (D_i * D_j) / (denom_i * denom_j) @@ -1249,17 +1284,17 @@ def D2_ij_summary_func( j = set_indexes[k][1] n = sample_set_sizes[i] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, i] / n + p_Ab = state[1, i] / n + p_aB = state[2, i] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_i = p_AB - (p_A * p_B) n = sample_set_sizes[j] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, j] / n + p_Ab = state[1, j] / n + p_aB = state[2, j] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_j = p_AB - (p_A * p_B) @@ -1287,17 +1322,18 @@ def D2_ij_unbiased_summary_func( w_Ab = state[1, i] w_aB = state[2, i] w_ab = n - (w_AB + w_Ab + w_aB) - result[k] = ( - ( - w_AB * (w_AB - 1) * w_ab * (w_ab - 1) - + w_Ab * (w_Ab - 1) * w_aB * (w_aB - 1) - - 2 * w_AB * w_Ab * w_aB * w_ab + with suppress_overflow_div0_warning(): + result[k] = ( + ( + w_AB * (w_AB - 1) * w_ab * (w_ab - 1) + + w_Ab * (w_Ab - 1) * w_aB * (w_aB - 1) + - 2 * w_AB * w_Ab * w_aB * w_ab + ) + / n + / (n - 1) + / (n - 2) + / (n - 3) ) - / n - / (n - 1) - / (n - 2) - / (n - 3) - ) else: n_i = sample_set_sizes[i] w_AB_i = state[0, i] @@ -1311,14 +1347,15 @@ def D2_ij_unbiased_summary_func( w_aB_j = state[2, j] w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j) - result[k] = ( - (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) - * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) - / n_i - / (n_i - 1) - / n_j - / (n_j - 1) - ) + with suppress_overflow_div0_warning(): + result[k] = ( + (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) + * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) + / n_i + / (n_i - 1) + / n_j + / (n_j - 1) + ) SUMMARY_FUNCS = { @@ -1351,7 +1388,7 @@ def D2_ij_unbiased_summary_func( D2_unbiased_summary_func: norm_total_weighted, Dz_unbiased_summary_func: norm_total_weighted, pi2_unbiased_summary_func: norm_total_weighted, - r2_ij_summary_func: norm_hap_weighted, + r2_ij_summary_func: norm_hap_weighted_ij, D2_ij_summary_func: norm_total_weighted, D2_ij_unbiased_summary_func: norm_total_weighted, } @@ -1373,9 +1410,11 @@ def D2_ij_unbiased_summary_func( } -def check_set_indexes(num_sets: int, num_set_indexes: int, set_indexes: np.ndarray): - for i in range(len(set_indexes)): - for j in range(num_set_indexes): +def check_set_indexes( + num_sets: int, num_set_indexes: int, tuple_size: int, set_indexes: np.ndarray +): + for i in range(num_set_indexes): + for j in range(tuple_size): if set_indexes[i, j] < 0 or set_indexes[i, j] >= num_sets: raise ValueError(f"Bad sample set index: {set_indexes[i, j]}") @@ -1393,7 +1432,7 @@ def check_sample_stat_inputs( ) if num_index_tuples < 1: raise ValueError(f"Insufficient number of index tuples: {num_index_tuples}") - check_set_indexes(num_sample_sets, num_index_tuples, index_tuples) + check_set_indexes(num_sample_sets, num_index_tuples, tuple_size, index_tuples) def ld_matrix( @@ -1761,6 +1800,17 @@ def test_input_validation(): with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(positions=[], mode="branch") + with pytest.raises( + ValueError, match="Sample sets must contain at least one element" + ): + ts.ld_matrix(sample_sets=[[1, 2, 3], []], indexes=[]) + with pytest.raises( + ValueError, match="Indexes must be convertable to a 2D numpy array" + ): + ts.ld_matrix( + sample_sets=[ts.samples(), ts.samples()], indexes=[[1, 2, 3], [2, 3, 4]] + ) + @dataclass class TreeState: @@ -2152,6 +2202,49 @@ def test_branch_ld_matrix_2pop_sample_sets_unbiased(ts, sample_set, stat): ) +def gen_dims_test_cases(ts, mode): + ss = ts.samples() + dim = ts.num_sites if mode == "site" else ts.num_trees + base = (dim, dim) + return [ + {"name": f"{mode}_default", "ld_params": {"mode": mode}, "shape": base}, + { + "name": f"{mode}_dim_drop", + "ld_params": {"mode": mode, "sample_sets": ss}, + "shape": base, + }, + { + "name": f"{mode}_no_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss]}, + "shape": (1, *base), + }, + { + "name": f"{mode}_two_sample_sets", + "ld_params": {"mode": mode, "sample_sets": [ss, ss]}, + "shape": (2, *base), + }, + { + "name": f"{mode}_two_way_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss, ss], "indexes": (0, 1)}, + "shape": base, + }, + { + "name": f"{mode}_two_way_no_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss, ss], "indexes": [(0, 1)]}, + "shape": (1, *base), + }, + { + "name": f"{mode}_two_way_three_set_indexes", + "ld_params": { + "mode": mode, + "sample_sets": [ss, ss], + "indexes": [(0, 0), (0, 1), (1, 1)], + }, + "shape": (3, *base), + }, + ] + + def get_test_dims_test_cases(): test_cases = { "empty_tree", @@ -2161,16 +2254,155 @@ def get_test_dims_test_cases(): "internal_nodes_samples", "mixed_internal_leaf_samples", } - return [t for t in get_example_tree_sequences() if t.id in test_cases] + for ts_case in [t for t in get_example_tree_sequences() if t.id in test_cases]: + ts = ts_case.values[0] + for dim_case in gen_dims_test_cases(ts, "site"): + name = "_".join([dim_case["name"], ts_case.id]) + yield pytest.param(ts, dim_case["ld_params"], dim_case["shape"], id=name) + for dim_case in gen_dims_test_cases(ts, "branch"): + name = "_".join([dim_case["name"], ts_case.id]) + yield pytest.param(ts, dim_case["ld_params"], dim_case["shape"], id=name) -@pytest.mark.parametrize("ts", get_test_dims_test_cases()) -def test_dims(ts): - ss = ts.samples() - assert ld_matrix(ts).ndim == 2 - assert ld_matrix(ts, sample_sets=ss).ndim == 2 - assert ld_matrix(ts, sample_sets=[ss]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=(0, 0)).ndim == 2 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=[(0, 0)]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=[(0, 0), (0, 1)]).ndim == 3 +@pytest.mark.parametrize("ts,params,shape", get_test_dims_test_cases()) +def test_dims(ts, params, shape): + assert ts.ld_matrix(**params).shape == ld_matrix(ts, **params).shape == shape + + +@pytest.mark.parametrize("ts,sample_sets", get_test_branch_2pop_test_cases()) +@pytest.mark.parametrize("stat", sorted(TWO_WAY_SUMMARY_FUNCS.keys())) +def test_two_way_branch_ld_matrix(ts, sample_sets, stat): + np.testing.assert_array_almost_equal( + ld_matrix(ts, sample_sets=sample_sets, indexes=[(0, 0), (0, 1), (1, 1)]), + ts.ld_matrix(sample_sets=sample_sets, indexes=[(0, 0), (0, 1), (1, 1)]), + ) + + +@pytest.mark.parametrize( + "ts", + [ + ts + for ts in get_example_tree_sequences() + if ts.id not in {"no_samples", "empty_ts"} + ], +) +@pytest.mark.parametrize( + "stat", + sorted(TWO_WAY_SUMMARY_FUNCS.keys()), +) +def test_two_way_site_ld_matrix(ts, stat): + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat), ts.ld_matrix(stat=stat) + ) + ss = [ts.samples()] * 3 + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]), + ts.ld_matrix(stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]), + ) + + +@pytest.mark.parametrize( + "genotypes,sample_sets,expected", + [ + ( + # these genotypes are rows from a genotype matrix (sites x samples) + correlated := np.array( + [ + [0, 1, 1, 0, 2, 2, 1, 0, 2, 0, 1, 2], + [1, 2, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0], + ], + ), + (np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9, 10, 11])), + np.float64(1.0), + ), + ( + correlated, + (np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9, 10])), + np.float64(1.0), + ), + ( + correlated, + (np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8, 9])), + np.float64(1.0), + ), + ( + correlated, + (np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7, 8])), + np.float64(1.0), + ), + ( + correlated, + (np.array([0, 1, 2, 3, 4, 5]), np.array([6, 7])), + np.float64(np.nan), + ), + ( + correlated, + (np.array([0, 1, 2, 3, 4, 5]), np.array([6])), + np.float64(np.nan), + ), + ( + anticorrelated := np.array( + [ + [0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3], + [1, 1, 1, 1, 3, 3, 3, 3, 0, 0, 0, 0, 2, 2, 2, 2], + ] + ), + ( + np.array([0, 2, 4, 6, 8, 10, 12, 14]), + np.array([1, 3, 5, 7, 9, 11, 13, 15]), + ), + np.float64(1.0), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9, 11, 13])), + np.float64(1.0), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9, 11])), + np.float64(np.nan), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7, 9])), + np.float64(np.nan), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5, 7])), + np.float64(np.nan), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3, 5])), + np.float64(np.nan), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1, 3])), + np.float64(np.nan), + ), + ( + anticorrelated, + (np.array([0, 2, 4, 6, 8, 10, 12, 14]), np.array([1])), + np.float64(np.nan), + ), + ], +) +def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, expected): + a, b = genotypes + state_dim = len(sample_sets) + state = np.zeros((3, state_dim), dtype=int) + result = np.zeros((max(a) + 1, max(b) + 1, 1)) + norm = np.zeros_like(result) + params = dict(sample_set_sizes=list(map(len, sample_sets)), set_indexes=[(0, 1)]) + for i, j in np.ndindex(result.shape[:2]): + for k, ss in enumerate(sample_sets): + A = a[ss] == i + B = b[ss] == j + state[:, k] = (A & B).sum(), (A & ~B).sum(), (~A & B).sum() + r2_ij_summary_func(state_dim, state, 1, result[i, j], params) + norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) + + np.testing.assert_allclose((result * norm).sum(), expected) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 9cc8206cb1..d6416c22b8 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1711,6 +1711,214 @@ def test_ld_matrix(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, col_sites, row_sites, None, None, "node") + @pytest.mark.parametrize( + "stat_method_name", + [ + "D2_ij_matrix", + "r2_ij_matrix", + "D2_ij_unbiased_matrix", + ], + ) + def test_ld_matrix_multipop(self, stat_method_name): + ts = self.get_example_tree_sequence(10) + stat_method = getattr(ts, stat_method_name) + + num_samples = len(ts.get_samples()) + ss = np.hstack([ts.get_samples(), ts.get_samples()]) # sample sets + ss_sizes = np.array([num_samples, num_samples], dtype=np.uint32) + indexes = [(0, 0), (0, 1)] + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + + # happy path + a = stat_method(ss_sizes, ss, indexes, row_sites, col_sites, None, None, "site") + assert a.shape == (10, 10, 2) + a = stat_method( + ss_sizes, ss, indexes, row_sites_list, col_sites_list, None, None, "site" + ) + assert a.shape == (10, 10, 2) + a = stat_method(ss_sizes, ss, indexes, None, None, None, None, "site") + assert a.shape == (10, 10, 2) + + a = stat_method(ss_sizes, ss, indexes, None, None, row_pos, col_pos, "branch") + assert a.shape == (2, 2, 2) + a = stat_method( + ss_sizes, ss, indexes, None, None, row_pos_list, col_pos_list, "branch" + ) + assert a.shape == (2, 2, 2) + a = stat_method(ss_sizes, ss, indexes, None, None, None, None, "branch") + assert a.shape == (2, 2, 2) + + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + stat_method( + ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + stat_method( + ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + stat_method(ss_sizes, ss, indexes, row_sites, col_sites, None, None, "bla") + with pytest.raises(TypeError, match="at most"): + stat_method( + ss_sizes, ss, indexes, row_sites, col_sites, None, None, "site", "abc" + ) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + stat_method( + ss_sizes, ss, indexes, row_sites, col_sites, None, None, "branch" + ) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + stat_method(ss_sizes, ss, indexes, None, None, row_pos, col_pos, "site") + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + bad_ss = np.array([1000, 1000], dtype=np.int32) + bad_ss_sizes = np.array([1, 1], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + bad_ss = np.array([1000, 1000], dtype=np.int32) + bad_ss_sizes = np.array([1, 1], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + bad_ss = np.array([1, 1, 2, 3], dtype=np.int32) + bad_ss_sizes = np.array([2, 2], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + bad_ss = np.array([1, 1, 2, 3], dtype=np.int32) + bad_ss_sizes = np.array([2, 2], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(ValueError, match="indexes must be a"): + bad_indexes = np.array([[0, 0, 1, 1], [0, 0, 1, 1]], dtype=np.int32) + stat_method( + ss_sizes, ss, bad_indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): + stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e740181f81..77ccbf4838 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8095,6 +8095,52 @@ def __two_locus_sample_set_stat( return result + def __k_way_two_locus_sample_set_stat( + self, + ll_method, + k, + sample_sets, + indexes=None, + sites=None, + positions=None, + mode=None, + ): + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + result = ll_method( + sample_set_sizes, + flattened, + indexes, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[:2]) + else: + # Orient the data so that the first dimension is the sample set. + # With this orientation, we get one LD matrix per sample set. + result = result.swapaxes(0, 2).swapaxes(1, 2) + return result + def __k_way_sample_set_stat( self, ll_method, @@ -10627,9 +10673,15 @@ def impute_unknown_mutations_time( return mutations_time def ld_matrix( - self, sample_sets=None, sites=None, positions=None, mode="site", stat="r2" + self, + sample_sets=None, + sites=None, + positions=None, + mode="site", + stat="r2", + indexes=None, ): - stats = { + one_way_stats = { "D": self._ll_tree_sequence.D_matrix, "D2": self._ll_tree_sequence.D2_matrix, "r2": self._ll_tree_sequence.r2_matrix, @@ -10641,20 +10693,32 @@ def ld_matrix( "D2_unbiased": self._ll_tree_sequence.D2_unbiased_matrix, "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_matrix, } - + two_way_stats = { + "D2": self._ll_tree_sequence.D2_ij_matrix, + "D2_unbiased": self._ll_tree_sequence.D2_ij_unbiased_matrix, + "r2": self._ll_tree_sequence.r2_ij_matrix, + } + stats = one_way_stats if indexes is None else two_way_stats try: - two_locus_stat = stats[stat] + stat_func = stats[stat] except KeyError: raise ValueError( f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" ) + if indexes is not None: + return self.__k_way_two_locus_sample_set_stat( + stat_func, + 2, + sample_sets, + indexes=indexes, + sites=sites, + positions=positions, + mode=mode, + ) + return self.__two_locus_sample_set_stat( - two_locus_stat, - sample_sets, - sites=sites, - positions=positions, - mode=mode, + stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) def sample_nodes_by_ploidy(self, ploidy):