diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index a92da4f252..37eb492119 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -2311,6 +2311,29 @@ def test_shapes(self, proportion): class TestBranchGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): mode = "branch" + def test_single_sample_set_self_comparison(self, ts_12_highrecomb_fixture): + # Test for issue #3055 - self-comparisons with single sample set + ts = ts_12_highrecomb_fixture + # Single sample set with self-comparison + result = ts.genetic_relatedness([[0]], indexes=[(0, 0)], mode="branch") + assert result.shape == (1,) + # Should work for multiple samples in single set too + result = ts.genetic_relatedness([[0, 1, 2]], indexes=[(0, 0)], mode="branch") + assert result.shape == (1,) + + def test_single_sample_set_invalid_indexes(self, ts_12_highrecomb_fixture): + # Test that invalid indexes raise ValueError with single sample set + ts = ts_12_highrecomb_fixture + # Index out of bounds (only have 1 sample set, but trying to access index 1) + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(0, 1)], mode="branch") + # Negative index + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(-1, 0)], mode="branch") + # Both indexes out of bounds + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0, 1]], indexes=[(2, 2)], mode="branch") + @pytest.mark.parametrize("polarised", [True, False]) def test_simple_tree_noncentred(self, polarised): # 2.00┊ 4 ┊ @@ -2365,10 +2388,61 @@ def test_simple_tree_noncentred(self, polarised): class TestNodeGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): mode = "node" + def test_single_sample_set_self_comparison(self, ts_12_highrecomb_fixture): + # Test for issue #3055 - self-comparisons with single sample set + ts = ts_12_highrecomb_fixture + # Single sample set with self-comparison + result = ts.genetic_relatedness([[0]], indexes=[(0, 0)], mode="node") + assert result.shape == (ts.num_nodes, 1) + # Should work for multiple samples in single set too + result = ts.genetic_relatedness([[0, 1, 2]], indexes=[(0, 0)], mode="node") + assert result.shape == (ts.num_nodes, 1) + + def test_single_sample_set_invalid_indexes(self, ts_12_highrecomb_fixture): + # Test that invalid indexes raise ValueError with single sample set + ts = ts_12_highrecomb_fixture + # Index out of bounds (only have 1 sample set, but trying to access index 1) + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(0, 1)], mode="node") + # Negative index + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(-1, 0)], mode="node") + # Both indexes out of bounds + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0, 1]], indexes=[(2, 2)], mode="node") + class TestSiteGeneticRelatedness(TestGeneticRelatedness, MutatedTopologyExamplesMixin): mode = "site" + def test_single_sample_set_self_comparison(self, ts_12_highrecomb_fixture): + # Test for issue #3055 - self-comparisons with single sample set + ts = ts_12_highrecomb_fixture + # Single sample set with self-comparison + result = ts.genetic_relatedness([[0]], indexes=[(0, 0)], mode="site") + assert result.shape == (1,) + # Should work for multiple samples in single set too + result = ts.genetic_relatedness([[0, 1, 2]], indexes=[(0, 0)], mode="site") + assert result.shape == (1,) + # Test with multiple self-comparisons + result = ts.genetic_relatedness( + [[0], [1]], indexes=[(0, 0), (1, 1)], mode="site" + ) + assert result.shape == (2,) + + def test_single_sample_set_invalid_indexes(self, ts_12_highrecomb_fixture): + # Test that invalid indexes raise ValueError with single sample set + ts = ts_12_highrecomb_fixture + # Index out of bounds (only have 1 sample set, but trying to access index 1) + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(0, 1)], mode="site") + # Negative index + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0]], indexes=[(-1, 0)], mode="site") + # Both indexes out of bounds + with pytest.raises(ValueError, match="Index out of bounds"): + ts.genetic_relatedness([[0, 1]], indexes=[(2, 2)], mode="site") + def test_match_K_c0(self): # This test checks that ts.genetic_relatedness() matches K_c0 # from Speed & Balding (2014) https://www.nature.com/articles/nrg3821 diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e740181f81..dc942f10f7 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8106,6 +8106,7 @@ def __k_way_sample_set_stat( span_normalise=True, polarised=False, centre=True, + allow_self_comparisons=False, ): sample_set_sizes = np.array( [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 @@ -8132,6 +8133,24 @@ def __k_way_sample_set_stat( "Indexes must be convertable to a 2D numpy array with {} " "columns".format(k) ) + # For genetic_relatedness, we allow self-comparisons with a single sample set + if allow_self_comparisons and len(sample_sets) < k: + # Check that all indexes are valid + if np.any(indexes >= len(sample_sets)) or np.any(indexes < 0): + raise ValueError("Index out of bounds") + # Find which sample sets we actually need + unique_indexes = np.unique(indexes) + if np.max(unique_indexes) < len(sample_sets): + # we need to pad with dummy sets to satisfy the C side + # requirement of having at least k sets + sample_sets = list(sample_sets) + sample_set_sizes = list(sample_set_sizes) + while len(sample_sets) < k: + # Add a dummy sample set that won't be used + sample_sets.append(np.array([sample_sets[0][0]], dtype=np.int32)) + sample_set_sizes.append(1) + sample_set_sizes = np.array(sample_set_sizes, dtype=np.uint32) + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) stat = self.__run_windowed_stat( windows, ll_method, @@ -8702,6 +8721,7 @@ def genetic_relatedness( span_normalise=span_normalise, polarised=polarised, centre=centre, + allow_self_comparisons=True, ) if proportion: # TODO this should be done in C also