Skip to content

Commit 62a8c21

Browse files
REGR: fix string contains/match methods with compiled regex with flags
1 parent 1feacde commit 62a8c21

File tree

4 files changed

+128
-23
lines changed

4 files changed

+128
-23
lines changed

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,7 @@ def _str_match(
316316
flags: int = 0,
317317
na: Scalar | lib.NoDefault = lib.no_default,
318318
):
319-
if isinstance(pat, re.Pattern):
320-
# GH#61952
321-
pat = pat.pattern
322-
if isinstance(pat, str) and not pat.startswith("^"):
319+
if not pat.startswith("^"):
323320
pat = f"^{pat}"
324321
return self._str_contains(pat, case, flags, na, regex=True)
325322

@@ -330,10 +327,7 @@ def _str_fullmatch(
330327
flags: int = 0,
331328
na: Scalar | lib.NoDefault = lib.no_default,
332329
):
333-
if isinstance(pat, re.Pattern):
334-
# GH#61952
335-
pat = pat.pattern
336-
if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")):
330+
if not pat.endswith("$") or pat.endswith("\\$"):
337331
pat = f"{pat}$"
338332
return self._str_match(pat, case, flags, na)
339333

pandas/core/arrays/string_arrow.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
ArrayLike,
5656
Dtype,
5757
NpDtype,
58+
Scalar,
5859
npt,
5960
)
6061

@@ -333,8 +334,6 @@ def astype(self, dtype, copy: bool = True):
333334
_str_startswith = ArrowStringArrayMixin._str_startswith
334335
_str_endswith = ArrowStringArrayMixin._str_endswith
335336
_str_pad = ArrowStringArrayMixin._str_pad
336-
_str_match = ArrowStringArrayMixin._str_match
337-
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
338337
_str_lower = ArrowStringArrayMixin._str_lower
339338
_str_upper = ArrowStringArrayMixin._str_upper
340339
_str_strip = ArrowStringArrayMixin._str_strip
@@ -349,6 +348,19 @@ def astype(self, dtype, copy: bool = True):
349348
_str_len = ArrowStringArrayMixin._str_len
350349
_str_slice = ArrowStringArrayMixin._str_slice
351350

