Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from pyspark.pandas.plot import PandasOnSparkPlotAccessor
from pyspark.pandas.utils import (
combine_frames,
is_ansi_mode_enabled,
is_name_like_tuple,
is_name_like_value,
name_like_string,
Expand Down Expand Up @@ -5081,33 +5082,68 @@ def replace(
)
)
to_replace = {k: v for k, v in zip(to_replace, value)}

spark_session = self._internal.spark_frame.sparkSession
ansi_mode = is_ansi_mode_enabled(spark_session)
col_type = self.spark.data_type

if isinstance(to_replace, dict):
is_start = True
if len(to_replace) == 0:
current = self.spark.column
else:
for to_replace_, value in to_replace.items():
cond = (
(F.isnan(self.spark.column) | self.spark.column.isNull())
if pd.isna(to_replace_)
else (self.spark.column == F.lit(to_replace_))
)
if pd.isna(to_replace_):
if ansi_mode and isinstance(col_type, NumericType):
cond = F.isnan(self.spark.column) | self.spark.column.isNull()
else:
cond = self.spark.column.isNull()
else:
to_replace_lit = (
F.lit(to_replace_).try_cast(col_type)
if ansi_mode
else F.lit(to_replace_)
)
cond = self.spark.column == to_replace_lit
value_expr = F.lit(value).try_cast(col_type) if ansi_mode else F.lit(value)
if is_start:
current = F.when(cond, value)
current = F.when(cond, value_expr)
is_start = False
else:
current = current.when(cond, value)
current = current.when(cond, value_expr)
current = current.otherwise(self.spark.column)
else:
if regex:
# to_replace must be a string
cond = self.spark.column.rlike(cast(str, to_replace))
else:
cond = self.spark.column.isin(to_replace)
if ansi_mode:
to_replace_values = (
[to_replace]
if not is_list_like(to_replace) or isinstance(to_replace, str)
else to_replace
)
to_replace_values = cast(List[Any], to_replace_values)
Comment on lines +5121 to +5126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try:

to_replace_values: List[Any] = (
    ...
)

to see mypy is happy with it? If not, it's fine with as-is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

                    to_replace_values: List[Any] = (
                        [to_replace]
                        if not is_list_like(to_replace) or isinstance(to_replace, str)
                        else to_replace
                    )

causes

python/pyspark/pandas/series.py:5122: error: Incompatible types in assignment (expression has type "list[Any] | int | float", variable has type "list[Any]")  [assignment]

I'm afraid we might have to keep to_replace_values = cast(List[Any], to_replace_values).

literals = [F.lit(v).try_cast(col_type) for v in to_replace_values]
cond = self.spark.column.isin(literals)
else:
cond = self.spark.column.isin(to_replace)
# to_replace may be a scalar
if np.array(pd.isna(to_replace)).any():
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()
current = F.when(cond, value).otherwise(self.spark.column)
if ansi_mode:
if isinstance(col_type, NumericType):
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()
else:
cond = cond | self.spark.column.isNull()
else:
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()

if ansi_mode:
value_expr = F.lit(value).try_cast(col_type)
current = F.when(cond, value_expr).otherwise(self.spark.column.try_cast(col_type))

else:
current = F.when(cond, value).otherwise(self.spark.column)

return self._with_new_scol(current) # TODO: dtype?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def test_fillna(self):
pdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}),
)

@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
def test_replace(self):
pdf = pd.DataFrame(
{
Expand Down