Skip to content

Commit 8919730

Browse files
lkirkjeromekelleher
authored andcommitted
C and Python API for two-way two-locus stats
This PR implements the C and Python API for computing two-way two-locus statistics. The algorithm is identical to the python version, except during testing I uncovered a small issue with normalisation. We need to handle the case where sample sets are of different sizes. The fix for this was to combine the normalisation factor for each sample set. Test coverage has been added to cover C, low-level python and some high-level tests.
1 parent 154168f commit 8919730

File tree

7 files changed

+1364
-234
lines changed

7 files changed

+1364
-234
lines changed

c/tests/test_stats.c

Lines changed: 329 additions & 109 deletions
Large diffs are not rendered by default.

c/tskit/trees.c

Lines changed: 239 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,15 +2225,15 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state,
22252225
}
22262226

22272227
static int
2228-
norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
2228+
norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
22292229
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
22302230
{
22312231
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
22322232
const double *weight_row;
22332233
double n;
22342234
tsk_size_t k;
22352235

2236-
for (k = 0; k < state_dim; k++) {
2236+
for (k = 0; k < result_dim; k++) {
22372237
weight_row = GET_2D_ROW(hap_weights, 3, k);
22382238
n = (double) args.sample_set_sizes[k];
22392239
// TODO: what to do when n = 0
@@ -2243,12 +2243,38 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
22432243
}
22442244

22452245
static int
2246-
norm_total_weighted(tsk_size_t state_dim, const double *TSK_UNUSED(hap_weights),
2246+
norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
2247+
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
2248+
{
2249+
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
2250+
const double *weight_row;
2251+
double ni, nj, wAB_i, wAB_j;
2252+
tsk_id_t i, j;
2253+
tsk_size_t k;
2254+
2255+
for (k = 0; k < result_dim; k++) {
2256+
i = args.set_indexes[2 * k];
2257+
j = args.set_indexes[2 * k + 1];
2258+
ni = (double) args.sample_set_sizes[i];
2259+
nj = (double) args.sample_set_sizes[j];
2260+
weight_row = GET_2D_ROW(hap_weights, 3, i);
2261+
wAB_i = weight_row[0];
2262+
weight_row = GET_2D_ROW(hap_weights, 3, j);
2263+
wAB_j = weight_row[0];
2264+
2265+
result[k] = (wAB_i + wAB_j) / (ni + nj);
2266+
}
2267+
2268+
return 0;
2269+
}
2270+
2271+
static int
2272+
norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights),
22472273
tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
22482274
{
22492275
tsk_size_t k;
22502276

2251-
for (k = 0; k < state_dim; k++) {
2277+
for (k = 0; k < result_dim; k++) {
22522278
result[k] = 1 / (double) (n_a * n_b);
22532279
}
22542280
return 0;
@@ -2268,9 +2294,6 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n)
22682294
}
22692295
}
22702296

2271-
typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t n_a,
2272-
tsk_size_t n_b, double *result, void *params);
2273-
22742297
static int
22752298
compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
22762299
const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles,
@@ -2290,14 +2313,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
22902313
// a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
22912314
// a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
22922315
tsk_size_t k, mut_a, mut_b;
2293-
tsk_size_t row_len = num_b_alleles * state_dim;
2316+
tsk_size_t result_row_len = num_b_alleles * result_dim;
22942317
tsk_size_t w_A = 0, w_B = 0, w_AB = 0;
22952318
uint8_t polarised_val = polarised ? 1 : 0;
22962319
double *hap_weight_row;
22972320
double *result_tmp_row;
22982321
double *weights = tsk_malloc(3 * state_dim * sizeof(*weights));
2299-
double *norm = tsk_malloc(state_dim * sizeof(*norm));
2300-
double *result_tmp = tsk_malloc(row_len * num_a_alleles * sizeof(*result_tmp));
2322+
double *norm = tsk_malloc(result_dim * sizeof(*norm));
2323+
double *result_tmp
2324+
= tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp));
23012325

23022326
tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples));
23032327
tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples));
@@ -2327,7 +2351,7 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
23272351
}
23282352

