Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions _tsinfermodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,19 @@ static PyObject *
AncestorBuilder_add_site(AncestorBuilder *self, PyObject *args, PyObject *kwds)
{
int err;
static char *kwlist[] = {"time", "genotypes", NULL};
static char *kwlist[] = {"time", "genotypes", "terminal", NULL};
PyObject *ret = NULL;
double time;
PyObject *genotypes = NULL;
PyArrayObject *genotypes_array = NULL;
npy_intp *shape;
int terminal = 0;

if (AncestorBuilder_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO", kwlist,
&time, &genotypes)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO|p", kwlist,
&time, &genotypes, &terminal)) {
goto out;
}
genotypes_array = (PyArrayObject *) PyArray_FROM_OTF(genotypes, NPY_INT8,
Expand All @@ -166,7 +167,7 @@ AncestorBuilder_add_site(AncestorBuilder *self, PyObject *args, PyObject *kwds)
}
Py_BEGIN_ALLOW_THREADS
err = ancestor_builder_add_site(self->builder, time,
(allele_t *) PyArray_DATA(genotypes_array));
(allele_t *) PyArray_DATA(genotypes_array), terminal);
Py_END_ALLOW_THREADS
if (err != 0) {
handle_library_error(err);
Expand Down
30 changes: 20 additions & 10 deletions lib/ancestor_builder.c
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ ancestor_builder_compute_ancestral_states(const ancestor_builder_t *self, int di
/* (int) min_sample_set_size); */
for (l = focal_site + direction; l >= 0 && l < (int64_t) num_sites; l += direction) {
/* printf("\tl = %d\n", (int) l); */
if (sites[l].terminal) {
break;
}
ancestor[l] = 0;
last_site = (tsk_id_t) l;

Expand Down Expand Up @@ -653,7 +656,8 @@ ancestor_builder_allocate_genotypes(ancestor_builder_t *self)
}

int WARN_UNUSED
ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genotypes)
ancestor_builder_add_site(
ancestor_builder_t *self, double time, allele_t *genotypes, bool terminal)
{
int ret = 0;
site_t *site;
Expand All @@ -665,21 +669,30 @@ ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genot
avl_tree_t *pattern_map;
tsk_id_t site_id = (tsk_id_t) self->num_sites;
size_t derived_count, j;
time_map_t *time_map = ancestor_builder_get_time_map(self, time);
time_map_t *time_map = NULL;

if (self->num_sites == self->max_sites) {
ret = TSI_ERR_TOO_MANY_SITES;
goto out;
}
derived_count = 0;
for (j = 0; j < (size_t) self->num_samples; j++) {
if (genotypes[j] == 1) {
derived_count++;
}
}

if (time_map == NULL) {
ret = TSI_ERR_NO_MEMORY;
site = &self->sites[site_id];
site->time = time;
site->derived_count = derived_count;
site->terminal = terminal;
if (terminal) {
site->encoded_genotypes = NULL;
self->num_sites++;
goto out;
}
if (self->num_sites == self->max_sites) {
ret = TSI_ERR_TOO_MANY_SITES;
time_map = ancestor_builder_get_time_map(self, time);
if (time_map == NULL) {
ret = TSI_ERR_NO_MEMORY;
goto out;
}
ret = ancestor_builder_encode_genotypes(self, genotypes, encoded_genotypes);
Expand All @@ -688,9 +701,6 @@ ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genot
}
self->num_sites++;
pattern_map = &time_map->pattern_map;
site = &self->sites[site_id];
site->time = time;
site->derived_count = derived_count;

search.encoded_genotypes = encoded_genotypes;
search.encoded_genotypes_size = self->encoded_genotypes_size;
Expand Down
10 changes: 5 additions & 5 deletions lib/tests/tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ run_random_data(size_t num_samples, size_t num_sites, int seed,
genotypes[k] = samples[k][j];
time += genotypes[k];
}
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes);
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes, false);
CU_ASSERT_EQUAL_FATAL(ret, 0);
}
/* ancestor_builder_print_state(&ancestor_builder, stdout); */
Expand Down Expand Up @@ -478,15 +478,15 @@ test_ancestor_builder_errors(void)
ret = ancestor_builder_alloc(&ancestor_builder, 2, 0, -1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 0);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, false);
CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_TOO_MANY_SITES);
ancestor_builder_free(&ancestor_builder);

ret = ancestor_builder_alloc(&ancestor_builder, 4, 2, -1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros, false);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, false);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 2);
ret = ancestor_builder_finalise(&ancestor_builder);
Expand All @@ -509,7 +509,7 @@ test_ancestor_builder_one_site(void)

ret = ancestor_builder_alloc(&ancestor_builder, 4, 1, -1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes);
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes, false);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = ancestor_builder_finalise(&ancestor_builder);
CU_ASSERT_EQUAL_FATAL(ret, 0);
Expand Down
3 changes: 2 additions & 1 deletion lib/tsinfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ typedef struct {
double time;
uint8_t *encoded_genotypes;
tsk_size_t derived_count;
bool terminal;
} site_t;

