Skip to content

Commit 2e07d21

Browse files
devmotiongithub-actions[bot]sethaxen
authored
Use EvoTrees instead of XGBoost in documentation (#57)
* Use EvoTrees instead of XGBoost * Update runtests.jl * Try to fix use of MLJ interface * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/rstar.jl * Apply suggestions from code review Co-authored-by: Seth Axen <[email protected]> * Fix RNG of `EvoTreeClassifier` in tests * Update rstar.jl * Update rstar.jl * Update test/rstar.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump compat entries * Apply suggested changes * Use a more advanced example * Test XGBoost as well * Update rstar.jl * Update rstar.jl * Update EvoTrees dependency * Update documentation * Use MLJ traits * Update src/rstar.jl Co-authored-by: Seth Axen <[email protected]> * Some refactoring and additional tests * Update Project.toml Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Seth Axen <[email protected]>
1 parent 1799e79 commit 2e07d21

File tree

9 files changed

+207
-67
lines changed

9 files changed

+207
-67
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
version:
17-
- '1.3'
17+
- '1.6'
1818
- '1'
1919
- 'nightly'
2020
os:

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
*.jl.*.cov
22
*.jl.cov
33
*.jl.mem
4-
/Manifest.toml
5-
/test/Manifest.toml
6-
/test/rstar/Manifest.toml
4+
Manifest.toml
75
/docs/build/

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MCMCDiagnosticTools"
22
uuid = "be115224-59cd-429b-ad48-344e309966f0"
33
authors = ["David Widmann"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -27,7 +27,7 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
2727
StatsBase = "0.33"
2828
StatsFuns = "1"
2929
Tables = "1"
30-
julia = "1.3"
30+
julia = "1.6"
3131

3232
[extras]
3333
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

docs/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
34
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
45
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
5-
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
6+
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
67
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[compat]
1011
Documenter = "0.27"
12+
EvoTrees = "0.14.7"
1113
MCMCDiagnosticTools = "0.2"
1214
MLJBase = "0.19, 0.20, 0.21"
13-
MLJXGBoostInterface = "0.1, 0.2, 0.3"
14-
julia = "1.3"
15+
MLJIteration = "0.5"
16+
julia = "1.6"

src/MCMCDiagnosticTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using AbstractFFTs: AbstractFFTs
44
using DataAPI: DataAPI
55
using DataStructures: DataStructures
66
using Distributions: Distributions
7-
using MLJModelInterface: MLJModelInterface
7+
using MLJModelInterface: MLJModelInterface as MMI
88
using SpecialFunctions: SpecialFunctions
99
using StatsBase: StatsBase
1010
using StatsFuns: StatsFuns

src/rstar.jl

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
rstar(
33
rng::Random.AbstractRNG=Random.default_rng(),
4-
classifier::MLJModelInterface.Supervised,
4+
classifier,
55
samples,
66
chain_indices::AbstractVector{Int};
77
subset::Real=0.7,
@@ -21,53 +21,97 @@ This method supports ragged chains, i.e. chains of nonequal lengths.
2121
"""
2222
function rstar(
2323
rng::Random.AbstractRNG,
24-
classifier::MLJModelInterface.Supervised,
24+
classifier,
2525
x,
2626
y::AbstractVector{Int};
2727
subset::Real=0.7,
2828
split_chains::Int=2,
2929
verbosity::Int=0,
3030
)
31-
# checks
32-
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
31+
# check the arguments
32+
_check_model_supports_continuous_inputs(classifier)
33+
_check_model_supports_multiclass_targets(classifier)
34+
_check_model_supports_multiclass_predictions(classifier)
35+
MMI.nrows(x) != length(y) && throw(DimensionMismatch())
3336
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))
3437

35-
ysplit = split_chain_indices(y, split_chains)
36-
3738
# randomly sub-select training and testing set
39+
ysplit = split_chain_indices(y, split_chains)
3840
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
3941
0 < length(train_ids) < length(y) ||
4042
throw(ArgumentError("training and test data subsets must not be empty"))
4143

4244
xtable = _astable(x)
45+
ycategorical = MMI.categorical(ysplit)
46+
xdata, ydata = MMI.reformat(classifier, xtable, ycategorical)
4347

4448
# train classifier on training data
45-
ycategorical = MLJModelInterface.categorical(ysplit)
46-
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
47-
fitresult, _ = MLJModelInterface.fit(
48-
classifier, verbosity, xtrain, ycategorical[train_ids]
49-
)
49+
xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata)
50+
fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain)
5051

5152
# compute predictions on test data
52-
xtest = MLJModelInterface.selectrows(xtable, test_ids)
53+
xtest, = MMI.selectrows(classifier, test_ids, xdata)
54+
ytest = ycategorical[test_ids]
5355
predictions = _predict(classifier, fitresult, xtest)
5456

5557
# compute statistic
56-
ytest = ycategorical[test_ids]
57-
result = _rstar(predictions, ytest)
58+
result = _rstar(MMI.scitype(predictions), predictions, ytest)
5859

5960
return result
6061
end
6162

63+
# check that the model supports the inputs and targets, and has predictions of the desired form
64+
function _check_model_supports_continuous_inputs(classifier)
65+
# ideally we would not allow MMI.Unknown but some models do not implement the traits
66+
input_scitype_classifier = MMI.input_scitype(classifier)
67+
if input_scitype_classifier !== MMI.Unknown &&
68+
!(MMI.Table(MMI.Continuous) <: input_scitype_classifier)
69+
throw(
70+
ArgumentError(
71+
"classifier does not support tables of continuous values as inputs"
72+
),
73+
)
74+
end
75+
return nothing
76+
end
77+
function _check_model_supports_multiclass_targets(classifier)
78+
target_scitype_classifier = MMI.target_scitype(classifier)
79+
if target_scitype_classifier !== MMI.Unknown &&
80+
!(AbstractVector{<:MMI.Finite} <: target_scitype_classifier)
81+
throw(
82+
ArgumentError(
83+
"classifier does not support vectors of multi-class labels as targets"
84+
),
85+
)
86+
end
87+
return nothing
88+
end
89+
function _check_model_supports_multiclass_predictions(classifier)
90+
if !(
91+
MMI.predict_scitype(classifier) <: Union{
92+
MMI.Unknown,
93+
AbstractVector{<:MMI.Finite},
94+
AbstractVector{<:MMI.Density{<:MMI.Finite}},
95+
}
96+
)
97+
throw(
98+
ArgumentError(
99+
"classifier does not support vectors of multi-class labels or their densities as predictions",
100+
),
101+
)
102+
end
103+
return nothing
104+
end
105+
62106
_astable(x::AbstractVecOrMat) = Tables.table(x)
63107
_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table"))
64108

65109
# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
66110
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
67111
# TODO: Remove once the upstream issue is fixed
68-
function _predict(model::MLJModelInterface.Model, fitresult, x)
69-
y = MLJModelInterface.predict(model, fitresult, x)
70-
return if :predict in MLJModelInterface.reporting_operations(model)
112+
function _predict(model::MMI.Model, fitresult, x)
113+
y = MMI.predict(model, fitresult, x)
114+
return if :predict in MMI.reporting_operations(model)
71115
first(y)
72116
else
73117
y
@@ -77,7 +121,7 @@ end
77121
"""
78122
rstar(
79123
rng::Random.AbstractRNG=Random.default_rng(),
80-
classifier::MLJModelInterface.Supervised,
124+
classifier,
81125
samples::AbstractArray{<:Real,3};
82126
subset::Real=0.7,
83127
split_chains::Int=2,
@@ -109,77 +153,111 @@ is returned (algorithm 2).
109153
# Examples
110154
111155
```jldoctest rstar; setup = :(using Random; Random.seed!(101))
112-
julia> using MLJBase, MLJXGBoostInterface, Statistics
156+
julia> using MLJBase, MLJIteration, EvoTrees, Statistics
113157
114158
julia> samples = fill(4.0, 100, 3, 2);
115159
```
116160
117-
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
161+
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a
118162
probabilistic classifier.
163+
For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`:
119164
120165
```jldoctest rstar
121-
julia> distribution = rstar(XGBoostClassifier(), samples);
166+
julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05);
122167
123-
julia> isapprox(mean(distribution), 1; atol=0.1)
124-
true
168+
julia> distribution = rstar(model, samples);
169+
170+
julia> round(mean(distribution); digits=2)
171+
1.0f0
172+
```
173+
174+
Note, however, that it is recommended to determine `nrounds` based on early-stopping.
175+
With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations):
176+
177+
```jldoctest rstar
178+
julia> model = IteratedModel(;
179+
model=EvoTreeClassifier(; eta=0.05),
180+
iteration_parameter=:nrounds,
181+
resampling=Holdout(),
182+
measures=log_loss,
183+
controls=[Step(5), Patience(2), NumberLimit(100)],
184+
retrain=true,
185+
);
186+
187+
julia> distribution = rstar(model, samples);
188+
189+
julia> round(mean(distribution); digits=2)
190+
1.0f0
125191
```
126192
127193
For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned.
128194
Deterministic classifiers can also be derived from probabilistic classifiers by e.g.
129195
predicting the mode. In MLJ this corresponds to a pipeline of models.
130196
131197
```jldoctest rstar
132-
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);
198+
julia> evotree_deterministic = Pipeline(model; operation=predict_mode);
133199
134-
julia> value = rstar(xgboost_deterministic, samples);
200+
julia> value = rstar(evotree_deterministic, samples);
135201
136-
julia> isapprox(value, 1; atol=0.2)
137-
true
202+
julia> round(value; digits=2)
203+
1.0
138204
```
139205
140206
# References
141207
142208
Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers.
143209
"""
144-
function rstar(
145-
rng::Random.AbstractRNG,
146-
classifier::MLJModelInterface.Supervised,
147-
x::AbstractArray{<:Any,3};
148-
kwargs...,
149-
)
210+
function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...)
150211
samples = reshape(x, :, size(x, 3))
151212
chain_inds = repeat(axes(x, 2); inner=size(x, 1))
152213
return rstar(rng, classifier, samples, chain_inds; kwargs...)
153214
end
154215

155-
function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...)
156-
return rstar(Random.default_rng(), classif, x, y; kwargs...)
216+
function rstar(classifier, x, y::AbstractVector{Int}; kwargs...)
217+
return rstar(Random.default_rng(), classifier, x, y; kwargs...)
157218
end
158219

159-
function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...)
160-
return rstar(Random.default_rng(), classif, x; kwargs...)
220+
function rstar(classifier, x::AbstractArray{<:Any,3}; kwargs...)
221+
return rstar(Random.default_rng(), classifier, x; kwargs...)
161222
end
162223

163224
# R⋆ for deterministic predictions (algorithm 1)
164-
function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T}
225+
function _rstar(
226+
::Type{<:AbstractVector{<:MMI.Finite}},
227+
predictions::AbstractVector,
228+
ytest::AbstractVector,
229+
)
165230
length(predictions) == length(ytest) ||
166231
error("numbers of predictions and targets must be equal")
167232
mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest))
168-
nclasses = length(MLJModelInterface.classes(ytest))
233+
nclasses = length(MMI.classes(ytest))
169234
return nclasses * mean_accuracy
170235
end
171236

172237
# R⋆ for probabilistic predictions (algorithm 2)
173-
function _rstar(predictions::AbstractVector, ytest::AbstractVector)
238+
function _rstar(
239+
::Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}},
240+
predictions::AbstractVector,
241+
ytest::AbstractVector,
242+
)
174243
length(predictions) == length(ytest) ||
175244
error("numbers of predictions and targets must be equal")
176245

177246
# create Poisson binomial distribution with support `0:length(predictions)`
178247
distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest))
179248

180249
# scale distribution to support in `[0, nclasses]`
181-
nclasses = length(MLJModelInterface.classes(ytest))
250+
nclasses = length(MMI.classes(ytest))
182251
scaled_distribution = (nclasses//length(predictions)) * distribution
183252

184253
return scaled_distribution
185254
end
255+
256+
# unsupported types of predictions and targets
257+
function _rstar(::Any, predictions, targets)
258+
throw(
259+
ArgumentError(
260+
"unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))",
261+
),
262+
)
263+
end

test/Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
4+
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
45
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
56
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
67
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
78
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
89
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
10+
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
911
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
12+
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
1013
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
1114
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
12-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1315
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1416
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1517
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -19,14 +21,17 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1921
[compat]
2022
Distributions = "0.25"
2123
DynamicHMC = "3"
24+
EvoTrees = "0.14.7"
2225
FFTW = "1.1"
2326
LogDensityProblems = "0.12, 1, 2"
2427
LogExpFunctions = "0.3"
2528
MCMCDiagnosticTools = "0.2"
2629
MLJBase = "0.19, 0.20, 0.21"
27-
MLJLIBSVMInterface = "0.1, 0.2"
28-
MLJXGBoostInterface = "0.1, 0.2, 0.3"
30+
MLJIteration = "0.5"
31+
MLJLIBSVMInterface = "0.2"
32+
MLJModels = "0.16"
33+
MLJXGBoostInterface = "0.3"
2934
OffsetArrays = "1"
3035
StatsBase = "0.33"
3136
Tables = "1"
32-
julia = "1.3"
37+
julia = "1.6"

0 commit comments

Comments
 (0)