@@ -2175,61 +2175,59 @@ def test_mixed_col_index_dtype(string_dtype_no_object):
2175
2175
tm .assert_frame_equal (result , expected )
2176
2176
2177
2177
2178
- @pytest .mark .parametrize ("op" , ["add" , "sub" , "mul" , "div" , "mod" , "truediv" , "pow" ])
2179
- def test_df_fill_value_operations (op ):
2180
- # GH 61581
2181
- input_data = np .arange (50 ).reshape (10 , 5 )
2182
- fill_val = 5
2183
- columns = list ("ABCDE" )
2184
- df = DataFrame (input_data , columns = columns )
2185
- for i in range (5 ):
2186
- df .iat [i , i ] = np .nan
2187
- df .iat [i + 1 , i ] = np .nan
2188
- df .iat [i + 4 , i ] = np .nan
2189
-
2190
- df_base = df .iloc [:, :- 1 ]
2191
- df_mult = df .iloc [:, - 1 ]
2192
- mask = df .isna ().values
2193
- mask = mask [:, :- 1 ] & mask [:, [- 1 ]]
2194
-
2195
- df_result = getattr (df_base , op )(df_mult , axis = 0 , fill_value = fill_val )
2196
- df_expected = getattr (df_base .fillna (fill_val ), op )(
2197
- df_mult .fillna (fill_val ), axis = 0
2198
- ).mask (mask , np .nan )
2199
-
2200
- tm .assert_frame_equal (df_result , df_expected )
2201
-
2202
-
2203
2178
dt_params = [
2204
- (tm .ALL_INT_NUMPY_DTYPES , 5 ),
2205
- (tm .ALL_INT_EA_DTYPES , 5 ),
2206
- (tm .FLOAT_NUMPY_DTYPES , 4.9 ),
2207
- (tm .FLOAT_EA_DTYPES , 4.9 ),
2179
+ (tm .ALL_INT_NUMPY_DTYPES [ 0 ] , 5 ),
2180
+ (tm .ALL_INT_EA_DTYPES [ 0 ] , 5 ),
2181
+ (tm .FLOAT_NUMPY_DTYPES [ 0 ] , 4.9 ),
2182
+ (tm .FLOAT_EA_DTYPES [ 0 ] , 4.9 ),
2208
2183
]
2209
2184
2210
- dt_param_flat = [( dt , val ) for lst , val in dt_params for dt in lst ]
2185
+ axes = [0 , 1 ]
2211
2186
2212
2187
2213
- @pytest .mark .parametrize ("data_type, fill_val" , dt_param_flat )
2214
- def test_df_fill_value_dtype (data_type , fill_val ):
2188
+ @pytest .mark .parametrize (
2189
+ "data_type,fill_val, axis" ,
2190
+ [(dt , val , axis ) for axis in axes for dt , val in dt_params ],
2191
+ )
2192
+ def test_df_fill_value_dtype (data_type , fill_val , axis ):
2215
2193
# GH 61581
2216
- base_data = np .arange (50 ).reshape (10 , 5 )
2217
- df_data = pd .array (base_data , dtype = data_type )
2194
+ base_data = np .arange (25 ).reshape (5 , 5 )
2195
+ mult_list = [1 , np .nan , 5 , np .nan , 3 ]
2196
+ np_int_flag = 0
2197
+
2198
+ try :
2199
+ mult_data = pd .array (mult_list , dtype = data_type )
2200
+ except ValueError as e :
2201
+ # Numpy int type cannot represent NaN, it will end up here
2202
+ if "cannot convert float NaN to integer" in str (e ):
2203
+ mult_data = np .asarray (mult_list )
2204
+ np_int_flag = 1
2205
+
2218
2206
columns = list ("ABCDE" )
2219
- df = DataFrame (df_data , columns = columns )
2220
- for i in range (5 ):
2221
- df .iat [i , i ] = np .nan
2222
- df .iat [i + 1 , i ] = pd .NA
2223
- df .iat [i + 4 , i ] = pd .NA
2224
-
2225
- df_base = df .iloc [:, :- 1 ]
2226
- df_mult = df .iloc [:, - 1 ]
2227
- mask = df .isna ().values
2228
- mask = mask [:, :- 1 ] & mask [:, [- 1 ]]
2229
-
2230
- df_result = df_base .mul (df_mult , axis = 0 , fill_value = fill_val )
2231
- df_expected = (df_base .fillna (fill_val ).mul (df_mult .fillna (fill_val ), axis = 0 )).mask (
2232
- mask , np .nan
2233
- )
2207
+ df = DataFrame (base_data , columns = columns )
2208
+
2209
+ for i in range (df .shape [0 ]):
2210
+ try :
2211
+ df .iat [i , i ] = np .nan
2212
+ df .iat [i + 1 , i ] = pd .NA
2213
+ df .iat [i + 3 , i ] = pd .NA
2214
+ except IndexError :
2215
+ pass
2216
+
2217
+ mult_mat = np .broadcast_to (mult_data , df .shape )
2218
+ if axis == 0 :
2219
+ mask = np .isnan (mult_mat ).T
2220
+ else :
2221
+ mask = np .isnan (mult_mat )
2222
+ mask = df .isna ().values & mask
2223
+
2224
+ df_result = df .mul (mult_data , axis = axis , fill_value = fill_val )
2225
+ if np_int_flag == 1 :
2226
+ mult_np = np .nan_to_num (mult_data , nan = fill_val )
2227
+ df_expected = (df .fillna (fill_val ).mul (mult_np , axis = axis )).mask (mask , np .nan )
2228
+ else :
2229
+ df_expected = (
2230
+ df .fillna (fill_val ).mul (mult_data .fillna (fill_val ), axis = axis )
2231
+ ).mask (mask , np .nan )
2234
2232
2235
2233
tm .assert_frame_equal (df_result , df_expected )
0 commit comments