Skip to content

Commit 06bb009

Browse files
carl-offerfitCarl Goldkbattocchi
authored
Alternative scoring metrics (#988)
* Allow specification of an alternative sklearn score function for double ML models * Add score_nuisances function Signed-off-by: Carl Gold <[email protected]> Signed-off-by: Keith Battocchi <[email protected]> Co-authored-by: Carl Gold <[email protected]> Co-authored-by: Keith Battocchi <[email protected]>
1 parent 641c1ac commit 06bb009

File tree

5 files changed

+293
-25
lines changed

5 files changed

+293
-25
lines changed

econml/_ortho_learner.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ def effect_inference(self, X=None, *, T0=0, T1=1):
10491049

10501050
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__
10511051

1052-
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
1052+
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, scoring=None):
10531053
"""
10541054
Score the fitted CATE model on a new data set.
10551055
@@ -1077,6 +1077,9 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
10771077
Weights for each samples
10781078
groups: (n,) vector, optional
10791079
All rows corresponding to the same group will be kept together during splitting.
1080+
scoring: name of an sklearn scoring function to use instead of the default, optional
1081+
Supports f1_score, log_loss, mean_absolute_error, mean_squared_error, r2_score,
1082+
and roc_auc_score.
10801083
10811084
Returns
10821085
-------
@@ -1135,9 +1138,24 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
11351138

11361139
accumulated_nuisances += nuisances
11371140

1141+
score_kwargs = {
1142+
'X': X,
1143+
'W': W,
1144+
'Z': Z,
1145+
'sample_weight': sample_weight,
1146+
'groups': groups
1147+
}
1148+
# If using an _rlearner, the scoring parameter can be passed along, if provided
1149+
if scoring is not None:
1150+
# Cannot import in header, or circular imports
1151+
from .dml._rlearner import _ModelFinal
1152+
if isinstance(self._ortho_learner_model_final, _ModelFinal):
1153+
score_kwargs['scoring'] = scoring
1154+
else:
1155+
raise NotImplementedError("scoring parameter only implemented for "
1156+
"_rlearner._ModelFinal")
11381157
return self._ortho_learner_model_final.score(Y, T, nuisances=accumulated_nuisances,
1139-
**filter_none_kwargs(X=X, W=W, Z=Z,
1140-
sample_weight=sample_weight, groups=groups))
1158+
**filter_none_kwargs(**score_kwargs))
11411159

11421160
@property
11431161
def ortho_learner_model_final_(self):

econml/dml/_rlearner.py

Lines changed: 158 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,16 @@
2727

2828
from abc import abstractmethod
2929
import numpy as np
30-
30+
import pandas as pd
31+
from sklearn.metrics import (
32+
get_scorer,
33+
get_scorer_names
34+
)
35+
from typing import Callable, Union
3136
from ..sklearn_extensions.model_selection import ModelSelector
3237
from ..utilities import (filter_none_kwargs)
3338
from .._ortho_learner import _OrthoLearner
3439

35-
3640
class _ModelNuisance(ModelSelector):
3741
"""
3842
RLearner nuisance model.
@@ -54,10 +58,13 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight
5458
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
5559
return self
5660

57-
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
61+
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None,
62+
y_scoring=None, t_scoring=None, t_score_by_dim=False):
5863
# note that groups are not passed to score because they are only used for fitting
59-
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight))
60-
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight))
64+
T_score = self._model_t.score(X, W, T, **filter_none_kwargs(sample_weight=sample_weight),
65+
scoring=t_scoring, score_by_dim=t_score_by_dim)
66+
Y_score = self._model_y.score(X, W, Y, **filter_none_kwargs(sample_weight=sample_weight),
67+
scoring=y_scoring)
6168
return Y_score, T_score
6269

6370
def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
@@ -98,18 +105,92 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
98105
def predict(self, X=None):
99106
return self._model_final.predict(X)
100107

