diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 07cbf489cfe1c..5dda2d914366c 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -304,23 +304,29 @@ def _str_contains( def _str_match( self, - pat: str, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if not pat.startswith("^"): + if isinstance(pat, re.Pattern): + # GH#61952 + pat = pat.pattern + if isinstance(pat, str) and not pat.startswith("^"): pat = f"^{pat}" return self._str_contains(pat, case, flags, na, regex=True) def _str_fullmatch( self, - pat, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): - if not pat.endswith("$") or pat.endswith("\\$"): + if isinstance(pat, re.Pattern): + # GH#61952 + pat = pat.pattern + if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")): pat = f"{pat}$" return self._str_match(pat, case, flags, na) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index c108808905dc7..21e6e2efbe778 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -1361,8 +1361,8 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default): Parameters ---------- - pat : str - Character sequence. + pat : str or compiled regex + Character sequence or regular expression. case : bool, default True If True, case sensitive. flags : int, default 0 (no flags) diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 0adb7b51cf2b7..c1d81fc3d7223 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -248,14 +248,15 @@ def rep(x, r): def _str_match( self, - pat: str, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar | lib.NoDefault = lib.no_default, ): if not case: flags |= re.IGNORECASE - + if isinstance(pat, re.Pattern): + pat = pat.pattern regex = re.compile(pat, flags=flags) f = lambda x: regex.match(x) is not None @@ -270,7 +271,8 @@ def _str_fullmatch( ): if not case: flags |= re.IGNORECASE - + if isinstance(pat, re.Pattern): + pat = pat.pattern regex = re.compile(pat, flags=flags) f = lambda x: regex.fullmatch(x) is not None diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 30e6ebf0eed13..567ef315366b1 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -818,6 +818,17 @@ def test_match_case_kwarg(any_string_dtype): tm.assert_series_equal(result, expected) +def test_match_compiled_regex(any_string_dtype): + # GH#61952 + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + result = values.str.match(re.compile(r"ab"), case=False) + expected_dtype = ( + np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" + ) + expected = Series([True, True, True, True], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # -------------------------------------------------------------------------------------- # str.fullmatch # -------------------------------------------------------------------------------------- @@ -887,6 +898,17 @@ def test_fullmatch_case_kwarg(any_string_dtype): tm.assert_series_equal(result, expected) +def test_fullmatch_compiled_regex(any_string_dtype): + # GH#61952 + values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype) + result = values.str.fullmatch(re.compile(r"ab"), case=False) + expected_dtype = ( + np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean" + ) + expected = Series([True, True, False, False], dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + # -------------------------------------------------------------------------------------- # str.findall # --------------------------------------------------------------------------------------