typedef struct {
Expand Down Expand Up @@ -251,7 +252,7 @@ int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples,
int ancestor_builder_free(ancestor_builder_t *self);
int ancestor_builder_print_state(ancestor_builder_t *self, FILE *out);
int ancestor_builder_add_site(
ancestor_builder_t *self, double time, allele_t *genotypes);
ancestor_builder_t *self, double time, allele_t *genotypes, bool terminal);
int ancestor_builder_finalise(ancestor_builder_t *self);
int ancestor_builder_make_ancestor(const ancestor_builder_t *self,
size_t num_focal_sites, const tsk_id_t *focal_sites, tsk_id_t *start, tsk_id_t *end,
Expand Down
25 changes: 15 additions & 10 deletions tsinfer/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Site:
id = attr.ib()
time = attr.ib()
derived_count = attr.ib()
terminal = attr.ib()


class AncestorBuilder:
Expand Down Expand Up @@ -137,21 +138,23 @@ def store_site_genotypes(self, site_id, genotypes):
stop = start + self.encoded_genotypes_size
self.genotype_store[start:stop] = genotypes

def add_site(self, time, genotypes):
def add_site(self, time, genotypes, terminal):
"""
Adds a new site at the specified ID to the builder.
"""
site_id = len(self.sites)
derived_count = np.sum(genotypes == 1)
self.store_site_genotypes(site_id, genotypes)
self.sites.append(Site(site_id, time, derived_count))
sites_at_fixed_timepoint = self.time_map[time]
# Sites with an identical variant distribution (i.e. with the same
# genotypes.tobytes() value) and at the same time, are put into the same ancestor
# to which we allocate a unique ID (just use the genotypes value)
ancestor_uid = tuple(genotypes)
# Add each site to the list for this ancestor_uid at this timepoint
sites_at_fixed_timepoint[ancestor_uid].append(site_id)
self.sites.append(Site(site_id, time, derived_count, terminal))
if not terminal:
self.store_site_genotypes(site_id, genotypes)
sites_at_fixed_timepoint = self.time_map[time]
# Sites with an identical variant distribution (i.e. with the same
# genotypes.tobytes() value) and at the same time, are put into the
# same ancestor to which we allocate a unique ID (just use the genotypes
# value)
ancestor_uid = tuple(genotypes)
# Add each site to the list for this ancestor_uid at thigs timepoint
sites_at_fixed_timepoint[ancestor_uid].append(site_id)

def print_state(self):
print("Ancestor builder")
Expand Down Expand Up @@ -221,6 +224,8 @@ def compute_ancestral_states(self, a, focal_site, sites):
disagree = np.zeros(self.num_samples, dtype=bool)

for site_index in sites:
if self.sites[site_index].terminal:
break
a[site_index] = 0
last_site = site_index
g_l = self.get_site_genotypes(site_index)
Expand Down
36 changes: 29 additions & 7 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,14 @@ class AncestorData(DataContainer):
FORMAT_NAME = "tsinfer-ancestor-data"
FORMAT_VERSION = (3, 0)

def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
def __init__(
self,
inference_position,
terminal_position,
sequence_length,
chunk_size_sites=None,
**kwargs,
):
super().__init__(**kwargs)
self._last_time = 0
self.inference_sites_set = False
Expand All @@ -3111,15 +3118,22 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
self.create_dataset("sample_end", dtype=np.int32)
self.create_dataset("sample_time", dtype=np.float64)
self.create_dataset("sample_focal_sites", dtype="array:i4")

variant_position = np.concatenate([inference_position, terminal_position])
self.create_dataset(
"variant_position",
data=position,
shape=position.shape,
data=variant_position,
shape=variant_position.shape,
chunks=self._chunk_size_sites,
dtype=np.float64,
dimensions=["variants"],
)
self.create_dataset(
"terminal_position",
data=terminal_position,
shape=terminal_position.shape,
dtype=np.float64,
dimensions=["terminal_sites"],
)

# We have to include a ploidy dimension sgkit compatibility
a = self.create_dataset(
Expand Down Expand Up @@ -3277,10 +3291,17 @@ def num_sites(self):
@property
def sites_position(self):
"""
The positions of the inference sites used to generate the ancestors
The positions of the inference and terminal sites used to generate the ancestors
"""
return self.data["variant_position"]

@property
def terminal_position(self):
"""
The positions of the terminal sites used to generate the ancestors
"""
return self.data["terminal_position"]

@property
def ancestors_start(self):
return self.data["sample_start"]
Expand Down Expand Up @@ -3314,10 +3335,10 @@ def ancestors_length(self):
"""
# Ancestor start and end are half-closed. The last site is assumed
# to cover the region up to sequence length.
pos = np.hstack([self.sites_position[:], [self.sequence_length]])

start = self.ancestors_start[:]
end = self.ancestors_end[:]
return pos[end] - pos[start]
return self.sites_position[end] - self.sites_position[start]

def insert_proxy_samples(
self,
Expand Down Expand Up @@ -3683,6 +3704,7 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype):
if start < 0:
raise ValueError("Start must be >= 0")
if end > self.num_sites:
print(f"[INFO] {end}, {self.num_sites}")
raise ValueError("end must be <= num_sites")
if start >= end:
raise ValueError("start must be < end")
Expand Down
Loading
Loading