diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 0b9ff088e5..04ed2b1ceb 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -10775,6 +10775,85 @@ test_check_integrity_bad_mutation_parent_topology(void) tsk_table_collection_free(&tables); } +static void +test_table_collection_compute_mutation_parents_tolerates_invalid_input(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_id_t site; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1.0; + + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 0, 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + site = tsk_site_table_add_row(&tables.sites, 0.0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(site >= 0); + ret_id = tsk_mutation_table_add_row( + &tables.mutations, site, 1, TSK_NULL, TSK_UNKNOWN_TIME, "C", 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.mutations.parent[0] = 42; + + ret = tsk_table_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(tables.mutations.parent[0] == TSK_NULL); + + tsk_table_collection_free(&tables); +} + +static void +test_table_collection_compute_mutation_parents_restores_on_error(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_id_t site; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1.0; + + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 0, 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + site = tsk_site_table_add_row(&tables.sites, 0.5, "A", 1, NULL, 0); + CU_ASSERT_FATAL(site >= 0); + + ret_id = tsk_mutation_table_add_row( + &tables.mutations, site, 1, TSK_NULL, TSK_UNKNOWN_TIME, "C", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables.mutations, site, 0, TSK_NULL, TSK_UNKNOWN_TIME, "G", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.mutations.parent[0] = 111; + tables.mutations.parent[1] = 222; + + ret = tsk_table_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_PARENT_AFTER_CHILD); + CU_ASSERT_EQUAL(tables.mutations.parent[0], 111); + CU_ASSERT_EQUAL(tables.mutations.parent[1], 222); + + tsk_table_collection_free(&tables); +} + static void test_table_collection_subset_with_options(tsk_flags_t options) { @@ -11934,6 +12013,10 @@ main(int argc, char **argv) test_table_collection_check_integrity_bad_indexes }, { "test_check_integrity_bad_mutation_parent_topology", test_check_integrity_bad_mutation_parent_topology }, + { "test_table_collection_compute_mutation_parents_tolerates_invalid_input", + test_table_collection_compute_mutation_parents_tolerates_invalid_input }, + { "test_table_collection_compute_mutation_parents_restores_on_error", + test_table_collection_compute_mutation_parents_restores_on_error }, { "test_table_collection_subset", test_table_collection_subset }, { "test_table_collection_subset_unsorted", test_table_collection_subset_unsorted }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 7106300f3d..4a256df27f 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12436,8 +12436,27 @@ tsk_table_collection_compute_mutation_parents( tsk_table_collection_t *self, tsk_flags_t options) { int ret = 0; + tsk_mutation_table_t *mutations = &self->mutations; + tsk_id_t *parent_backup = NULL; + bool restore_parents = false; if (!(options & TSK_NO_CHECK_INTEGRITY)) { + if (mutations->num_rows > 0) { + /* We need to wipe the parent column before computing, as otherwise invalid + * parents can cause integrity checks to fail. We take a copy to restore on + * error */ + parent_backup = tsk_malloc(mutations->num_rows * sizeof(*parent_backup)); + if (parent_backup == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + tsk_memcpy(parent_backup, mutations->parent, + mutations->num_rows * sizeof(*parent_backup)); + /* Set the parent pointers to TSK_NULL */ + tsk_memset(mutations->parent, 0xff, + mutations->num_rows * sizeof(*mutations->parent)); + restore_parents = true; + } /* Safe to cast here as we're not counting trees */ ret = (int) tsk_table_collection_check_integrity(self, TSK_CHECK_TREES); if (ret < 0) { @@ -12452,6 +12471,11 @@ tsk_table_collection_compute_mutation_parents( } out: + if (ret != 0 && restore_parents) { + tsk_memcpy(mutations->parent, parent_backup, + mutations->num_rows * sizeof(*parent_backup)); + } + tsk_safe_free(parent_backup); return ret; } diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 6971780b47..dc9dfb4c97 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -3540,6 +3540,34 @@ def test_table_references(self): assert str(populations) == before_populations assert str(provenances) == before_provenances + def test_compute_mutation_parents_ignores_existing_values(self): + tables = tskit.TableCollection(sequence_length=1.0) + parent = tables.nodes.add_row(time=1.0) + child = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + tables.edges.add_row(left=0.0, right=1.0, parent=parent, child=child) + site = tables.sites.add_row(position=0.0, ancestral_state="A") + tables.mutations.add_row(site=site, node=child, derived_state="C") + tables.build_index() + tables.mutations.parent[:] = 42 + tables.compute_mutation_parents() + assert tables.mutations.parent[0] == tskit.NULL + + def test_compute_mutation_parents_restores_on_index_error(self): + tables = tskit.TableCollection(sequence_length=1.0) + parent = tables.nodes.add_row(time=1.0) + child = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + tables.edges.add_row(left=0.0, right=1.0, parent=parent, child=child) + site = tables.sites.add_row(position=0.0, ancestral_state="A") + tables.mutations.add_row(site=site, node=child, derived_state="C") + + mutation_columns = tables.mutations.asdict() + mutation_columns["parent"] = np.array([123], dtype=np.int32) + tables.mutations.set_columns(**mutation_columns) + + with pytest.raises(tskit.LibraryError, match="TSK_ERR_TABLES_NOT_INDEXED"): + tables.compute_mutation_parents() + assert tables.mutations.parent[0] == 123 + def test_str(self): ts = msprime.simulate(10, random_seed=1) tables = ts.tables