23292353
for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) {
2330-
result_tmp_row = GET_2D_ROW(result_tmp, row_len, mut_a);
2354+
result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a);
23312355
for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) {
23322356
tsk_bit_array_get_row(site_a_state, mut_a, &A_samples);
23332357
tsk_bit_array_get_row(site_b_state, mut_b, &B_samples);
@@ -2352,15 +2376,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
23522376
if (ret != 0) {
23532377
goto out;
23542378
}
2355-
ret = norm_f(state_dim, weights, num_a_alleles - polarised_val,
2379+
ret = norm_f(result_dim, weights, num_a_alleles - polarised_val,
23562380
num_b_alleles - polarised_val, norm, f_params);
23572381
if (ret != 0) {
23582382
goto out;
23592383
}
2360-
for (k = 0; k < state_dim; k++) {
2384+
for (k = 0; k < result_dim; k++) {
23612385
result[k] += result_tmp_row[k] * norm[k];
23622386
}
2363-
result_tmp_row += state_dim; // Advance to the next column
2387+
result_tmp_row += result_dim; // Advance to the next column
23642388
}
23652389
}
23662390

@@ -2538,8 +2562,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
25382562
get_site_row_col_indices(
25392563
n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx);
25402564

2541-
// We rely on n_sites to allocate these arrays, they're initialized to NULL for safe
2542-
// deallocation if the previous allocation fails
2565+
// We rely on n_sites to allocate these arrays, which are initialized
2566+
// to NULL for safe deallocation if the previous allocation fails
25432567
num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles));
25442568
site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets));
25452569
if (num_alleles == NULL || site_offsets == NULL) {
@@ -3195,7 +3219,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di
31953219
return ret;
31963220
}
31973221

3198-
static int
3222+
int
31993223
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
32003224
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
32013225
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
@@ -3209,7 +3233,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
32093233
tsk_bit_array_t sample_sets_bits;
32103234
bool stat_site = !!(options & TSK_STAT_SITE);
32113235
bool stat_branch = !!(options & TSK_STAT_BRANCH);
3212-
// double default_windows[] = { 0, self->tables->sequence_length };
32133236
tsk_size_t state_dim = num_sample_sets;
32143237
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
32153238
.num_sample_sets = num_sample_sets,
@@ -3232,17 +3255,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
32323255
ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES);
32333256
goto out;
32343257
}
3235-
// TODO: impossible until we implement branch/windows
3236-
// if (result_dim < 1) {
3237-
// ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
3238-
// goto out;
3239-
// }
32403258
ret = tsk_treeseq_check_sample_sets(
32413259
self, num_sample_sets, sample_set_sizes, sample_sets);
32423260
if (ret != 0) {
32433261
goto out;
32443262
}
3245-
tsk_bug_assert(state_dim > 0);
3263+
if (result_dim < 1) {
3264+
ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
3265+
goto out;
3266+
}
32463267
ret = sample_sets_to_bit_array(
32473268
self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits);
32483269
if (ret != 0) {
@@ -4781,6 +4802,200 @@ tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
47814802
return ret;
47824803
}
47834804

