Skip to content

Commit a5940d5

Browse files
committed
addressed changes brought up in PR, converted test cases to not use non-1D EAs
1 parent 98fb07f commit a5940d5

File tree

4 files changed

+48
-167
lines changed

4 files changed

+48
-167
lines changed

pandas/core/frame.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8444,12 +8444,7 @@ def _maybe_align_series_as_frame(self, series: Series, axis: AxisInt):
84448444
"""
84458445
rvalues = series._values
84468446
if not isinstance(rvalues, np.ndarray):
8447-
# TODO(EA2D): no need to special-case with 2D EAs
8448-
if rvalues.dtype in ("datetime64[ns]", "timedelta64[ns]"):
8449-
# We can losslessly+cheaply cast to ndarray
8450-
rvalues = np.asarray(rvalues)
8451-
else:
8452-
return series
8447+
rvalues = np.asarray(rvalues)
84538448

84548449
if axis == 0:
84558450
rvalues = rvalues.reshape(-1, 1)

pandas/tests/frame/test_arithmetic.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,61 +2175,59 @@ def test_mixed_col_index_dtype(string_dtype_no_object):
21752175
tm.assert_frame_equal(result, expected)
21762176

21772177

2178-
@pytest.mark.parametrize("op", ["add", "sub", "mul", "div", "mod", "truediv", "pow"])
2179-
def test_df_fill_value_operations(op):
2180-
# GH 61581
2181-
input_data = np.arange(50).reshape(10, 5)
2182-
fill_val = 5
2183-
columns = list("ABCDE")
2184-
df = DataFrame(input_data, columns=columns)
2185-
for i in range(5):
2186-
df.iat[i, i] = np.nan
2187-
df.iat[i + 1, i] = np.nan
2188-
df.iat[i + 4, i] = np.nan
2189-
2190-
df_base = df.iloc[:, :-1]
2191-
df_mult = df.iloc[:, -1]
2192-
mask = df.isna().values
2193-
mask = mask[:, :-1] & mask[:, [-1]]
2194-
2195-
df_result = getattr(df_base, op)(df_mult, axis=0, fill_value=fill_val)
2196-
df_expected = getattr(df_base.fillna(fill_val), op)(
2197-
df_mult.fillna(fill_val), axis=0
2198-
).mask(mask, np.nan)
2199-
2200-
tm.assert_frame_equal(df_result, df_expected)
2201-
2202-
22032178
dt_params = [
2204-
(tm.ALL_INT_NUMPY_DTYPES, 5),
2205-
(tm.ALL_INT_EA_DTYPES, 5),
2206-
(tm.FLOAT_NUMPY_DTYPES, 4.9),
2207-
(tm.FLOAT_EA_DTYPES, 4.9),
2179+
(tm.ALL_INT_NUMPY_DTYPES[0], 5),
2180+
(tm.ALL_INT_EA_DTYPES[0], 5),
2181+
(tm.FLOAT_NUMPY_DTYPES[0], 4.9),
2182+
(tm.FLOAT_EA_DTYPES[0], 4.9),
22082183
]
22092184

2210-
dt_param_flat = [(dt, val) for lst, val in dt_params for dt in lst]
2185+
axes = [0, 1]
22112186

22122187

2213-
@pytest.mark.parametrize("data_type, fill_val", dt_param_flat)
2214-
def test_df_fill_value_dtype(data_type, fill_val):
2188+
@pytest.mark.parametrize(
2189+
"data_type,fill_val, axis",
2190+
[(dt, val, axis) for axis in axes for dt, val in dt_params],
2191+
)
2192+
def test_df_fill_value_dtype(data_type, fill_val, axis):
22152193
# GH 61581
2216-
base_data = np.arange(50).reshape(10, 5)
2217-
df_data = pd.array(base_data, dtype=data_type)
2194+
base_data = np.arange(25).reshape(5, 5)
2195+
mult_list = [1, np.nan, 5, np.nan, 3]
2196+
np_int_flag = 0
2197+
2198+
try:
2199+
mult_data = pd.array(mult_list, dtype=data_type)
2200+
except ValueError as e:
2201+
# Numpy int type cannot represent NaN, it will end up here
2202+
if "cannot convert float NaN to integer" in str(e):
2203+
mult_data = np.asarray(mult_list)
2204+
np_int_flag = 1
2205+
22182206
columns = list("ABCDE")
2219-
df = DataFrame(df_data, columns=columns)
2220-
for i in range(5):
2221-
df.iat[i, i] = np.nan
2222-
df.iat[i + 1, i] = pd.NA
2223-
df.iat[i + 4, i] = pd.NA
2224-
2225-
df_base = df.iloc[:, :-1]
2226-
df_mult = df.iloc[:, -1]
2227-
mask = df.isna().values
2228-
mask = mask[:, :-1] & mask[:, [-1]]
2229-
2230-
df_result = df_base.mul(df_mult, axis=0, fill_value=fill_val)
2231-
df_expected = (df_base.fillna(fill_val).mul(df_mult.fillna(fill_val), axis=0)).mask(
2232-
mask, np.nan
2233-
)
2207+
df = DataFrame(base_data, columns=columns)
2208+
2209+
for i in range(df.shape[0]):
2210+
try:
2211+
df.iat[i, i] = np.nan
2212+
df.iat[i + 1, i] = pd.NA
2213+
df.iat[i + 3, i] = pd.NA
2214+
except IndexError:
2215+
pass
2216+
2217+
mult_mat = np.broadcast_to(mult_data, df.shape)
2218+
if axis == 0:
2219+
mask = np.isnan(mult_mat).T
2220+
else:
2221+
mask = np.isnan(mult_mat)
2222+
mask = df.isna().values & mask
2223+
2224+
df_result = df.mul(mult_data, axis=axis, fill_value=fill_val)
2225+
if np_int_flag == 1:
2226+
mult_np = np.nan_to_num(mult_data, nan=fill_val)
2227+
df_expected = (df.fillna(fill_val).mul(mult_np, axis=axis)).mask(mask, np.nan)
2228+
else:
2229+
df_expected = (
2230+
df.fillna(fill_val).mul(mult_data.fillna(fill_val), axis=axis)
2231+
).mask(mask, np.nan)
22342232

22352233
tm.assert_frame_equal(df_result, df_expected)

test.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

test2.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)