101-
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None):
108+
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None,
109+
scoring='mean_squared_error'):
110+
"""
111+
Score final model fit of residualized outcomes from residualized treatments and nuisances.
112+
113+
The default scoring method "mean_squared_error" is the score used to fit residualized
114+
outcomes from residualized treatments and nuisances, and reproduces the behavior of this
115+
score function from before the scoring method option.
116+
117+
:param Y: Unused
118+
:param T: Unused
119+
:param X: Combined nuisances, treatments and instruments to call _model_final.predict
120+
:param W: Unused
121+
:param Z: Unused
122+
:param nuisances: tuple of the outcome (Y) residuals and treatment (T) residuals
123+
:param sample_weight: Optional weighting on the samples
124+
:param groups: Unused
125+
:param scoring: Optional alternative scoring metric from sklearn.get_scorer
126+
:return: Float score
127+
"""
102128
Y_res, T_res = nuisances
103129
if Y_res.ndim == 1:
104130
Y_res = Y_res.reshape((-1, 1))
105131
if T_res.ndim == 1:
106132
T_res = T_res.reshape((-1, 1))
107133
effects = self._model_final.predict(X).reshape((-1, Y_res.shape[1], T_res.shape[1]))
108134
Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape)
135+
return _ModelFinal._wrap_scoring(Y_true=Y_res, Y_pred=Y_res_pred, scoring=scoring, sample_weight=sample_weight)
136+
137+
138+
@staticmethod
139+
def _wrap_scoring(scoring:Union[str, Callable], Y_true, Y_pred, sample_weight=None):
140+
"""
141+
Pull the scoring function from sklearn.get_scorer and call it with Y_true, Y_pred.
142+
143+
Standard score names like "mean_squared_error" are present in sklearn scoring as
144+
"neg_..." so score names are accepted either with or without the "neg_" prefix.
145+
The function _score_func is called directly because the scorer objects from get_scorer()
146+
do not accept a sample_weight parameter. The _score_func member has been available in
147+
sklearn scorers since before sklearn 1.0. Note that custom callable score functions
148+
are allowed but they are not validated before use; any errors will be raised.
149+
150+
151+
:param scoring: A string name of a scoring function from sklearn, or any callable that will
152+
function as thes core.
153+
:param Y_true: True Y values
154+
:param Y_pred: Predicted Y values
155+
:param sample_weight: Optional weighting on the examples
156+
:return: Float score
157+
"""
158+
if isinstance(scoring,str) and scoring in get_scorer_names():
159+
score_fn = get_scorer(scoring)._score_func
160+
elif isinstance(scoring,str) and 'neg_' + scoring in get_scorer_names():
161+
score_fn = get_scorer('neg_' + scoring)._score_func
162+
elif callable(scoring):
163+
score_fn = scoring
164+
else:
165+
raise NotImplementedError(f"_wrap_scoring does not support '{scoring}'" )
166+
167+
# Some score like functions are partial to np.array and not np.ndarray with shape (N,1)
168+
Y_true = Y_true.squeeze() if len(Y_true.shape)==2 and Y_true.shape[1]==1 else Y_true
169+
Y_pred = Y_pred.squeeze() if len(Y_pred.shape)==2 and Y_pred.shape[1]==1 else Y_pred
109170
if sample_weight is not None:
110-
return np.mean(np.average((Y_res - Y_res_pred) ** 2, weights=sample_weight, axis=0))
171+
res = score_fn(Y_true, Y_pred, sample_weight=sample_weight)
111172
else:
112-
return np.mean((Y_res - Y_res_pred) ** 2)
173+
res = score_fn(Y_true, Y_pred)
174+
175+
return res
176+
177+
178+
@staticmethod
179+
def wrap_scoring(scoring, Y_true, Y_pred, sample_weight=None, score_by_dim=False):
180+
"""
181+
In case the caller wants a score for each dimension of a multiple treatment model.
182+
183+
Loop over the call to the single score wrapper.
184+
"""
185+
if not score_by_dim:
186+
return _ModelFinal._wrap_scoring(scoring, Y_true, Y_pred, sample_weight)
187+
else:
188+
assert Y_true.shape == Y_pred.shape, "Mismatch shape in wrap_scoring"
189+
n_out = Y_pred.shape[1]
190+
res = [None]*Y_pred.shape[1]
191+
for yidx in range(n_out):
192+
res[yidx]= _ModelFinal.wrap_scoring(scoring, Y_true[:,yidx], Y_pred[:,yidx], sample_weight)
193+
return res
113194

