Skip to content

C and Python API for two-way two-locus stats #3243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
438 changes: 329 additions & 109 deletions c/tests/test_stats.c

Large diffs are not rendered by default.

263 changes: 239 additions & 24 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2225,15 +2225,15 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state,
}

static int
norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
double n;
tsk_size_t k;

for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
weight_row = GET_2D_ROW(hap_weights, 3, k);
n = (double) args.sample_set_sizes[k];
// TODO: what to do when n = 0
Expand All @@ -2243,12 +2243,38 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
}

static int
norm_total_weighted(tsk_size_t state_dim, const double *TSK_UNUSED(hap_weights),
norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
double ni, nj, wAB_i, wAB_j;
tsk_id_t i, j;
tsk_size_t k;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];
ni = (double) args.sample_set_sizes[i];
nj = (double) args.sample_set_sizes[j];
weight_row = GET_2D_ROW(hap_weights, 3, i);
wAB_i = weight_row[0];
weight_row = GET_2D_ROW(hap_weights, 3, j);
wAB_j = weight_row[0];

result[k] = (wAB_i + wAB_j) / (ni + nj);
}

return 0;
}

static int
norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights),
tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
{
tsk_size_t k;

for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
result[k] = 1 / (double) (n_a * n_b);
}
return 0;
Expand All @@ -2268,9 +2294,6 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n)
}
}

typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *params);

static int
compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles,
Expand All @@ -2290,14 +2313,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
// a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
tsk_size_t k, mut_a, mut_b;
tsk_size_t row_len = num_b_alleles * state_dim;
tsk_size_t result_row_len = num_b_alleles * result_dim;
tsk_size_t w_A = 0, w_B = 0, w_AB = 0;
uint8_t polarised_val = polarised ? 1 : 0;
double *hap_weight_row;
double *result_tmp_row;
double *weights = tsk_malloc(3 * state_dim * sizeof(*weights));
double *norm = tsk_malloc(state_dim * sizeof(*norm));
double *result_tmp = tsk_malloc(row_len * num_a_alleles * sizeof(*result_tmp));
double *norm = tsk_malloc(result_dim * sizeof(*norm));
double *result_tmp
= tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp));

tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples));
tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples));
Expand Down Expand Up @@ -2327,7 +2351,7 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
}

for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) {
result_tmp_row = GET_2D_ROW(result_tmp, row_len, mut_a);
result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a);
for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) {
tsk_bit_array_get_row(site_a_state, mut_a, &A_samples);
tsk_bit_array_get_row(site_b_state, mut_b, &B_samples);
Expand All @@ -2352,15 +2376,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
if (ret != 0) {
goto out;
}
ret = norm_f(state_dim, weights, num_a_alleles - polarised_val,
ret = norm_f(result_dim, weights, num_a_alleles - polarised_val,
num_b_alleles - polarised_val, norm, f_params);
if (ret != 0) {
goto out;
}
for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
result[k] += result_tmp_row[k] * norm[k];
}
result_tmp_row += state_dim; // Advance to the next column
result_tmp_row += result_dim; // Advance to the next column
}
}

Expand Down Expand Up @@ -2538,8 +2562,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
get_site_row_col_indices(
n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx);

// We rely on n_sites to allocate these arrays, they're initialized to NULL for safe
// deallocation if the previous allocation fails
// We rely on n_sites to allocate these arrays, which are initialized
// to NULL for safe deallocation if the previous allocation fails
num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles));
site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets));
if (num_alleles == NULL || site_offsets == NULL) {
Expand Down Expand Up @@ -3195,7 +3219,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di
return ret;
}

