Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
f7cf7d6
First draft, needs tests & fixes
cakedev0 Sep 2, 2025
7061ff6
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 2, 2025
f4edaa2
fixed compilation errors
cakedev0 Sep 3, 2025
01fd9b2
fixed compilation errors
cakedev0 Sep 3, 2025
3f87b99
Moved AE computation in external helper to be able to unit-test it; a…
cakedev0 Sep 3, 2025
e8adf96
WIP some additional tests that helped me, some will be kept in my fin…
cakedev0 Sep 3, 2025
4ed868e
tests cleanup
cakedev0 Sep 3, 2025
83d89a4
cleanup
cakedev0 Sep 3, 2025
1ca34bf
cleanup
cakedev0 Sep 3, 2025
43692f7
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 3, 2025
d463558
WIP fixing linting issues
cakedev0 Sep 3, 2025
fa993d4
fixed linting
cakedev0 Sep 3, 2025
cbf5405
fix spelling
cakedev0 Sep 3, 2025
a4bd310
Added test that would fail before this PR
cakedev0 Sep 4, 2025
f4a0e07
added changed logs
cakedev0 Sep 4, 2025
a86a190
cleanup
cakedev0 Sep 4, 2025
092af65
comments & cleanups
cakedev0 Sep 4, 2025
4a12dea
slight refactor of class inheritance
cakedev0 Sep 4, 2025
b44fb2b
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 4, 2025
81728c2
adressed PR comments; simplified dimension of left/right abs errors a…
cakedev0 Sep 9, 2025
7477f4c
removed print
cakedev0 Sep 9, 2025
8f035d0
heap methods docstring; test: split assertion
cakedev0 Sep 10, 2025
e6bf43b
unit test for heap
cakedev0 Sep 10, 2025
eb2ccf5
fix comment
cakedev0 Sep 10, 2025
66a2cb6
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 12, 2025
d13a2c5
Apply suggestions from code review
cakedev0 Sep 13, 2025
4fc78f4
comments & naming
cakedev0 Sep 13, 2025
220c34f
parameters docstring
cakedev0 Sep 13, 2025
d9b3c35
Update doc about MAE criterion speed
cakedev0 Sep 14, 2025
72e15b5
move precompute
cakedev0 Sep 14, 2025
debf965
minimal changes
cakedev0 Sep 14, 2025
6e267d5
AE to pinball loss
cakedev0 Sep 14, 2025
aa91439
doing typos is my signature move, sorry for taht
cakedev0 Sep 14, 2025
450290a
Update doc/modules/tree.rst
cakedev0 Sep 15, 2025
bc7685e
Add docstring for test_cython_weighted_heap_vs_heapq
cakedev0 Sep 15, 2025
390731a
Update comment about mem footprint
cakedev0 Sep 15, 2025
1153cb5
PERF: Decision trees: improve prefs by ~20% with very simple changes …
cakedev0 Sep 15, 2025
0f6d896
:lock: :robot: CI Update lock files for main CI build(s) :lock: :robo…
scikit-learn-bot Sep 15, 2025
f48a2a4
:lock: :robot: CI Update lock files for array-api CI build(s) :lock: …
scikit-learn-bot Sep 15, 2025
b32df28
:lock: :robot: CI Update lock files for free-threaded CI build(s) :lo…
scikit-learn-bot Sep 15, 2025
6cdacd1
TST Fix the error message in test_min_dependencies_readme (#32149)
jeremiedbb Sep 15, 2025
3a85d5c
Revert "API make murmurhash3_32 private (#32103)" (#32131)
jeremiedbb Sep 15, 2025
85b12c9
fix docstring
cakedev0 Sep 15, 2025
7996ed6
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 15, 2025
22c843e
addressed comments around test_absolute_errors_precomputation_function
cakedev0 Sep 15, 2025
319523a
update docstring
cakedev0 Sep 15, 2025
d7f5157
update docstring; again
cakedev0 Sep 15, 2025
592e74a
Pass down pinball_alpha
cakedev0 Sep 17, 2025
3c59ae7
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 17, 2025
a2f3a85
Merge branch 'mae-split-optim' into quantile-regression
cakedev0 Sep 17, 2025
ecd2f15
small changes
cakedev0 Sep 19, 2025
075243c
fixes
cakedev0 Sep 23, 2025
3819c50
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Sep 26, 2025
77dcb19
new test and fix
cakedev0 Sep 26, 2025
14014f5
fix typo
cakedev0 Sep 26, 2025
ad16ae0
remove np.pow
cakedev0 Sep 26, 2025
1e9c74f
Apply suggestion from @ogrisel
cakedev0 Sep 26, 2025
b21040e
Apply suggestion from @cakedev0
cakedev0 Sep 26, 2025
e557f9e
added explanation test; more tests with integer weights
cakedev0 Sep 29, 2025
f920379
Merge branch 'mae-split-optim' of github.com:cakedev0/scikit-learn in…
cakedev0 Sep 29, 2025
c204c20
Merge branch 'mae-split-optim' into quantile-regression
cakedev0 Sep 29, 2025
c842e59
Merge remote-tracking branch 'upstream/main' into mae-split-optim
cakedev0 Oct 2, 2025
bec926a
Merge branch 'main' into mae-split-optim
cakedev0 Oct 3, 2025
0cdeaaf
Merge branch 'mae-split-optim' into quantile-regression
cakedev0 Oct 7, 2025
bf0007f
Merge branch 'main' into quantile-regression
cakedev0 Dec 15, 2025
19bf4a6
cleanup, comments updates, renamings, ...
cakedev0 Dec 15, 2025
aaa4b2a
remove old changelog
cakedev0 Dec 15, 2025
4023c02
Added simple changelog
cakedev0 Dec 15, 2025
e9c424d
Merge remote-tracking branch 'upstream/main' into quantile-regression
cakedev0 Jan 10, 2026
10e7398
Merge remote-tracking branch 'upstream/main' into quantile-regression
cakedev0 Jan 27, 2026
a718a68
renaming & public API
cakedev0 Jan 27, 2026
23f0382
userguide
cakedev0 Jan 27, 2026
30418a4
added tests
cakedev0 Jan 27, 2026
2cfc7f2
support in RF/ExtraTrees
cakedev0 Jan 27, 2026
e06bd6c
add a test with quantile criterion for forests
cakedev0 Jan 27, 2026
a9b26d6
fix docstring
cakedev0 Jan 28, 2026
6316b6b
update changelog
cakedev0 Jan 28, 2026
19a46b8
cleanup
cakedev0 Jan 28, 2026
c1a9af9
Merge remote-tracking branch 'upstream/main' into quantile-regression
cakedev0 Jan 28, 2026
3f81df5
Merge remote-tracking branch 'upstream/main' into quantile-regression
cakedev0 Jan 28, 2026
ff2df12
fix __reduce__ for MAE criterion
cakedev0 Jan 28, 2026
1553d6c
minor public doc udpate
cakedev0 Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,20 @@ Mean Absolute Error:

Note that it is 3–6× slower to fit than the MSE criterion as of version 1.8.

Quantile (pinball loss):

.. math::

q_{\alpha}(y)_m = \underset{y \in Q_m}{\mathrm{quantile}_{\alpha}}(y)

H(Q_m) = \frac{1}{n_m} \sum_{y \in Q_m}
\left(\alpha \max(y - q_{\alpha}(y)_m, 0) +
(1-\alpha) \max(q_{\alpha}(y)_m - y, 0)\right)

Use ``criterion="quantile"`` together with the ``quantile`` parameter to
choose :math:`\alpha \in (0, 1)`. The special case ``quantile=0.5`` corresponds
to the median.

.. _tree_missing_value_support:

Missing Values Support
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.tree/32903.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeRegressor`,
:class:`ensemble.RandomForestRegressor`, and :class:`ensemble.ExtraTreesRegressor`
now support `criterion="quantile"` together with the `quantile` parameter to
optimize the pinball loss (also known as the quantile loss). This effectively
allows to do quantile regression.
By :user:`Arthur Lacote <cakedev0>`
46 changes: 38 additions & 8 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,10 @@ def fit(self, X, y, sample_weight=None):
# will raise an error if the underlying tree base estimator can't handle missing
# values. Only the criterion is required to determine if the tree supports
# missing values.
estimator = type(self.estimator)(criterion=self.criterion)
estimator_kwargs = {"criterion": self.criterion}
if self.criterion == "quantile":
estimator_kwargs["quantile"] = self.quantile
estimator = type(self.estimator)(**estimator_kwargs)
missing_values_in_feature_mask = (
estimator._compute_missing_values_in_feature_mask(
X, estimator_name=self.__class__.__name__
Expand Down Expand Up @@ -697,7 +700,10 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# Only the criterion is required to determine if the tree supports
# missing values
estimator = type(self.estimator)(criterion=self.criterion)
estimator_kwargs = {"criterion": self.criterion}
if self.criterion == "quantile":
estimator_kwargs["quantile"] = self.quantile
estimator = type(self.estimator)(**estimator_kwargs)
tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
return tags

Expand Down Expand Up @@ -1609,14 +1615,16 @@ class RandomForestRegressor(ForestRegressor):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error"
criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \
default="squared_error"
The function to measure the quality of a split. Supported criteria
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion and minimizes the L2
loss using the mean of each terminal node, "absolute_error" for the mean
absolute error, which minimizes the L1 loss using the median of each terminal
node, and "poisson" which uses reduction in Poisson deviance to find splits,
also using the mean of each terminal node.
node, "quantile" which minimizes the pinball loss using the quantile of each
terminal node (controlled by ``quantile``), and "poisson" which uses reduction
in Poisson deviance to find splits, also using the mean of each terminal node.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.
Expand All @@ -1627,6 +1635,9 @@ class RandomForestRegressor(ForestRegressor):
.. versionchanged:: 1.9
Criterion `"friedman_mse"` was deprecated.

.. versionadded:: 1.9
Quantile/Pinball loss criterion

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
Expand Down Expand Up @@ -1786,6 +1797,10 @@ class RandomForestRegressor(ForestRegressor):

.. versionadded:: 1.4

quantile : float, default=0.5
The quantile to predict when ``criterion="quantile"``. It must be strictly
between 0 and 1.

Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
Expand Down Expand Up @@ -1913,6 +1928,7 @@ def __init__(
ccp_alpha=0.0,
max_samples=None,
monotonic_cst=None,
quantile=0.5,
):
super().__init__(
estimator=DecisionTreeRegressor(),
Expand All @@ -1929,6 +1945,7 @@ def __init__(
"random_state",
"ccp_alpha",
"monotonic_cst",
"quantile",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand Down Expand Up @@ -1959,6 +1976,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.quantile = quantile


class ExtraTreesClassifier(ForestClassifier):
Expand Down Expand Up @@ -2378,21 +2396,26 @@ class ExtraTreesRegressor(ForestRegressor):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error"
criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \
default="squared_error"
The function to measure the quality of a split. Supported criteria
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion and minimizes the L2
loss using the mean of each terminal node, "absolute_error" for the mean
absolute error, which minimizes the L1 loss using the median of each terminal
node, and "poisson" which uses reduction in Poisson deviance to find splits,
also using the mean of each terminal node.
node, "quantile" which minimizes the pinball loss using the quantile of each
terminal node (controlled by ``quantile``), and "poisson" which uses reduction
in Poisson deviance to find splits, also using the mean of each terminal node.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.

.. versionchanged:: 1.9
Criterion `"friedman_mse"` was deprecated.

.. versionadded:: 1.9
Quantile/Pinball loss criterion

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
all leaves are pure or until all leaves contain less than
Expand Down Expand Up @@ -2556,6 +2579,10 @@ class ExtraTreesRegressor(ForestRegressor):

.. versionadded:: 1.4

quantile : float, default=0.5
The quantile to predict when ``criterion="quantile"``. It must be strictly
between 0 and 1.

Attributes
----------
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
Expand Down Expand Up @@ -2667,6 +2694,7 @@ def __init__(
ccp_alpha=0.0,
max_samples=None,
monotonic_cst=None,
quantile=0.5,
):
super().__init__(
estimator=ExtraTreeRegressor(),
Expand All @@ -2683,6 +2711,7 @@ def __init__(
"random_state",
"ccp_alpha",
"monotonic_cst",
"quantile",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand Down Expand Up @@ -2713,6 +2742,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.quantile = quantile


class RandomTreesEmbedding(TransformerMixin, BaseForest):
Expand Down
1 change: 1 addition & 0 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta):
}
_parameter_constraints.pop("splitter")
_parameter_constraints.pop("monotonic_cst")
_parameter_constraints.pop("quantile")

@abstractmethod
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,12 @@ def test_probability(name):
"name, criterion",
itertools.chain(
product(FOREST_CLASSIFIERS, ["gini", "log_loss"]),
product(FOREST_REGRESSORS, ["squared_error", "absolute_error"]),
product(FOREST_REGRESSORS, ["squared_error", "absolute_error", "quantile"]),
),
)
def test_importances(dtype, name, criterion):
tolerance = 0.01
if name in FOREST_REGRESSORS and criterion == "absolute_error":
if name in FOREST_REGRESSORS and criterion in {"absolute_error", "quantile"}:
tolerance = 0.05

# cast as dtype
Expand Down
47 changes: 39 additions & 8 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"squared_error": _criterion.MSE,
"absolute_error": _criterion.MAE,
"poisson": _criterion.Poisson,
"quantile": _criterion.Pinball,
}

DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter}
Expand Down Expand Up @@ -382,7 +383,14 @@ def _fit(
self.n_outputs_, self.n_classes_
)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
args = (self.n_outputs_, n_samples)
if self.criterion == "quantile":
args = (*args, self.quantile)
if self.criterion == "absolute_error":
# FIXME: this is coupled with code at a much lower level
# because of the inheritance behavior of __cinit__
args = (*args, 0.5)
criterion = CRITERIA_REG[self.criterion](*args)
else:
# Make a deepcopy in case the criterion has mutable attributes that
# might be shared and modified concurrently during parallel fitting
Expand Down Expand Up @@ -1117,14 +1125,16 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):

Parameters
----------
criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error"
criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \
default="squared_error"
The function to measure the quality of a split. Supported criteria
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion and minimizes the L2
loss using the mean of each terminal node, "absolute_error" for the mean
absolute error, which minimizes the L1 loss using the median of each terminal
node, and "poisson" which uses reduction in Poisson deviance to find splits,
also using the mean of each terminal node.
node, "quantile" which minimizes the pinball loss using the quantile of each
terminal node (controlled by ``quantile``), and "poisson" which uses reduction
in Poisson deviance to find splits, also using the mean of each terminal node.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.
Expand All @@ -1135,6 +1145,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
.. versionchanged:: 1.9
Criterion `"friedman_mse"` was deprecated.

.. versionadded:: 1.9
Quantile/Pinball loss criterion

splitter : {"best", "random"}, default="best"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split and "random" to choose
Expand Down Expand Up @@ -1255,6 +1268,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):

.. versionadded:: 1.4

quantile : float, default=0.5
The quantile to predict when ``criterion="quantile"``. It must be strictly
between 0 and 1. If 0.5 (default), the model predicts the median.

Attributes
----------
feature_importances_ : ndarray of shape (n_features,)
Expand Down Expand Up @@ -1338,9 +1355,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
_parameter_constraints: dict = {
**BaseDecisionTree._parameter_constraints,
"criterion": [
StrOptions({"squared_error", "absolute_error", "poisson"}),
StrOptions({"squared_error", "absolute_error", "poisson", "quantile"}),
Hidden(Criterion),
],
"quantile": [Interval(RealNotInt, 0.0, 1.0, closed="neither")],
}

def __init__(
Expand All @@ -1358,6 +1376,7 @@ def __init__(
min_impurity_decrease=0.0,
ccp_alpha=0.0,
monotonic_cst=None,
quantile=0.5,
):
if isinstance(criterion, str) and criterion == "friedman_mse":
# TODO(1.11): remove support of "friedman_mse" criterion.
Expand All @@ -1383,6 +1402,7 @@ def __init__(
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
)
self.quantile = quantile

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y, sample_weight=None, check_input=True):
Expand Down Expand Up @@ -1767,14 +1787,16 @@ class ExtraTreeRegressor(DecisionTreeRegressor):

Parameters
----------
criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error"
criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \
default="squared_error"
The function to measure the quality of a split. Supported criteria
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion and minimizes the L2
loss using the mean of each terminal node, "absolute_error" for the mean
absolute error, which minimizes the L1 loss using the median of each terminal
node, and "poisson" which uses reduction in Poisson deviance to find splits,
also using the mean of each terminal node.
node, "quantile" which minimizes the pinball loss using the quantile of each
terminal node (controlled by ``quantile``), and "poisson" which uses reduction
in Poisson deviance to find splits, also using the mean of each terminal node.

.. versionadded:: 0.18
Mean Absolute Error (MAE) criterion.
Expand All @@ -1785,6 +1807,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
.. versionchanged:: 1.9
Criterion `"friedman_mse"` was deprecated.

.. versionadded:: 1.9
Quantile/Pinball loss criterion

splitter : {"random", "best"}, default="random"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split and "random" to choose
Expand Down Expand Up @@ -1897,6 +1922,10 @@ class ExtraTreeRegressor(DecisionTreeRegressor):

.. versionadded:: 1.4

quantile : float, default=0.5
The quantile to predict when ``criterion="quantile"``. It must be strictly
between 0 and 1. If 0.5 (default), the model predicts the median.

Attributes
----------
max_features_ : int
Expand Down Expand Up @@ -1981,6 +2010,7 @@ def __init__(
max_leaf_nodes=None,
ccp_alpha=0.0,
monotonic_cst=None,
quantile=0.5,
):
super().__init__(
criterion=criterion,
Expand All @@ -1995,6 +2025,7 @@ def __init__(
random_state=random_state,
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
quantile=quantile,
)

def __sklearn_tags__(self):
Expand Down
Loading
Loading