Skip to content

Commit fb63aaa

Browse files
committed
add missing examples.md file; plus some tweaks
1 parent 5836481 commit fb63aaa

File tree

5 files changed

+224
-14
lines changed

5 files changed

+224
-14
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,12 @@ using LearnTestAPI
324324

325325
## [Other data patterns](@id di)
326326

327-
Here are some important remarks for implementations wanting to deviate in their
327+
Here are some important remarks for implementations deviating in their
328328
assumptions about data from those made above.
329329

330330
- New implementations of `fit`, `predict`, etc, always have a *single* `data` argument as
331-
above. For convenience, a signature such as `fit(learner, X, y)`, calling `fit(learner,
332-
(X, y))`, can be added, but the LearnAPI.jl specification is silent on the meaning or
331+
above. For convenience, a signature such as `fit(learner, table, formula)`, calling `fit(learner,
332+
(table, formula))`, can be added, but the LearnAPI.jl specification is silent on the meaning or
333333
existence of signatures with extra arguments.
334334

335335
- If the `data` object consumed by `fit`, `predict`, or `transform` is not not a suitable
@@ -415,7 +415,7 @@ The [`obs`](@ref) methods exist to:
415415
!!! important
416416

417417
While many new learner implementations will want to adopt a canned data front end, such as those provided by [LearnDataFrontEnds.jl](https://juliaai.github.io/LearnAPI.jl/dev/), we
418-
focus here on a self-contained implemementation of `obs` for the ridge example above, to show
418+
focus here on a self-contained implementation of `obs` for the ridge example above, to show
419419
how it works.
420420

421421
In the typical case, where [`LearnAPI.data_interface`](@ref) is not overloaded, the

docs/src/examples.md

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# [Code for ridge example](@id code)
2+
3+
Below is the complete source code for the ridge implementations described in the tutorial,
4+
[Anatomy of an Implementation](@ref).
5+
6+
- [Basic implementation](@ref)
7+
- [Implementation with data front end](@ref)
8+
9+
10+
## Basic implementation
11+
12+
```julia
13+
using LearnAPI
14+
using LinearAlgebra, Tables
15+
16+
struct Ridge{T<:Real}
17+
lambda::T
18+
end
19+
20+
"""
21+
Ridge(; lambda=0.1)
22+
23+
Instantiate a ridge regression learner, with regularization of `lambda`.
24+
"""
25+
Ridge(; lambda=0.1) = Ridge(lambda)
26+
LearnAPI.constructor(::Ridge) = Ridge
27+
28+
# struct for output of `fit`
29+
struct RidgeFitted{T,F}
30+
learner::Ridge
31+
coefficients::Vector{T}
32+
named_coefficients::F
33+
end
34+
35+
function LearnAPI.fit(learner::Ridge, data; verbosity=1)
36+
X, y = data
37+
38+
# data preprocessing:
39+
table = Tables.columntable(X)
40+
names = Tables.columnnames(table) |> collect
41+
A = Tables.matrix(table, transpose=true)
42+
43+
lambda = learner.lambda
44+
45+
# apply core algorithm:
46+
coefficients = (A*A' + learner.lambda*I)\(A*y) # vector
47+
48+
# determine named coefficients:
49+
named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]
50+
51+
# make some noise, if allowed:
52+
verbosity > 0 && @info "Coefficients: $named_coefficients"
53+
54+
return RidgeFitted(learner, coefficients, named_coefficients)
55+
end
56+
57+
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
58+
Tables.matrix(Xnew)*model.coefficients
59+
60+
# accessor functions:
61+
LearnAPI.learner(model::RidgeFitted) = model.learner
62+
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
63+
LearnAPI.strip(model::RidgeFitted) =
64+
RidgeFitted(model.learner, model.coefficients, nothing)
65+
66+
@trait(
67+
Ridge,
68+
constructor = Ridge,
69+
kinds_of_proxy=(Point(),),
70+
tags = ("regression",),
71+
functions = (
72+
:(LearnAPI.fit),
73+
:(LearnAPI.learner),
74+
:(LearnAPI.clone),
75+
:(LearnAPI.strip),
76+
:(LearnAPI.obs),
77+
:(LearnAPI.features),
78+
:(LearnAPI.target),
79+
:(LearnAPI.predict),
80+
:(LearnAPI.coefficients),
81+
)
82+
)
83+
84+
# convenience method:
85+
LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)
86+
```
87+
88+
# Implementation with data front end
89+
90+
```julia
91+
using LearnAPI
92+
using LinearAlgebra, Tables
93+
94+
struct Ridge{T<:Real}
95+
lambda::T
96+
end
97+
98+
Ridge(; lambda=0.1) = Ridge(lambda)
99+
100+
# struct for output of `fit`:
101+
struct RidgeFitted{T,F}
102+
learner::Ridge
103+
coefficients::Vector{T}
104+
named_coefficients::F
105+
end
106+
107+
# struct for internal representation of training data:
108+
struct RidgeFitObs{T,M<:AbstractMatrix{T}}
109+
A::M # `p` x `n` matrix
110+
names::Vector{Symbol} # features
111+
y::Vector{T} # target
112+
end
113+
114+
# implementation of `RandomAccess()` data interface for such representation:
115+
Base.getindex(data::RidgeFitObs, I) =
116+
RidgeFitObs(data.A[:,I], data.names, y[I])
117+
Base.length(data::RidgeFitObs) = length(data.y)
118+
119+
# data front end for `fit`:
120+
function LearnAPI.obs(::Ridge, data)
121+
X, y = data
122+
table = Tables.columntable(X)
123+
names = Tables.columnnames(table) |> collect
124+
return RidgeFitObs(Tables.matrix(table)', names, y)
125+
end
126+
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
127+
128+
function LearnAPI.fit(learner::Ridge, observations::RidgeFitObs; verbosity=1)
129+
130+
lambda = learner.lambda
131+
132+
A = observations.A
133+
names = observations.names
134+
y = observations.y
135+
136+
# apply core learner:
137+
coefficients = (A*A' + learner.lambda*I)\(A*y) # 1 x p matrix
138+
139+
# determine named coefficients:
140+
named_coefficients = [names[j] => coefficients[j] for j in eachindex(names)]
141+
142+
# make some noise, if allowed:
143+
verbosity > 0 && @info "Coefficients: $named_coefficients"
144+
145+
return RidgeFitted(learner, coefficients, named_coefficients)
146+
147+
end
148+
149+
LearnAPI.fit(learner::Ridge, data; kwargs...) =
150+
fit(learner, obs(learner, data); kwargs...)
151+
152+
# data front end for `predict`:
153+
LearnAPI.obs(::RidgeFitted, Xnew) = Tables.matrix(Xnew)'
154+
LearnAPI.obs(::RidgeFitted, observations::AbstractArray) = observations # involutivity
155+
156+
LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
157+
observations'*model.coefficients
158+
159+
LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
160+
predict(model, Point(), obs(model, Xnew))
161+
162+
# methods to deconstruct training data:
163+
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
164+
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
165+
LearnAPI.features(learner::Ridge, data) = LearnAPI.features(learner, obs(learner, data))
166+
LearnAPI.target(learner::Ridge, data) = LearnAPI.target(learner, obs(learner, data))
167+
168+
# accessor functions:
169+
LearnAPI.learner(model::RidgeFitted) = model.learner
170+
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
171+
LearnAPI.strip(model::RidgeFitted) =
172+
RidgeFitted(model.learner, model.coefficients, nothing)
173+
174+
@trait(
175+
Ridge,
176+
constructor = Ridge,
177+
kinds_of_proxy=(Point(),),
178+
tags = ("regression",),
179+
functions = (
180+
:(LearnAPI.fit),
181+
:(LearnAPI.learner),
182+
:(LearnAPI.clone),
183+
:(LearnAPI.strip),
184+
:(LearnAPI.obs),
185+
:(LearnAPI.features),
186+
:(LearnAPI.target),
187+
:(LearnAPI.predict),
188+
:(LearnAPI.coefficients),
189+
)
190+
)
191+
192+
```

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Suppose `forest` is some object encapsulating the hyperparameters of the [random
4747
algorithm](https://en.wikipedia.org/wiki/Random_forest) (the number of trees, etc.). Then,
4848
a LearnAPI.jl interface can be implemented, for objects with the type of `forest`, to
4949
enable the basic workflow below. In this case data is presented following the
50-
"scikit-learn" `X, y` pattern, although LearnAPI.jl supports other patterns as well.
50+
"scikit-learn" `X, y` pattern, although LearnAPI.jl supports other data pattern.
5151

5252
```julia
5353
# `X` is some training features

docs/src/reference.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ number of user-specified *hyperparameters*, such as the number of trees in a ran
3838
forest. Hyperparameters are understood in a rather broad sense. For example, one is
3939
allowed to have hyperparameters that are not data-generic. For example, a class weight
4040
dictionary, which will only make sense for a target taking values in the set of specified
41-
dictionary keys, should be given as a hyperparameter. For simplicity, LearnAPI.jl
42-
discourages "run time" parameters (extra arguments to `fit`) such as acceleration
43-
options (cpu/gpu/multithreading/multiprocessing). These should be included as
41+
dictionary keys, should be given as a hyperparameter. For simplicity and composability,
42+
LearnAPI.jl discourages "run time" parameters (extra arguments to `fit`) such as
43+
acceleration options (cpu/gpu/multithreading/multiprocessing). These should be included as
4444
hyperparameters as far as possible. An exception is the compulsory `verbosity` keyword
4545
argument of `fit`.
4646

@@ -102,7 +102,7 @@ generally requires overloading `Base.==` for the struct.
102102
!!! important
103103

104104
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make
105-
deep copies of RNG hyperparameters before using them in a new implementation of
105+
deep copies of RNG hyperparameters before using them in an implementation of
106106
[`fit`](@ref).
107107

108108
#### Composite learners (wrappers)
@@ -114,9 +114,6 @@ properties that are not in [`LearnAPI.learners(learner)`](@ref). Instead, these
114114
learner-valued properties can have a `nothing` default, with the constructor throwing an
115115
error if the constructor call does not explicitly specify a new value.
116116

117-
Any object `learner` for which [`LearnAPI.functions(learner)`](@ref) is non-empty is
118-
understood to have a valid implementation of the LearnAPI.jl interface.
119-
120117
#### Example
121118

122119
Below is an example of a learner type with a valid constructor:
@@ -139,6 +136,14 @@ GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01)
139136
LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor
140137
```
141138

139+
#### Testing something is a learner
140+
141+
Any object `object` for which [`LearnAPI.functions(object)`](@ref) is non-empty is
142+
understood to have a valid implementation of the LearnAPI.jl interface. You can test this
143+
with the convenience method [`LearnAPI.is_learner(object)`](@ref) but this is never explicitly
144+
overloaded.
145+
146+
142147
## Documentation
143148

144149
Attach public LearnAPI.jl-related documentation for a learner to it's *constructor*,
@@ -200,11 +205,14 @@ Most learners will also implement [`predict`](@ref) and/or [`transform`](@ref).
200205

201206
## Utilities
202207

208+
209+
- [`LearnAPI.is_learner`](@ref)
203210
- [`clone`](@ref): for cloning a learner with specified hyperparameter replacements.
204211
- [`@trait`](@ref): for simultaneously declaring multiple traits
205212
- [`@functions`](@ref): for listing functions available for use with a learner
206213

207214
```@docs
215+
LearnAPI.is_learner
208216
clone
209217
@trait
210218
@functions

src/traits.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ reference functions not owned by LearnAPI.jl.
7474
The understanding is that `learner` is a LearnAPI-compliant object whenever the return
7575
value is non-empty.
7676
77-
Do `LearnAPI.functions()` to list all possible elements of the return value owned by
78-
LearnAPI.jl.
77+
Do `LearnAPI.functions()` to list all possible elements of the return value representing
78+
functions owned by LearnAPI.jl.
7979
8080
# Extended help
8181
@@ -513,6 +513,16 @@ This trait should not be overloaded. Instead overload [`LearnAPI.nonlearners`](@
513513
514514
"""
515515
learners(learner) = setdiff(propertynames(learner), nonlearners(learner))
516+
517+
"""
518+
LearnAPI.is_learner(object)
519+
520+
Returns `true` if `object` has a valid implementation of the LearnAPI.jl
521+
interface. Equivalent to non-emptiness of [`LearnAPI.functions(object)`](@ref).
522+
523+
This trait should never be overloaded explicitly.
524+
525+
"""
516526
is_learner(learner) = !isempty(functions(learner))
517527
preferred_kind_of_proxy(learner) = first(kinds_of_proxy(learner))
518528
target(learner) = :(LearnAPI.target) in functions(learner)

0 commit comments

Comments
 (0)