4805+
static int
4806+
D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
4807+
tsk_size_t result_dim, double *result, void *params)
4808+
{
4809+
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
4810+
const double *state_row;
4811+
double n;
4812+
tsk_size_t k;
4813+
tsk_id_t i, j;
4814+
double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j;
4815+
4816+
for (k = 0; k < result_dim; k++) {
4817+
i = args.set_indexes[2 * k];
4818+
j = args.set_indexes[2 * k + 1];
4819+
4820+
n = (double) args.sample_set_sizes[i];
4821+
state_row = GET_2D_ROW(state, 3, i);
4822+
p_AB = state_row[0] / n;
4823+
p_Ab = state_row[1] / n;
4824+
p_aB = state_row[2] / n;
4825+
p_A = p_AB + p_Ab;
4826+
p_B = p_AB + p_aB;
4827+
D_i = p_AB - (p_A * p_B);
4828+
4829+
n = (double) args.sample_set_sizes[j];
4830+
state_row = GET_2D_ROW(state, 3, j);
4831+
p_AB = state_row[0] / n;
4832+
p_Ab = state_row[1] / n;
4833+
p_aB = state_row[2] / n;
4834+
p_A = p_AB + p_Ab;
4835+
p_B = p_AB + p_aB;
4836+
D_j = p_AB - (p_A * p_B);
4837+
4838+
result[k] = D_i * D_j;
4839+
}
4840+
4841+
return 0;
4842+
}
4843+
4844+
int
4845+
tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
4846+
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
4847+
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
4848+
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
4849+
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
4850+
double *result)
4851+
{
4852+
int ret = 0;
4853+
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
4854+
if (ret != 0) {
4855+
goto out;
4856+
}
4857+
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
4858+
sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func,
4859+
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
4860+
col_positions, options, result);
4861+
out:
4862+
return ret;
4863+
}
4864+
4865+
static int
4866+
D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
4867+
tsk_size_t result_dim, double *result, void *params)
4868+
{
4869+
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
4870+
const double *state_row;
4871+
tsk_size_t k;
4872+
tsk_id_t i, j;
4873+
double n_i, n_j;
4874+
double w_AB_i, w_Ab_i, w_aB_i, w_ab_i;
4875+
double w_AB_j, w_Ab_j, w_aB_j, w_ab_j;
4876+
4877+
for (k = 0; k < result_dim; k++) {
4878+
i = args.set_indexes[2 * k];
4879+
j = args.set_indexes[2 * k + 1];
4880+
if (i == j) {
4881+
// We require disjoint sample sets because we test equality here
4882+
n_i = (double) args.sample_set_sizes[i];
4883+
state_row = GET_2D_ROW(state, 3, i);
4884+
w_AB_i = state_row[0];
4885+
w_Ab_i = state_row[1];
4886+
w_aB_i = state_row[2];
4887+
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);
4888+
result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1)
4889+
+ w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1)
4890+
- 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i)
4891+
/ n_i / (n_i - 1) / (n_i - 2) / (n_i - 3);
4892+
}
4893+
4894+
else {
4895+
n_i = (double) args.sample_set_sizes[i];
4896+
state_row = GET_2D_ROW(state, 3, i);
4897+
w_AB_i = state_row[0];
4898+
w_Ab_i = state_row[1];
4899+
w_aB_i = state_row[2];
4900+
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);
4901+
4902+
n_j = (double) args.sample_set_sizes[j];
4903+
state_row = GET_2D_ROW(state, 3, j);
4904+
w_AB_j = state_row[0];
4905+
w_Ab_j = state_row[1];
4906+
w_aB_j = state_row[2];
4907+
w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j);
4908+
4909+
result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i)
4910+
* (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j
4911+
/ (n_j - 1);
4912+
}
4913+
}
4914+
4915+
return 0;
4916+
}
4917+
4918+
int
4919+
tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
4920+
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
4921+
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
4922+
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
4923+
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
4924+
double *result)
4925+
{
4926+
int ret = 0;
4927+
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
4928+
if (ret != 0) {
4929+
goto out;
4930+
}
4931+
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
4932+
sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func,
4933+
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
4934+
col_positions, options, result);
4935+
out:
4936+
return ret;
4937+
}
4938+
4939+
static int
4940+
r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
4941+
tsk_size_t result_dim, double *result, void *params)
4942+
{
4943+
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
4944+
const double *state_row;
4945+
tsk_size_t k;
4946+
tsk_id_t i, j;
4947+
double n, pAB, pAb, paB, pA, pB, D_i, D_j, denom_i, denom_j;
4948+
4949+
for (k = 0; k < result_dim; k++) {
4950+
i = args.set_indexes[2 * k];
4951+
j = args.set_indexes[2 * k + 1];
4952+
4953+
n = (double) args.sample_set_sizes[i];
4954+
state_row = GET_2D_ROW(state, 3, i);
4955+
pAB = state_row[0] / n;
4956+
pAb = state_row[1] / n;
4957+
paB = state_row[2] / n;
4958+
pA = pAB + pAb;
4959+
pB = pAB + paB;
4960+
D_i = pAB - (pA * pB);
4961+
denom_i = sqrt(pA * (1 - pA) * pB * (1 - pB));
4962+
4963+
n = (double) args.sample_set_sizes[j];
4964+
state_row = GET_2D_ROW(state, 3, j);
4965+
pAB = state_row[0] / n;
4966+
pAb = state_row[1] / n;
4967+
paB = state_row[2] / n;
4968+
pA = pAB + pAb;
4969+
pB = pAB + paB;
4970+
D_j = pAB - (pA * pB);
4971+
denom_j = sqrt(pA * (1 - pA) * pB * (1 - pB));
4972+
4973+
result[k] = (D_i * D_j) / (denom_i * denom_j);
4974+
}
4975+
return 0;
4976+
}
4977+
4978+
int
4979+
tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
4980+
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
4981+
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
4982+
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
4983+
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
4984+
double *result)
4985+
{
4986+
int ret = 0;
4987+
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
4988+
if (ret != 0) {
4989+
goto out;
4990+
}
4991+
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
4992+
sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func,
4993+
norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites,
4994+
col_positions, options, result);
4995+
out:
4996+
return ret;
4997+
}
4998+
47844999
/***********************************
47855000
* Three way stats
47865001
***********************************/

0 commit comments

Comments
 (0)