Skip to content

Commit e122455

Browse files
committed
API: rank with nullable dtypes preserve NA
1 parent 3ea783e commit e122455

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Other enhancements
9595
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)
9696
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
9797
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
98+
- :meth:`Series.rank` and :meth:`DataFrame.rank` with numpy-nullable dtypes preserve ``NA`` values and return ``UInt64`` dtype where appropriate instead of casting ``NA`` to ``NaN`` with ``float64`` dtype (:issue:`??`)
9899

99100
.. ---------------------------------------------------------------------------
100101
.. _whatsnew_300.notable_bug_fixes:

pandas/core/arrays/masked.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313

1414
from pandas._libs import (
15+
algos as libalgos,
1516
lib,
1617
missing as libmissing,
1718
)
@@ -992,6 +993,49 @@ def copy(self) -> Self:
992993
mask = self._mask.copy()
993994
return self._simple_new(data, mask)
994995

996+
def _rank(
997+
self,
998+
*,
999+
axis: AxisInt = 0,
1000+
method: str = "average",
1001+
na_option: str = "keep",
1002+
ascending: bool = True,
1003+
pct: bool = False,
1004+
):
1005+
# Avoid going through copy-making ensure_data in algorithms.rank
1006+
if axis != 0 or self.ndim != 1:
1007+
raise NotImplementedError
1008+
1009+
from pandas.core.arrays import FloatingArray
1010+
1011+
data = self._data
1012+
if data.dtype.kind == "b":
1013+
data = data.view("uint8")
1014+
1015+
result = libalgos.rank_1d(
1016+
data,
1017+
is_datetimelike=False,
1018+
ties_method=method,
1019+
ascending=ascending,
1020+
na_option=na_option,
1021+
pct=pct,
1022+
mask=self.isna(),
1023+
)
1024+
if na_option in ["top", "bottom"]:
1025+
mask = np.zeros(self.shape, dtype=bool)
1026+
else:
1027+
mask = self._mask.copy()
1028+
1029+
if method != "average" and not pct:
1030+
if na_option not in ["top", "bottom"]:
1031+
result[self._mask] = 0 # avoid warning on casting
1032+
result = result.astype("uint64", copy=False)
1033+
from pandas.core.arrays import IntegerArray
1034+
1035+
return IntegerArray(result, mask=mask)
1036+
1037+
return FloatingArray(result, mask=mask)
1038+
9951039
@doc(ExtensionArray.duplicated)
9961040
def duplicated(
9971041
self, keep: Literal["first", "last", False] = "first"

pandas/tests/series/methods/test_rank.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def expected_dtype(dtype, method, pct=False):
6868
exp_dtype = "double[pyarrow]"
6969
else:
7070
exp_dtype = "uint64[pyarrow]"
71+
elif dtype in ["Float64", "Int64"]:
72+
if method == "average" or pct:
73+
exp_dtype = "Float64"
74+
else:
75+
exp_dtype = "UInt64"
7176

7277
return exp_dtype
7378

@@ -257,7 +262,7 @@ def test_rank_nullable_integer(self):
257262
exp = Series([None, 2, None, 3, 3, 2, 3, 1], dtype="Int64")
258263
result = exp.rank(na_option="keep")
259264

260-
expected = Series([np.nan, 2.5, np.nan, 5.0, 5.0, 2.5, 5.0, 1.0])
265+
expected = Series([None, 2.5, None, 5.0, 5.0, 2.5, 5.0, 1.0], dtype="Float64")
261266

262267
tm.assert_series_equal(result, expected)
263268

@@ -302,6 +307,11 @@ def test_rank_tie_methods_on_infs_nans(
302307
exp_dtype = "float64[pyarrow]"
303308
else:
304309
exp_dtype = "uint64[pyarrow]"
310+
elif dtype == "Float64":
311+
if rank_method == "average":
312+
exp_dtype = "Float64"
313+
else:
314+
exp_dtype = "UInt64"
305315
else:
306316
exp_dtype = "float64"
307317

@@ -327,7 +337,8 @@ def test_rank_tie_methods_on_infs_nans(
327337
result = iseries.rank(
328338
method=rank_method, na_option=na_option, ascending=ascending
329339
)
330-
tm.assert_series_equal(result, Series(expected, dtype=exp_dtype))
340+
exp_ser = Series(expected, dtype=exp_dtype)
341+
tm.assert_series_equal(result, exp_ser)
331342

332343
def test_rank_desc_mix_nans_infs(self):
333344
# GH 19538
@@ -439,7 +450,7 @@ def test_rank_ea_small_values(self):
439450
dtype="Float64",
440451
)
441452
result = ser.rank(method="min")
442-
expected = Series([4, 1, 3, np.nan, 2])
453+
expected = Series([4, 1, 3, NA, 2], dtype="UInt64")
443454
tm.assert_series_equal(result, expected)
444455

445456

0 commit comments

Comments
 (0)