351+
@staticmethod
352+
def _preprocess_re_pattern(pat: re.Pattern, case: bool):
353+
flags = pat.flags
354+
pat = pat.pattern
355+
# flags is not supported by pyarrow, but `case` is -> extract and remove
356+
if flags & re.IGNORECASE:
357+
case = False
358+
flags = flags & ~re.IGNORECASE
359+
# when creating a pattern with re.compile and a string, it automatically
360+
# gets a UNICODE flag, while pyarrow assumes unicode for strings anyway
361+
flags = flags & ~re.UNICODE
362+
return pat, case, flags
363+
352364
def _str_contains(
353365
self,
354366
pat,
@@ -360,10 +372,44 @@ def _str_contains(
360372
if flags:
361373
return super()._str_contains(pat, case, flags, na, regex)
362374
if isinstance(pat, re.Pattern):
363-
pat = pat.pattern
375+
pat, case, flags = self._preprocess_re_pattern(pat, case)
376+
if flags:
377+
return super()._str_contains(pat, case, flags, na, regex)
364378

365379
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
366380

381+
def _str_match(
382+
self,
383+
pat: str | re.Pattern,
384+
case: bool = True,
385+
flags: int = 0,
386+
na: Scalar | lib.NoDefault = lib.no_default,
387+
):
388+
if flags:
389+
return super()._str_match(pat, case, flags, na)
390+
if isinstance(pat, re.Pattern):
391+
pat, case, flags = self._preprocess_re_pattern(pat, case)
392+
if flags:
393+
return super()._str_match(pat, case, flags, na)
394+
395+
return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)
396+
397+
def _str_fullmatch(
398+
self,
399+
pat: str | re.Pattern,
400+
case: bool = True,
401+
flags: int = 0,
402+
na: Scalar | lib.NoDefault = lib.no_default,
403+
):
404+
if flags:
405+
return super()._str_fullmatch(pat, case, flags, na)
406+
if isinstance(pat, re.Pattern):
407+
pat, case, flags = self._preprocess_re_pattern(pat, case)
408+
if flags:
409+
return super()._str_fullmatch(pat, case, flags, na)
410+
411+
return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na)
412+
367413
def _str_replace(
368414
self,
369415
pat: str | re.Pattern,

pandas/core/strings/object_array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def _str_match(
262262
):
263263
if not case:
264264
flags |= re.IGNORECASE
265-
if isinstance(pat, re.Pattern):
266-
pat = pat.pattern
265+
267266
regex = re.compile(pat, flags=flags)
268267

269268
f = lambda x: regex.match(x) is not None
@@ -278,8 +277,7 @@ def _str_fullmatch(
278277
):
279278
if not case:
280279
flags |= re.IGNORECASE
281-
if isinstance(pat, re.Pattern):
282-
pat = pat.pattern
280+
283281
regex = re.compile(pat, flags=flags)
284282

285283
f = lambda x: regex.fullmatch(x) is not None

pandas/tests/strings/test_find_replace.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,39 @@ def test_contains_nan(any_string_dtype):
283283

284284
def test_contains_compiled_regex(any_string_dtype):
285285
# GH#61942
286-
ser = Series(["foo", "bar", "baz"], dtype=any_string_dtype)
287-
pat = re.compile("ba.")
288-
result = ser.str.contains(pat)
289-
290286
expected_dtype = (
291287
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
292288
)
289+
290+
ser = Series(["foo", "bar", "Baz"], dtype=any_string_dtype)
291+
292+
pat = re.compile("ba.")
293+
result = ser.str.contains(pat)
294+
expected = Series([False, True, False], dtype=expected_dtype)
295+
tm.assert_series_equal(result, expected)
296+
297+
# TODO this currently works for pyarrow-backed dtypes but raises for python
298+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
299+
result = ser.str.contains(pat, case=False)
300+
expected = Series([False, True, True], dtype=expected_dtype)
301+
tm.assert_series_equal(result, expected)
302+
else:
303+
with pytest.raises(
304+
ValueError, match="cannot process flags argument with a compiled pattern"
305+
):
306+
ser.str.contains(pat, case=False)
307+
308+
pat = re.compile("ba.", flags=re.IGNORECASE)
309+
result = ser.str.contains(pat)
293310
expected = Series([False, True, True], dtype=expected_dtype)
294311
tm.assert_series_equal(result, expected)
295312

313+
# TODO should this be supported?
314+
with pytest.raises(
315+
ValueError, match="cannot process flags argument with a compiled pattern"
316+
):
317+
ser.str.contains(pat, flags=re.IGNORECASE)
318+
296319

297320
# --------------------------------------------------------------------------------------
298321
# str.startswith
@@ -833,14 +856,36 @@ def test_match_case_kwarg(any_string_dtype):
833856

834857
def test_match_compiled_regex(any_string_dtype):
835858
# GH#61952
836-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
837-
result = values.str.match(re.compile(r"ab"), case=False)
838859
expected_dtype = (
839860
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
840861
)
862+
863+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
864+
865+
result = values.str.match(re.compile("ab"))
866+
expected = Series([True, False, True, False], dtype=expected_dtype)
867+
tm.assert_series_equal(result, expected)
868+
869+
# TODO this currently works for pyarrow-backed dtypes but raises for python
870+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
871+
result = values.str.match(re.compile("ab"), case=False)
872+
expected = Series([True, True, True, True], dtype=expected_dtype)
873+
tm.assert_series_equal(result, expected)
874+
else:
875+
with pytest.raises(
876+
ValueError, match="cannot process flags argument with a compiled pattern"
877+
):
878+
values.str.match(re.compile("ab"), case=False)
879+
880+
result = values.str.match(re.compile("ab", flags=re.IGNORECASE))
841881
expected = Series([True, True, True, True], dtype=expected_dtype)
842882
tm.assert_series_equal(result, expected)
843883

884+
with pytest.raises(
885+
ValueError, match="cannot process flags argument with a compiled pattern"
886+
):
887+
values.str.match(re.compile("ab"), flags=re.IGNORECASE)
888+
844889

845890
# --------------------------------------------------------------------------------------
846891
# str.fullmatch
@@ -913,14 +958,36 @@ def test_fullmatch_case_kwarg(any_string_dtype):
913958

914959
def test_fullmatch_compiled_regex(any_string_dtype):
915960
# GH#61952
916-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
917-
result = values.str.fullmatch(re.compile(r"ab"), case=False)
918961
expected_dtype = (
919962
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
920963
)
964+
965+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
966+
967+
result = values.str.fullmatch(re.compile("ab"))
968+
expected = Series([True, False, False, False], dtype=expected_dtype)
969+
tm.assert_series_equal(result, expected)
970+
971+
# TODO this currently works for pyarrow-backed dtypes but raises for python
972+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
973+
result = values.str.fullmatch(re.compile("ab"), case=False)
974+
expected = Series([True, True, False, False], dtype=expected_dtype)
975+
tm.assert_series_equal(result, expected)
976+
else:
977+
with pytest.raises(
978+
ValueError, match="cannot process flags argument with a compiled pattern"
979+
):
980+
values.str.fullmatch(re.compile("ab"), case=False)
981+
982+
result = values.str.fullmatch(re.compile("ab", flags=re.IGNORECASE))
921983
expected = Series([True, True, False, False], dtype=expected_dtype)
922984
tm.assert_series_equal(result, expected)
923985

986+
with pytest.raises(
987+
ValueError, match="cannot process flags argument with a compiled pattern"
988+
):
989+
values.str.fullmatch(re.compile("ab"), flags=re.IGNORECASE)
990+
924991

925992
# --------------------------------------------------------------------------------------
926993
# str.findall

0 commit comments

Comments
 (0)