diff --git a/doc/source/user_guide/indexing.rst b/doc/source/user_guide/indexing.rst index 605f9501c5b23..47ff92c163b01 100644 --- a/doc/source/user_guide/indexing.rst +++ b/doc/source/user_guide/indexing.rst @@ -700,7 +700,7 @@ to have different probabilities, you can pass the ``sample`` function sampling w s = pd.Series([0, 1, 2, 3, 4, 5]) example_weights = [0, 0, 0.2, 0.2, 0.2, 0.4] - s.sample(n=3, weights=example_weights) + s.sample(n=2, weights=example_weights) # Weights will be re-normalized automatically example_weights2 = [0.5, 0, 0, 0, 0, 0] @@ -714,7 +714,7 @@ as a string. df2 = pd.DataFrame({'col1': [9, 8, 7, 6], 'weight_column': [0.5, 0.4, 0.1, 0]}) - df2.sample(n=3, weights='weight_column') + df2.sample(n=2, weights='weight_column') ``sample`` also allows users to sample columns instead of rows using the ``axis`` argument. diff --git a/doc/source/whatsnew/v0.16.1.rst b/doc/source/whatsnew/v0.16.1.rst index b376530358f53..c15f56ba61447 100644 --- a/doc/source/whatsnew/v0.16.1.rst +++ b/doc/source/whatsnew/v0.16.1.rst @@ -196,7 +196,7 @@ facilitate replication. (:issue:`2419`) # weights are accepted. example_weights = [0, 0, 0.2, 0.2, 0.2, 0.4] - example_series.sample(n=3, weights=example_weights) + example_series.sample(n=2, weights=example_weights) # weights will also be normalized if they do not sum to one, # and missing values will be treated as zeros. @@ -210,7 +210,7 @@ when sampling from rows. .. ipython:: python df = pd.DataFrame({"col1": [9, 8, 7, 6], "weight_column": [0.5, 0.4, 0.1, 0]}) - df.sample(n=3, weights="weight_column") + df.sample(n=2, weights="weight_column") .. _whatsnew_0161.enhancements.string: diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 4e0e497379fa2..d0f4a2b3fc0a1 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -910,6 +910,7 @@ Other - Bug in :meth:`DataFrame.query` where using duplicate column names led to a ``TypeError``. (:issue:`59950`) - Bug in :meth:`DataFrame.query` which raised an exception or produced incorrect results when expressions contained backtick-quoted column names containing the hash character ``#``, backticks, or characters that fall outside the ASCII range (U+0001..U+007F). (:issue:`59285`) (:issue:`49633`) - Bug in :meth:`DataFrame.query` which raised an exception when querying integer column names using backticks. (:issue:`60494`) +- Bug in :meth:`DataFrame.sample` with ``replace=False`` and ``(n * max(weights) / sum(weights)) > 1``, the method would return biased results. Now raises ``ValueError``. (:issue:`61516`) - Bug in :meth:`DataFrame.shift` where passing a ``freq`` on a DataFrame with no columns did not shift the index correctly. (:issue:`60102`) - Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`) - Bug in :meth:`DataFrame.sort_values` where sorting by a column explicitly named ``None`` raised a ``KeyError`` instead of sorting by the column as expected. (:issue:`61512`) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 7f1ccc482f70f..8708de68c0860 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5814,6 +5814,8 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. + When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1`` + in order to avoid biased results. See the Notes below for more details. random_state : int, array-like, BitGenerator, np.random.RandomState, np.random.Generator, optional If int, array-like, or BitGenerator, seed for random number generator. If np.random.RandomState or np.random.Generator, use as given. @@ -5850,6 +5852,11 @@ def sample( ----- If `frac` > 1, `replacement` should be set to `True`. + When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1``, + since that would cause results to be biased. E.g. sampling 2 items without replacement + with weights [100, 1, 1] would yield two last items in 1/2 of cases, instead of 1/102. + This is similar to specifying `n=4` without replacement on a Series with 3 elements. + Examples -------- >>> df = pd.DataFrame( diff --git a/pandas/core/sample.py b/pandas/core/sample.py index 4f12563e3c5e2..4f476540cf406 100644 --- a/pandas/core/sample.py +++ b/pandas/core/sample.py @@ -150,6 +150,14 @@ def sample( else: raise ValueError("Invalid weights: weights sum to zero") + assert weights is not None # for mypy + if not replace and size * weights.max() > 1: + raise ValueError( + "Weighted sampling cannot be achieved with replace=False. Either " + "set replace=True or use smaller weights. See the docstring of " + "sample for details." + ) + return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype( np.intp, copy=False ) diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index a9d56cbfd2b46..9b6660778508e 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -113,9 +113,6 @@ def test_sample_invalid_weight_lengths(self, obj): with pytest.raises(ValueError, match=msg): obj.sample(n=3, weights=[0.5] * 11) - with pytest.raises(ValueError, match="Fewer non-zero entries in p than size"): - obj.sample(n=4, weights=Series([0, 0, 0.2])) - def test_sample_negative_weights(self, obj): # Check won't accept negative weights bad_weights = [-0.1] * 10 @@ -137,6 +134,33 @@ def test_sample_inf_weights(self, obj): with pytest.raises(ValueError, match=msg): obj.sample(n=3, weights=weights_with_ninf) + def test_sample_unit_probabilities_raises(self, obj): + # GH#61516 + high_variance_weights = [1] * 10 + high_variance_weights[0] = 100 + msg = ( + "Weighted sampling cannot be achieved with replace=False. Either " + "set replace=True or use smaller weights. See the docstring of " + "sample for details." + ) + with pytest.raises(ValueError, match=msg): + obj.sample(n=2, weights=high_variance_weights, replace=False) + + def test_sample_unit_probabilities_edge_case_do_not_raise(self, obj): + # GH#61516 + # edge case, n*max(weights)/sum(weights) == 1 + edge_variance_weights = [1] * 10 + edge_variance_weights[0] = 9 + # should not raise + obj.sample(n=2, weights=edge_variance_weights, replace=False) + + def test_sample_unit_normal_probabilities_do_not_raise(self, obj): + # GH#61516 + low_variance_weights = [1] * 10 + low_variance_weights[0] = 8 + # should not raise + obj.sample(n=2, weights=low_variance_weights, replace=False) + def test_sample_zero_weights(self, obj): # All zeros raises errors