11"""
22 rstar(
33 rng::Random.AbstractRNG=Random.default_rng(),
4- classifier::MLJModelInterface.Supervised ,
4+ classifier,
55 samples,
66 chain_indices::AbstractVector{Int};
77 subset::Real=0.7,
@@ -21,53 +21,97 @@ This method supports ragged chains, i.e. chains of nonequal lengths.
2121"""
2222function rstar (
2323 rng:: Random.AbstractRNG ,
24- classifier:: MLJModelInterface.Supervised ,
24+ classifier,
2525 x,
2626 y:: AbstractVector{Int} ;
2727 subset:: Real = 0.7 ,
2828 split_chains:: Int = 2 ,
2929 verbosity:: Int = 0 ,
3030)
31- # checks
32- MLJModelInterface. nrows (x) != length (y) && throw (DimensionMismatch ())
31+ # check the arguments
32+ _check_model_supports_continuous_inputs (classifier)
33+ _check_model_supports_multiclass_targets (classifier)
34+ _check_model_supports_multiclass_predictions (classifier)
35+ MMI. nrows (x) != length (y) && throw (DimensionMismatch ())
3336 0 < subset < 1 || throw (ArgumentError (" `subset` must be a number in (0, 1)" ))
3437
35- ysplit = split_chain_indices (y, split_chains)
36-
3738 # randomly sub-select training and testing set
39+ ysplit = split_chain_indices (y, split_chains)
3840 train_ids, test_ids = shuffle_split_stratified (rng, ysplit, subset)
3941 0 < length (train_ids) < length (y) ||
4042 throw (ArgumentError (" training and test data subsets must not be empty" ))
4143
4244 xtable = _astable (x)
45+ ycategorical = MMI. categorical (ysplit)
46+ xdata, ydata = MMI. reformat (classifier, xtable, ycategorical)
4347
4448 # train classifier on training data
45- ycategorical = MLJModelInterface. categorical (ysplit)
46- xtrain = MLJModelInterface. selectrows (xtable, train_ids)
47- fitresult, _ = MLJModelInterface. fit (
48- classifier, verbosity, xtrain, ycategorical[train_ids]
49- )
49+ xtrain, ytrain = MMI. selectrows (classifier, train_ids, xdata, ydata)
50+ fitresult, _ = MMI. fit (classifier, verbosity, xtrain, ytrain)
5051
5152 # compute predictions on test data
52- xtest = MLJModelInterface. selectrows (xtable, test_ids)
53+ xtest, = MMI. selectrows (classifier, test_ids, xdata)
54+ ytest = ycategorical[test_ids]
5355 predictions = _predict (classifier, fitresult, xtest)
5456
5557 # compute statistic
56- ytest = ycategorical[test_ids]
57- result = _rstar (predictions, ytest)
58+ result = _rstar (MMI. scitype (predictions), predictions, ytest)
5859
5960 return result
6061end
6162
63+ # check that the model supports the inputs and targets, and has predictions of the desired form
64+ function _check_model_supports_continuous_inputs (classifier)
65+ # ideally we would not allow MMI.Unknown but some models do not implement the traits
66+ input_scitype_classifier = MMI. input_scitype (classifier)
67+ if input_scitype_classifier != = MMI. Unknown &&
68+ ! (MMI. Table (MMI. Continuous) <: input_scitype_classifier )
69+ throw (
70+ ArgumentError (
71+ " classifier does not support tables of continuous values as inputs"
72+ ),
73+ )
74+ end
75+ return nothing
76+ end
77+ function _check_model_supports_multiclass_targets (classifier)
78+ target_scitype_classifier = MMI. target_scitype (classifier)
79+ if target_scitype_classifier != = MMI. Unknown &&
80+ ! (AbstractVector{<: MMI.Finite } <: target_scitype_classifier )
81+ throw (
82+ ArgumentError (
83+ " classifier does not support vectors of multi-class labels as targets"
84+ ),
85+ )
86+ end
87+ return nothing
88+ end
89+ function _check_model_supports_multiclass_predictions (classifier)
90+ if ! (
91+ MMI. predict_scitype (classifier) <: Union {
92+ MMI. Unknown,
93+ AbstractVector{<: MMI.Finite },
94+ AbstractVector{<: MMI.Density{<:MMI.Finite} },
95+ }
96+ )
97+ throw (
98+ ArgumentError (
99+ " classifier does not support vectors of multi-class labels or their densities as predictions" ,
100+ ),
101+ )
102+ end
103+ return nothing
104+ end
105+
62106_astable (x:: AbstractVecOrMat ) = Tables. table (x)
63107_astable (x) = Tables. istable (x) ? x : throw (ArgumentError (" Argument is not a valid table" ))
64108
65109# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
66110# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
67111# TODO : Remove once the upstream issue is fixed
68- function _predict (model:: MLJModelInterface .Model , fitresult, x)
69- y = MLJModelInterface . predict (model, fitresult, x)
70- return if :predict in MLJModelInterface . reporting_operations (model)
112+ function _predict (model:: MMI .Model , fitresult, x)
113+ y = MMI . predict (model, fitresult, x)
114+ return if :predict in MMI . reporting_operations (model)
71115 first (y)
72116 else
73117 y
77121"""
78122 rstar(
79123 rng::Random.AbstractRNG=Random.default_rng(),
80- classifier::MLJModelInterface.Supervised ,
124+ classifier,
81125 samples::AbstractArray{<:Real,3};
82126 subset::Real=0.7,
83127 split_chains::Int=2,
@@ -109,77 +153,111 @@ is returned (algorithm 2).
109153# Examples
110154
111155```jldoctest rstar; setup = :(using Random; Random.seed!(101))
112- julia> using MLJBase, MLJXGBoostInterface , Statistics
156+ julia> using MLJBase, MLJIteration, EvoTrees , Statistics
113157
114158julia> samples = fill(4.0, 100, 3, 2);
115159```
116160
117- One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
161+ One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a
118162probabilistic classifier.
163+ For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`:
119164
120165```jldoctest rstar
121- julia> distribution = rstar(XGBoostClassifier(), samples );
166+ julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05 );
122167
123- julia> isapprox(mean(distribution), 1; atol=0.1)
124- true
168+ julia> distribution = rstar(model, samples);
169+
170+ julia> round(mean(distribution); digits=2)
171+ 1.0f0
172+ ```
173+
174+ Note, however, that it is recommended to determine `nrounds` based on early-stopping.
175+ With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations):
176+
177+ ```jldoctest rstar
178+ julia> model = IteratedModel(;
179+ model=EvoTreeClassifier(; eta=0.05),
180+ iteration_parameter=:nrounds,
181+ resampling=Holdout(),
182+ measures=log_loss,
183+ controls=[Step(5), Patience(2), NumberLimit(100)],
184+ retrain=true,
185+ );
186+
187+ julia> distribution = rstar(model, samples);
188+
189+ julia> round(mean(distribution); digits=2)
190+ 1.0f0
125191```
126192
127193For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned.
128194Deterministic classifiers can also be derived from probabilistic classifiers by e.g.
129195predicting the mode. In MLJ this corresponds to a pipeline of models.
130196
131197```jldoctest rstar
132- julia> xgboost_deterministic = Pipeline(XGBoostClassifier() ; operation=predict_mode);
198+ julia> evotree_deterministic = Pipeline(model ; operation=predict_mode);
133199
134- julia> value = rstar(xgboost_deterministic , samples);
200+ julia> value = rstar(evotree_deterministic , samples);
135201
136- julia> isapprox (value, 1; atol=0. 2)
137- true
202+ julia> round (value; digits= 2)
203+ 1.0
138204```
139205
140206# References
141207
142208Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers.
143209"""
144- function rstar (
145- rng:: Random.AbstractRNG ,
146- classifier:: MLJModelInterface.Supervised ,
147- x:: AbstractArray{<:Any,3} ;
148- kwargs... ,
149- )
210+ function rstar (rng:: Random.AbstractRNG , classifier, x:: AbstractArray{<:Any,3} ; kwargs... )
150211 samples = reshape (x, :, size (x, 3 ))
151212 chain_inds = repeat (axes (x, 2 ); inner= size (x, 1 ))
152213 return rstar (rng, classifier, samples, chain_inds; kwargs... )
153214end
154215
155- function rstar (classif :: MLJModelInterface.Supervised , x, y:: AbstractVector{Int} ; kwargs... )
156- return rstar (Random. default_rng (), classif , x, y; kwargs... )
216+ function rstar (classifier , x, y:: AbstractVector{Int} ; kwargs... )
217+ return rstar (Random. default_rng (), classifier , x, y; kwargs... )
157218end
158219
159- function rstar (classif :: MLJModelInterface.Supervised , x:: AbstractArray{<:Any,3} ; kwargs... )
160- return rstar (Random. default_rng (), classif , x; kwargs... )
220+ function rstar (classifier , x:: AbstractArray{<:Any,3} ; kwargs... )
221+ return rstar (Random. default_rng (), classifier , x; kwargs... )
161222end
162223
163224# R⋆ for deterministic predictions (algorithm 1)
164- function _rstar (predictions:: AbstractVector{T} , ytest:: AbstractVector{T} ) where {T}
225+ function _rstar (
226+ :: Type{<:AbstractVector{<:MMI.Finite}} ,
227+ predictions:: AbstractVector ,
228+ ytest:: AbstractVector ,
229+ )
165230 length (predictions) == length (ytest) ||
166231 error (" numbers of predictions and targets must be equal" )
167232 mean_accuracy = Statistics. mean (p == y for (p, y) in zip (predictions, ytest))
168- nclasses = length (MLJModelInterface . classes (ytest))
233+ nclasses = length (MMI . classes (ytest))
169234 return nclasses * mean_accuracy
170235end
171236
172237# R⋆ for probabilistic predictions (algorithm 2)
173- function _rstar (predictions:: AbstractVector , ytest:: AbstractVector )
238+ function _rstar (
239+ :: Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}} ,
240+ predictions:: AbstractVector ,
241+ ytest:: AbstractVector ,
242+ )
174243 length (predictions) == length (ytest) ||
175244 error (" numbers of predictions and targets must be equal" )
176245
177246 # create Poisson binomial distribution with support `0:length(predictions)`
178247 distribution = Distributions. PoissonBinomial (map (Distributions. pdf, predictions, ytest))
179248
180249 # scale distribution to support in `[0, nclasses]`
181- nclasses = length (MLJModelInterface . classes (ytest))
250+ nclasses = length (MMI . classes (ytest))
182251 scaled_distribution = (nclasses// length (predictions)) * distribution
183252
184253 return scaled_distribution
185254end
255+
256+ # unsupported types of predictions and targets
257+ function _rstar (:: Any , predictions, targets)
258+ throw (
259+ ArgumentError (
260+ " unsupported types of predictions ($(typeof (predictions)) ) and targets ($(typeof (targets)) )" ,
261+ ),
262+ )
263+ end
0 commit comments