-
Notifications
You must be signed in to change notification settings - Fork 841
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Using polars (cast to numpy on fit and predict) get an error in predict on X_new[:, 0] = 1 (118 row in slearner.py)
Cause CatBoostRegressor (and CatBoostClassifier) change writeable flag on array
To Reproduce
Steps to reproduce the behavior:
n_train_rows = 10_000
n_test_rows = 5_000
n_features = 10
feature_names = [f'X_{i}' for i in range(n_features)]
X_np = np.random.randint(0, 100, (n_train_rows, n_features))
y_np = np.random.randint(0, 100, (n_train_rows, 1))
X_test_np = np.random.randint(0, 100, (n_train_rows, n_features))
X_pd = pd.DataFrame(np.random.randint(0, 100, (n_train_rows, n_features)), columns=feature_names)
y_pd = pd.DataFrame(np.random.randint(0, 100, (n_train_rows, 1)), columns=['y'])
X_test_pd = pd.DataFrame(np.random.randint(0, 100, (n_train_rows, n_features)), columns=feature_names)
X_pd = pl.DataFrame(np.random.randint(0, 100, (n_train_rows, n_features)), schema=feature_names)
y_pd = pl.DataFrame(np.random.randint(0, 100, (n_train_rows, 1)), schema=['y'])
X_test_pd = pl.DataFrame(np.random.randint(0, 100, (n_train_rows, n_features)), schema=feature_names)
# 1. Numpy (no errors, save all writeable flags)
X = X_np
y = y_np
X_test = X_test_np
print(
'Before',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # True
f'{X_test.flags.writeable=}', # True
sep='\n',
)
model = CatBoostRegressor(iterations=10) # `10` for debug
model.fit(X, y)
model.predict(X_test)
print(
'After',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # True
f'{X_test.flags.writeable=}', # True
sep='\n',
)
s_model = BaseSLearner(CatBoostRegressor(iterations=10))
s_model.fit(X, y)
s_model.predict(X_test)
# 2. Pandas -> Numpy (no errors, save all writeable flags)
X = X_pd.to_numpy()
y = y_pd.to_numpy()
X_test = X_test_pd.to_numpt()
print(
'Before',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # True
f'{X_test.flags.writeable=}', # True
sep='\n',
)
model = CatBoostRegressor(iterations=10) # `10` for debug
model.fit(X, y)
model.predict(X_test)
print(
'After',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # True
f'{X_test.flags.writeable=}', # True
sep='\n',
)
s_model = BaseSLearner(CatBoostRegressor(iterations=10))
s_model.fit(X, y)
s_model.predict(X_test)
# 3. Polars -> Numpy (ValueError, do not save writeable flags)
X = X_df.to_numpy()
y = y_df.to_numpy()
X_test = X_test_df.to_numpy()
print(
'Before',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # False
f'{X_test.flags.writeable=}', # True
sep='\n',
)
model = CatBoostRegressor(iterations=10) # `10` for debug
model.fit(X, y)
model.predict(X_test)
print(
'After',
f'{X.flags.writeable=}', # False (was True)
f'{y.flags.writeable=}', # False
f'{X_test.flags.writeable=}', # False (was True)
sep='\n',
)
s_model = BaseSLearner(CatBoostRegressor(iterations=10))
s_model.fit(X, y)
try:
s_model.predict(X_test) # ValueError
except ValueError:
pass
# 4. Polars -> Pandas -> Numpy (ValueError, do not save writeable flags)
X = X_df.to_pandas().to_numpy()
y = y_df.to_pandas().to_numpy()
X_test = X_test_df.to_pandas().to_numpy()
print(
'Before',
f'{X.flags.writeable=}', # True
f'{y.flags.writeable=}', # True
f'{X_test.flags.writeable=}', # True
sep='\n',
)
model = CatBoostRegressor(iterations=10) # `10` for debug
model.fit(X, y)
model.predict(X_test)
print(
'After',
f'{X.flags.writeable=}', # False (was True)
f'{y.flags.writeable=}', # True (was False)
f'{X_test.flags.writeable=}', # False (was True)
sep='\n',
)
s_model = BaseSLearner(CatBoostRegressor(iterations=10))
s_model.fit(X, y)
try:
s_model.predict(X_test) # ValueError
except ValueError:
passExpected behavior
Add X_new.flags.writeable = True to fix this error
Environment (please complete the following information):
Linux-4.18.0-305.72.1.el8_4.x86_64-x86_64-with-glibc2.28
Python 3.11.9 (main, Jun 19 2024, 10:02:06) [GCC 8.5.0 20210514 (Red Hat 8.5.0-22)]
causalml==0.15.5
catboost==1.2.8
numpy==2.2.6
pandas==2.2.3
polars==1.32.2
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working