static int
int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
Expand All @@ -3209,7 +3233,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
tsk_bit_array_t sample_sets_bits;
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
// double default_windows[] = { 0, self->tables->sequence_length };
tsk_size_t state_dim = num_sample_sets;
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
Expand All @@ -3232,17 +3255,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES);
goto out;
}
// TODO: impossible until we implement branch/windows
// if (result_dim < 1) {
// ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
// goto out;
// }
ret = tsk_treeseq_check_sample_sets(
self, num_sample_sets, sample_set_sizes, sample_sets);
if (ret != 0) {
goto out;
}
tsk_bug_assert(state_dim > 0);
if (result_dim < 1) {
ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
goto out;
}
ret = sample_sets_to_bit_array(
self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits);
if (ret != 0) {
Expand Down Expand Up @@ -4781,6 +4802,200 @@ tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
return ret;
}

static int
D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *state_row;
double n;
tsk_size_t k;
tsk_id_t i, j;
double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];

n = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_i = p_AB - (p_A * p_B);

n = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_j = p_AB - (p_A * p_B);

result[k] = D_i * D_j;
}

return 0;
}

int
tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func,
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

static int
D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *state_row;
tsk_size_t k;
tsk_id_t i, j;
double n_i, n_j;
double w_AB_i, w_Ab_i, w_aB_i, w_ab_i;
double w_AB_j, w_Ab_j, w_aB_j, w_ab_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];
if (i == j) {
// We require disjoint sample sets because we test equality here
n_i = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
w_AB_i = state_row[0];
w_Ab_i = state_row[1];
w_aB_i = state_row[2];
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);
result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1)
+ w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1)
- 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i)
/ n_i / (n_i - 1) / (n_i - 2) / (n_i - 3);
}

else {
n_i = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
w_AB_i = state_row[0];
w_Ab_i = state_row[1];
w_aB_i = state_row[2];
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);

n_j = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
w_AB_j = state_row[0];
w_Ab_j = state_row[1];
w_aB_j = state_row[2];
w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j);

result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i)
* (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j
/ (n_j - 1);
}
}

return 0;
}

int
tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func,
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

static int
r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *state_row;
tsk_size_t k;
tsk_id_t i, j;
double n, pAB, pAb, paB, pA, pB, D_i, D_j, denom_i, denom_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];

n = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
pAB = state_row[0] / n;
pAb = state_row[1] / n;
paB = state_row[2] / n;
pA = pAB + pAb;
pB = pAB + paB;
D_i = pAB - (pA * pB);
denom_i = sqrt(pA * (1 - pA) * pB * (1 - pB));

n = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
pAB = state_row[0] / n;
pAb = state_row[1] / n;
paB = state_row[2] / n;
pA = pAB + pAb;
pB = pAB + paB;
D_j = pAB - (pA * pB);
denom_j = sqrt(pA * (1 - pA) * pB * (1 - pB));

result[k] = (D_i * D_j) / (denom_i * denom_j);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will spuriously produce a NAN if the allele is not variable in one of the two sample sets. From our discussion earlier this week, I think we can calculate result = (D_i * D_j) / denom, where denom = pA * (1 - pA) * pB * (1 - pB), and pA is (pA_i + pA_j) / (n_i + n_j) and likewise for pB.

Copy link
Contributor Author

@lkirk lkirk Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we can combine the allele frequencies this way. Consider this example:

>>> Ai = 0.3; Bi = 0.5; Aj = 0.7; Bj = 0.2; ni = 37; nj = 39
>>> A = (Ai + Aj) / (ni + nj)
>>> B = (Bi + Bj) / (ni + nj)
>>> A * (1-A) * B * (1-B)
0.00011849496867350617
>>> (Ai*(1-Ai)*Bi*(1-Bi)) * (Aj*(1-Aj)*Bj*(1-Bj))
0.0017640000000000006

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I miswrote that. I meant pA is (nA_i + nA_j) / (n_i + n_j)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I added this change. I also added some test coverage for unequal sample set sizes, so that we can see the results of this change and how the stats produced vary -- it can be quite a bit in some cases.

}
return 0;
}

int
tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func,
norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

/***********************************
* Three way stats
***********************************/
Expand Down
Loading
Loading