114195

115196
class _RLearner(_OrthoLearner):
@@ -255,13 +336,13 @@ def _gen_rlearner_model_final(self):
255336
>>> est.effect(np.ones((1,1)), T0=0, T1=10)
256337
array([9.996314...])
257338
>>> est.score(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])
258-
np.float64(9.73638006...e-05)
339+
9.73638006...e-05
259340
>>> est.rlearner_model_final_.model
260341
LinearRegression(fit_intercept=False)
261342
>>> est.rlearner_model_final_.model.coef_
262343
array([0.999631...])
263344
>>> est.score_
264-
np.float64(9.82623204...e-05)
345+
9.82623204...e-05
265346
>>> [mdl._model for mdls in est.models_y for mdl in mdls]
266347
[LinearRegression(), LinearRegression()]
267348
>>> [mdl._model for mdls in est.models_t for mdl in mdls]
@@ -422,7 +503,7 @@ def fit(self, Y, T, *, X=None, W=None, sample_weight=None, freq_weight=None, sam
422503
cache_values=cache_values,
423504
inference=inference)
424505

425-
def score(self, Y, T, X=None, W=None, sample_weight=None):
506+
def score(self, Y, T, X=None, W=None, sample_weight=None, scoring=None):
426507
"""
427508
Score the fitted CATE model on a new data set.
428509
@@ -453,7 +534,7 @@ def score(self, Y, T, X=None, W=None, sample_weight=None):
453534
The MSE of the final CATE model on the new data.
454535
"""
455536
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
456-
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight)
537+
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight, scoring=scoring)
457538

458539
@property
459540
def rlearner_model_final_(self):
@@ -493,3 +574,68 @@ def residuals_(self):
493574
"Set to `True` to enable residual storage.")
494575
Y_res, T_res = self._cached_values.nuisances
495576
return Y_res, T_res, self._cached_values.X, self._cached_values.W
577+
578+
@staticmethod
579+
def scoring_name(scoring: Union[str,Callable,None])->str:
580+
if scoring is None:
581+
return 'default_score'
582+
elif isinstance(scoring,str):
583+
return scoring
584+
elif callable(scoring):
585+
return scoring.__name__
586+
else:
587+
raise ValueError("Scoring should be str|Callable|None")
588+
589+
590+
def score_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, y_scoring=None,
591+
t_scoring=None, t_score_by_dim=False):
592+
"""
593+
Score the fitted nuisance models on arbitrary data and using any supported sklearn scoring.
594+
595+
Parameters
596+
----------
597+
Y: (n, d_y) matrix or vector of length n
598+
Outcomes for each sample
599+
T: (n, d_t) matrix or vector of length n
600+
Treatments for each sample
601+
X: (n, d_x) matrix, optional
602+
Features for each sample
603+
W: (n, d_w) matrix, optional
604+
Controls for each sample
605+
Z: (n, d_z) matrix, optional
606+
Instruments for each sample
607+
sample_weight:(n,) vector, optional
608+
Weights for each samples
609+
t_scoring: str, optional
610+
Name of an sklearn scoring function to use instead of the default for model_t, choices
611+
are from sklearn.get_scoring_names() plus pearsonr
612+
y_scoring: str, optional
613+
Name of an sklearn scoring function to use instead of the default for model_y, choices
614+
are from sklearn.get_scoring_names() plus pearsonr
615+
t_score_by_dim: bool, default=False
616+
Score prediction of treatment dimensions separately
617+
618+
Returns
619+
-------
620+
score_dict : dict[str,list[float]]
621+
A dictionary where the keys indicate the Y and T scores used and the values are
622+
lists of scores, one per CV fold model.
623+
"""
624+
Y_key = f'Y_{_RLearner.scoring_name(y_scoring)}'
625+
T_Key = f'T_{_RLearner.scoring_name(t_scoring)}'
626+
score_dict = {
627+
Y_key : [],
628+
T_Key : []
629+
}
630+
631+
# For discrete treatments, these will have to be one hot encoded
632+
Y_2_score = pd.get_dummies(Y) if self.discrete_outcome and (len(Y.shape) == 1 or Y.shape[1] == 1) else Y
633+
T_2_score = pd.get_dummies(T) if self.discrete_treatment and (len(T.shape) == 1 or T.shape[1] == 1) else T
634+
635+
for m in self._models_nuisance[0]:
636+
Y_score, T_score = m.score(Y_2_score, T_2_score, X=X, W=W, Z=Z, sample_weight=sample_weight,
637+
y_scoring=y_scoring, t_scoring=t_scoring,
638+
t_score_by_dim=t_score_by_dim)
639+
score_dict[Y_key].append(Y_score)
640+
score_dict[T_Key].append(T_score)
641+
return score_dict

econml/dml/dml.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from sklearn.preprocessing import (FunctionTransformer)
1111
from sklearn.utils import check_random_state
1212

13+
1314
from .._ortho_learner import _OrthoLearner
14-
from ._rlearner import _RLearner
15+
from ._rlearner import _RLearner, _ModelFinal
1516
from .._cate_estimator import (DebiasedLassoCateEstimatorMixin,
1617
LinearModelFinalCateEstimatorMixin,
1718
StatsModelsCateEstimatorMixin,
@@ -54,20 +55,42 @@ def predict(self, X, W):
5455
raise AttributeError("Cannot use a classifier as a first stage model when the target is continuous!")
5556
return self._model.predict(_combine(X, W, n_samples))
5657

57-
def score(self, X, W, Target, sample_weight=None):
58-
if hasattr(self._model, 'score'):
59-
if self._discrete_target:
60-
# In this case, the Target is the one-hot-encoding of the treatment variable
61-
# We need to go back to the label representation of the one-hot so as to call
62-
# the classifier.
63-
Target = inverse_onehot(Target)
58+
def score(self, X, W, Target, sample_weight=None, scoring=None, score_by_dim=False):
59+
"""
60+
Score the first stage model on provided data.
61+
62+
:param X: Nuisances
63+
:param W: Treatments
64+
:param Target: The true targets
65+
:param sample_weight: optional sample weights
66+
:param scoring: non-standard scoring function name from sklearn get_scorer. Results in
67+
call to _rlearner._wrap_scoring
68+
:param score_by_dim: If a multi-dimension treatment, score each treatment separately.
69+
:return:
70+
"""
71+
XW_combined = _combine(X, W, Target.shape[0])
72+
if self._discrete_target:
73+
# In this case, the Target is the one-hot-encoding of the treatment variable
74+
# We need to go back to the label representation of the one-hot so as to call
75+
# the classifier.
76+
Target = inverse_onehot(Target)
77+
if hasattr(self._model, 'score') and scoring is None and not score_by_dim:
78+
# Standard default model scoring
6479
if sample_weight is not None:
65-
return self._model.score(_combine(X, W, Target.shape[0]), Target, sample_weight=sample_weight)
80+
return self._model.score(XW_combined, Target, sample_weight=sample_weight)
6681
else:
67-
return self._model.score(_combine(X, W, Target.shape[0]), Target)
82+
return self._model.score(XW_combined, Target)
83+
elif hasattr(self._model, 'score'):
84+
return _FirstStageWrapper._wrap_scoring(scoring,Y_true=Target, X=XW_combined, est=self._model,
85+
sample_weight=sample_weight, score_by_dim=score_by_dim)
6886
else:
6987
return None
7088

89+
@staticmethod
90+
def _wrap_scoring(scoring, Y_true, X, est, sample_weight=None, score_by_dim=False):
91+
"""Predict from the estimator, and use the _ModelFinal.wrap_scoring function."""
92+
Y_pred = est.predict(X)
93+
return _ModelFinal.wrap_scoring(scoring, Y_true, Y_pred, sample_weight, score_by_dim=score_by_dim)
7194

7295
class _FirstStageSelector(SingleModelSelector):
7396
def __init__(self, model: SingleModelSelector, discrete_target):

0 commit comments

Comments
 (0)