diff --git a/docs/.gitignore b/docs/.gitignore index 90fb8a549..909052372 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,4 +1,3 @@ build/ site/ - -#Temp to avoid to many changes +src/examples/ diff --git a/docs/Project.toml b/docs/Project.toml index 0e0cc849e..eee2727f7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,6 @@ [deps] +AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" diff --git a/docs/examples/automaticstatistician.jl b/docs/examples/automaticstatistician.jl new file mode 100644 index 000000000..e69de29bb diff --git a/docs/examples/deepkernellearning.jl b/docs/examples/deepkernellearning.jl new file mode 100644 index 000000000..f90dc07dd --- /dev/null +++ b/docs/examples/deepkernellearning.jl @@ -0,0 +1,72 @@ +# # Deep Kernel Learning with Flux +# ## Package loading +# We use a couple of useful packages to plot and optimize +# the different hyper-parameters +using KernelFunctions +using Flux +using Distributions, LinearAlgebra +using Plots +using ProgressMeter +using AbstractGPs +pyplot(); default(legendfontsize = 15.0, linewidth = 3.0) + +# ## Data creation +# We create a simple 1D Problem with very different variations +xmin = -3; xmax = 3 # Limits +N = 150 +noise = 0.01 +x_train = collect(eachrow(rand(Uniform(xmin, xmax), N))) # Training dataset +target_f(x) = sinc(abs(x) ^ abs(x)) # We use sinc with a highly varying value +target_f(x::AbstractArray) = target_f(first(x)) +y_train = target_f.(x_train) + randn(N) * noise +x_test = collect(eachrow(range(xmin, xmax, length=200))) # Testing dataset +spectral_mixture_kernel() +# ## Model definition +# We create a neural net with 2 layers and 10 units each +# The data is passed through the NN before being used in the kernel +neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5)) +# We use two cases : +# - The Squared Exponential Kernel +k = transform(SqExponentialKernel(), FunctionTransform(neuralnet) ) + +# We use AbstractGPs.jl to define our model +gpprior = GP(k) # GP Prior +fx = AbstractGPs.FiniteGP(gpprior, x_train, noise) # Prior on f +fp = posterior(fx, y_train) # Posterior of f + +# This compute the log evidence of `y`, +# which is going to be used as the objective +loss(y) = -logpdf(fx, y) + +@info "Init Loss = $(loss(y_train))" + +# Flux will automatically extract all the parameters of the kernel +ps = Flux.params(k) + +# We show the initial prediction with the untrained model +p_init = Plots.plot(vcat(x_test...), target_f, lab = "true f", title = "Loss = $(loss(y_train))") +Plots.scatter!(vcat(x_train...), y_train, lab = "data") +pred = marginals(fp(x_test)) +Plots.plot!(vcat(x_test...), mean.(pred), ribbon = std.(pred), lab = "Prediction") +# ## Training +anim = Animation() +nmax= 1000 +opt = Flux.ADAM(0.1) +@showprogress for i = 1:nmax + global grads = gradient(ps) do + loss(y_train) + end + Flux.Optimise.update!(opt, ps, grads) + if i % 100 == 0 + @info "$i/$nmax" + L = loss(y_train) + # @info "Loss = $L" + p = Plots.plot(vcat(x_test...), target_f, lab = "true f", title = "Loss = $(loss(y_train))") + p = Plots.scatter!(vcat(x_train...), y_train, lab = "data") + pred = marginals(posterior(fx, y_train)(x_test)) + Plots.plot!(vcat(x_test...), mean.(pred), ribbon = std.(pred), lab = "Prediction") + frame(anim) + display(p) + end +end +gif(anim, fps = 5) diff --git a/docs/examples/kernelridgeregression.jl b/docs/examples/kernelridgeregression.jl new file mode 100644 index 000000000..2efa474bd --- /dev/null +++ b/docs/examples/kernelridgeregression.jl @@ -0,0 +1,71 @@ +# # Kernel Ridge Regression +# ## We load KernelFunctions and some other packages +using KernelFunctions +using LinearAlgebra +using Distributions +using Plots; default(lw = 2.0, legendfontsize = 15.0) +using Flux: Optimise +using ForwardDiff +using Random: seed! +seed!(42) + +# ## Data Generation +# We generated data in 1 dimension +xmin = -3; xmax = 3 # Bounds of the data +N = 50 # Number of samples +x_train = rand(Uniform(xmin, xmax), N) # We sample 100 random samples +σ = 0.1 +y_train = sinc.(x_train) + randn(N) * σ # We create a function and add some noise +x_test = range(xmin-0.1, xmax+0.1, length=300) + +# Plot the data +scatter(x_train, y_train, lab = "data") +plot!(x_test, sinc, lab = "true function") + +# ## Kernel training +# To train the kernel parameters via ForwardDiff.jl +# we need to create a function creating a kernel from an array +kernelcall(θ) = transform( + exp(θ[1]) * SqExponentialKernel(),# + exp(θ[2]) * Matern32Kernel(), + exp(θ[3]), +) + +# From theory we know the prediction for a test set x given +# the kernel parameters and normalization constant +function f(x, x_train, y_train, θ) + k = kernelcall(θ[1:3]) + kernelmatrix(k, x, x_train) * + inv(kernelmatrix(k, x_train) + exp(θ[4]) * I) * y_train +end + +# We look how the prediction looks like +# with starting parameters [1.0, 1.0, 1.0, 1.0] we get : +ŷ = f(x_test, x_train, y_train, log.(ones(4))) +scatter(x_train, y_train, lab = "data") +plot!(x_test, sinc, lab = "true function") +plot!(x_test, ŷ, lab = "prediction") + +# We define the loss based on the L2 norm both +# for the loss and the regularization +function loss(θ) + ŷ = f(x_train, x_train, y_train, θ) + sum(abs2, y_train - ŷ) + exp(θ[4]) * norm(ŷ) +end + +# The loss with our starting point : +loss(log.(ones(4))) + +# ## Training the model +θ = vcat(log.([1.0, 0.0, 0.01]), log(0.001)) # Initial vector +anim = Animation() +opt = Optimise.ADAGrad(0.5) +for i = 1:30 + grads = ForwardDiff.gradient(x -> loss(x), θ) # We compute the gradients given the kernel parameters and regularization + Δ = Optimise.apply!(opt, θ, grads) + θ .-= Δ # We apply a simple Gradient descent algorithm + p = scatter(x_train, y_train, lab = "data", title = "i = $(i), Loss = $(round(loss(θ), digits = 4))") + plot!(x_test, sinc, lab = "true function") + plot!(x_test, f(x_test, x_train, y_train, θ), lab = "Prediction", lw = 3.0) + frame(anim) +end +gif(anim) diff --git a/docs/examples/svm.jl b/docs/examples/svm.jl new file mode 100644 index 000000000..9bb262496 --- /dev/null +++ b/docs/examples/svm.jl @@ -0,0 +1,48 @@ +# # Support Vector Machines +# ## Package loading +using KernelFunctions +using Distributions, LinearAlgebra +using Plots; default(legendfontsize = 15.0, ms = 5.0) + +# ## Data Generation +# ### We first generate a mixture of two Gaussians in 2 dimensions +xmin = -3; xmax = 3 # Limits for sampling μ₁ and μ₂ +μ = rand(Uniform(xmin, xmax), 2, 2) # Sample 2 Random Centers +# ### We then sample both y and x +N = 100 # Number of samples +y = rand((-1, 1), N) # Select randomly between the two classes +x = Vector{Vector{Float64}}(undef, N) # We preallocate x +x[y .== 1] = [rand(MvNormal(μ[:, 1], I)) for _ in 1:count(y.==1)] # Features for samples of class 1 +x[y .== -1] = [rand(MvNormal(μ[:, 2], I)) for _ in 1:count(y.==-1)] # Features for samples of class 2 +scatter(getindex.(x[y .== 1], 1), getindex.(x[y .== 1], 2), label = "y = 1", title = "Data") +scatter!(getindex.(x[y .== -1], 1), getindex.(x[y .== -1], 2), label = "y = 2") + +# ## Model Definition +# TODO Write theory here +# ### We create a kernel k +k = SqExponentialKernel() # SqExponentialKernel/RBFKernel +λ = 1.0 # Regularization parameter + +# ### We create a function to return the optimal prediction for a +# test data `x_new` +function f(x_new, x, y, k, λ) + kernelmatrix(k, x_new, x) * inv(kernelmatrix(k, x) + λ * I) * y # Optimal prediction f +end + +# ### We also compute the total loss of the model that we want to minimize +hingeloss(y, ŷ) = maximum(zero(ŷ), 1 - y * ŷ) # hingeloss function +function reg_hingeloss(k, x, y, λ) + ŷ = f(x, x, y, k, λ) + return sum(hingeloss.(y, ŷ)) - λ * norm(ŷ) # Total svm loss with regularisation +end +# ### We create a 2D grid based on the maximum values of the data +N_test = 100 # Size of the grid +xgrid = range(extrema(vcat(x...)).*1.1..., length=N_test) # Create a 1D grid +xgrid_v = vec(collect.(Iterators.product(xgrid, xgrid))) #Combine into a 2D grid +# ### We predict the value of y on this grid on plot it against the data +y_grid = f(xgrid_v, x, y, k, λ) #Compute prediction on a grid +contourf(xgrid, xgrid, reshape(y_grid, N_test, N_test)', label = "Predictions", title="Trained model") +scatter!(getindex.(x[y .== 1], 1), getindex.(x[y .== 1], 2), label = "y = 1") +scatter!(getindex.(x[y .== -1], 1), getindex.(x[y .== -1], 2), label = "y = 2") +xlims!(extrema(xgrid)) +ylims!(extrema(xgrid)) diff --git a/docs/make.jl b/docs/make.jl index 2949cb7e1..8b8991ec4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,6 +7,21 @@ end using KernelFunctions +if ispath(joinpath(@__DIR__, "src", "examples")) + rm(joinpath(@__DIR__, "src", "examples"), recursive=true) +end + +for filename in readdir(joinpath(@__DIR__, "..", "examples")) + endswith(filename, ".jl") || continue + name = splitext(filename)[1] + Literate.markdown( + joinpath(@__DIR__, "..", "examples", filename), + joinpath(@__DIR__, "src", "examples"), + name = name, + documenter = true, + ) +end + DocMeta.setdocmeta!( KernelFunctions, :DocTestSetup, diff --git a/docs/src/create_kernel.md b/docs/src/create_kernel.md index d7a10c4db..2b415e4d1 100644 --- a/docs/src/create_kernel.md +++ b/docs/src/create_kernel.md @@ -8,15 +8,15 @@ Here are a few ways depending on how complicated your kernel is: ### SimpleKernel for kernel functions depending on a metric -If your kernel function is of the form `k(x, y) = f(d(x, y))` where `d(x, y)` is a `PreMetric`, -you can construct your custom kernel by defining `kappa` and `metric` for your kernel. +If your kernel function is of the form `k(x, y) = f(binary_op(x, y))` where `binary_op(x, y)` is a `PreMetric` or another function/instance implementing `pairwise` and `evaluate` from `Distances.jl`, +you can construct your custom kernel by defining `kappa` and `binary_op` for your kernel. Here is for example how one can define the `SqExponentialKernel` again : ```julia struct MyKernel <: KernelFunctions.SimpleKernel end KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2) -KernelFunctions.metric(::MyKernel) = SqEuclidean() +KernelFunctions.binary_op(::MyKernel) = SqEuclidean() ``` ### Kernel for more complex kernels diff --git a/docs/src/index.md b/docs/src/index.md index 719e6bb79..bf0f002a2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,7 +11,7 @@ The main goals of this package compared to its predecessors/concurrents in [MLKe The methodology of how kernels are computed is quite simple and is done in three phases : - A `Transform` object is applied sample-wise on every sample -- The pairwise matrix is computed using [Distances.jl](https://github.com/JuliaStats/Distances.jl) by using a `Metric` proper to each kernel +- The pairwise matrix is computed using [Distances.jl](https://github.com/JuliaStats/Distances.jl) by using a `BinaryOp` like a `Metric` proper to each kernel - The `Kernel` function is applied element-wise on the pairwise matrix For a quick introduction on how to use it go to [User guide](@ref) diff --git a/docs/src/metrics.md b/docs/src/metrics.md index 12260adbc..0bff836fc 100644 --- a/docs/src/metrics.md +++ b/docs/src/metrics.md @@ -1,31 +1,42 @@ -# Metrics +# Binary Operations `SimpleKernel` implementations rely on [Distances.jl](https://github.com/JuliaStats/Distances.jl) for efficiently computing the pairwise matrix. This requires a distance measure or metric, such as the commonly used `SqEuclidean` and `Euclidean`. The metric used by a given kernel type is specified as ```julia -KernelFunctions.metric(::CustomKernel) = SqEuclidean() +KernelFunctions.binary_op(::CustomKernel) = SqEuclidean() ``` However, there are kernels that can be implemented efficiently using "metrics" that do not respect all the definitions expected by Distances.jl. For this reason, KernelFunctions.jl provides additional "metrics" such as `DotProduct` ($\langle x, y \rangle$) and `Delta` ($\delta(x,y)$). -## Adding a new metric +## Adding a new binary operation -If you want to create a new "metric" just implement the following: +If you want to create a new binary operation you have the choice. +If your operation respects all [`PreMetric` conditions](https://en.wikipedia.org/wiki/Metric_(mathematics)#Generalized_metrics) you can just implement the following: ```julia -struct Delta <: Distances.PreMetric -end +struct Delta <: Distances.PreMetric end -@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T} +@inline function Distances._evaluate( + ::Delta, + a::AbstractVector{T}, + b::AbstractVector{T} + ) where {T} @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end return a==b end -@inline (dist::Delta)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist,a,b) -@inline (dist::Delta)(a::Number,b::Number) = a==b +@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist,a,b) +@inline (dist::Delta)(a::Number, b::Number) = a==b ``` + +However if it somehow does not respect some of the conditions (for instance `d(x, y) ≥ 0` for the `DotProduct`), we have a similar backend: +```julia +struct DotProduct <: KernelFunctions.AbstractBinaryOp end +(d::DotProduct)(a, b) = evaluate(d, a,b) +Distances.evaluate(::DotProduct, a, b) = dot(a, b) +``` \ No newline at end of file diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 7a6ec8a6c..eecb871fa 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -67,11 +67,15 @@ using TensorCore abstract type Kernel end abstract type SimpleKernel <: Kernel end +abstract type AbstractBinaryOp end +const BinaryOp = Union{Distances.Metric, AbstractBinaryOp} + include("utils.jl") -include(joinpath("distances", "pairwise.jl")) -include(joinpath("distances", "dotproduct.jl")) -include(joinpath("distances", "delta.jl")) -include(joinpath("distances", "sinus.jl")) + +include(joinpath("binary_op", "abstractbinaryop.jl")) +include(joinpath("binary_op", "dotproduct.jl")) +include(joinpath("binary_op", "delta.jl")) +include(joinpath("binary_op", "sinus.jl")) include(joinpath("transform", "transform.jl")) include(joinpath("transform", "scaletransform.jl")) diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 5996546f1..883bf12cb 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -17,7 +17,7 @@ struct ZeroKernel <: SimpleKernel end kappa(κ::ZeroKernel, d::T) where {T<:Real} = zero(T) -metric(::ZeroKernel) = Delta() +binary_op(::ZeroKernel) = Delta() Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel") @@ -44,7 +44,7 @@ const EyeKernel = WhiteKernel kappa(κ::WhiteKernel, δₓₓ::Real) = δₓₓ -metric(::WhiteKernel) = Delta() +binary_op(::WhiteKernel) = Delta() Base.show(io::IO, ::WhiteKernel) = print(io, "White Kernel") @@ -75,6 +75,6 @@ end kappa(κ::ConstantKernel, x::Real) = first(κ.c) * one(x) -metric(::ConstantKernel) = Delta() +binary_op(::ConstantKernel) = Delta() Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first(κ.c), ")") diff --git a/src/basekernels/cosine.jl b/src/basekernels/cosine.jl index 361b06b81..0a084e29c 100644 --- a/src/basekernels/cosine.jl +++ b/src/basekernels/cosine.jl @@ -14,6 +14,6 @@ struct CosineKernel <: SimpleKernel end kappa(::CosineKernel, d::Real) = cospi(d) -metric(::CosineKernel) = Euclidean() +binary_op(::CosineKernel) = Euclidean() Base.show(io::IO, ::CosineKernel) = print(io, "Cosine Kernel") diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index e0546a989..78831a421 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -16,7 +16,7 @@ struct SqExponentialKernel <: SimpleKernel end kappa(::SqExponentialKernel, d²::Real) = exp(-d² / 2) -metric(::SqExponentialKernel) = SqEuclidean() +binary_op(::SqExponentialKernel) = SqEuclidean() iskroncompatible(::SqExponentialKernel) = true @@ -63,7 +63,7 @@ struct ExponentialKernel <: SimpleKernel end kappa(::ExponentialKernel, d::Real) = exp(-d) -metric(::ExponentialKernel) = Euclidean() +binary_op(::ExponentialKernel) = Euclidean() iskroncompatible(::ExponentialKernel) = true @@ -114,7 +114,7 @@ end kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ)) -metric(::GammaExponentialKernel) = Euclidean() +binary_op(::GammaExponentialKernel) = Euclidean() iskroncompatible(::GammaExponentialKernel) = true diff --git a/src/basekernels/exponentiated.jl b/src/basekernels/exponentiated.jl index 0b360ceb6..afb3f387a 100644 --- a/src/basekernels/exponentiated.jl +++ b/src/basekernels/exponentiated.jl @@ -14,7 +14,7 @@ struct ExponentiatedKernel <: SimpleKernel end kappa(::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy) -metric(::ExponentiatedKernel) = DotProduct() +binary_op(::ExponentiatedKernel) = DotProduct() iskroncompatible(::ExponentiatedKernel) = true diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index fad6d70f2..61d7485fb 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -46,26 +46,26 @@ _mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2)) function kernelmatrix(κ::FBMKernel, x::AbstractVector) modx = _mod(x) - modxx = pairwise(SqEuclidean(), x) + modxx = Distances.pairwise(SqEuclidean(), x) return _fbm.(modx, modx', modxx, κ.h) end function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector) modx = _mod(x) - pairwise!(K, SqEuclidean(), x) + Distances.pairwise!(K, SqEuclidean(), x) K .= _fbm.(modx, modx', K, κ.h) return K end function kernelmatrix(κ::FBMKernel, x::AbstractVector, y::AbstractVector) - modxy = pairwise(SqEuclidean(), x, y) + modxy = Distances.pairwise(SqEuclidean(), x, y) return _fbm.(_mod(x), _mod(y)', modxy, κ.h) end function kernelmatrix!( K::AbstractMatrix, κ::FBMKernel, x::AbstractVector, y::AbstractVector ) - pairwise!(K, SqEuclidean(), x, y) + Distances.pairwise!(K, SqEuclidean(), x, y) K .= _fbm.(_mod(x), _mod(y)', K, κ.h) return K end diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index bc597ca1f..c5c8eb8af 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -38,7 +38,7 @@ function _matern(ν::Real, d::Real) return exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(y) + log(besselk(ν, y))) end -metric(::MaternKernel) = Euclidean() +binary_op(::MaternKernel) = Euclidean() Base.show(io::IO, κ::MaternKernel) = print(io, "Matern Kernel (ν = ", first(κ.ν), ")") @@ -62,7 +62,7 @@ struct Matern32Kernel <: SimpleKernel end kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d) -metric(::Matern32Kernel) = Euclidean() +binary_op(::Matern32Kernel) = Euclidean() Base.show(io::IO, ::Matern32Kernel) = print(io, "Matern 3/2 Kernel") @@ -85,6 +85,6 @@ struct Matern52Kernel <: SimpleKernel end kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d) -metric(::Matern52Kernel) = Euclidean() +binary_op(::Matern52Kernel) = Euclidean() Base.show(io::IO, ::Matern52Kernel) = print(io, "Matern 5/2 Kernel") diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 2758d7f94..00d88a9ea 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -32,7 +32,7 @@ PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims)) @functor PeriodicKernel -metric(κ::PeriodicKernel) = Sinus(κ.r) +binary_op(κ::PeriodicKernel) = Sinus(κ.r) kappa(::PeriodicKernel, d::Real) = exp(-0.5d) diff --git a/src/basekernels/piecewisepolynomial.jl b/src/basekernels/piecewisepolynomial.jl index 8969f8267..f7f985457 100644 --- a/src/basekernels/piecewisepolynomial.jl +++ b/src/basekernels/piecewisepolynomial.jl @@ -76,7 +76,7 @@ end kappa(κ::PiecewisePolynomialKernel, r) = max(1 - r, 0)^κ.alpha * evalpoly(r, κ.coeffs) -metric(::PiecewisePolynomialKernel) = Euclidean() +binary_op(κ::PiecewisePolynomialKernel) = Euclidean() function Base.show(io::IO, κ::PiecewisePolynomialKernel{D}) where {D} return print( diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 4bbc6658d..53da7d3fd 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -26,7 +26,7 @@ end kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c) -metric(::LinearKernel) = DotProduct() +binary_op(::LinearKernel) = DotProduct() Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")") @@ -76,7 +76,7 @@ end kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree -metric(::PolynomialKernel) = DotProduct() +binary_op(::PolynomialKernel) = DotProduct() function Base.show(io::IO, κ::PolynomialKernel) return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")") diff --git a/src/basekernels/rationalquad.jl b/src/basekernels/rationalquad.jl index 2d0780fe7..861c2989e 100644 --- a/src/basekernels/rationalquad.jl +++ b/src/basekernels/rationalquad.jl @@ -29,7 +29,7 @@ function kappa(κ::RationalQuadraticKernel, d²::T) where {T<:Real} return (one(T) + d² / (2 * first(κ.α)))^(-first(κ.α)) end -metric(::RationalQuadraticKernel) = SqEuclidean() +binary_op(::RationalQuadraticKernel) = SqEuclidean() function Base.show(io::IO, κ::RationalQuadraticKernel) return print(io, "Rational Quadratic Kernel (α = $(first(κ.α)))") @@ -70,7 +70,7 @@ function kappa(κ::GammaRationalQuadraticKernel, d::Real) return (one(d) + d^first(κ.γ) / first(κ.α))^(-first(κ.α)) end -metric(::GammaRationalQuadraticKernel) = Euclidean() +binary_op(::GammaRationalQuadraticKernel) = Euclidean() function Base.show(io::IO, κ::GammaRationalQuadraticKernel) return print( diff --git a/src/binary_op/abstractbinaryop.jl b/src/binary_op/abstractbinaryop.jl new file mode 100644 index 000000000..0598e3ef0 --- /dev/null +++ b/src/binary_op/abstractbinaryop.jl @@ -0,0 +1,65 @@ +## AbstractBinaryOp shadows the implementation of Distances.jl functions and types +## for types which are not metric by definition but benefit from all the +## pairwise machinery + +## pairwise functions for matrices +function Distances.pairwise(d::AbstractBinaryOp, a::AbstractMatrix, b::AbstractMatrix=a; dims=1) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + m = size(a, dims) + n = size(b, dims) + P = Matrix{Distances.result_type(d, a, b)}(undef, m, n) + if dims == 1 + return Distances._pairwise!(P, d, transpose(a), transpose(b)) + else + return Distances._pairwise!(P, d, a, b) + end + return P +end + +function Distances.pairwise!(P::AbstractMatrix, d::AbstractBinaryOp, a::AbstractMatrix, b::AbstractMatrix=a; dims=1) + dims = deprecated_dims(dims) + dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)")) + if dims == 1 + na, ma = size(a) + nb, mb = size(b) + ma == mb || throw(DimensionMismatch("The numbers of columns in a and b " * + "must match (got $ma and $mb).")) + else + ma, na = size(a) + mb, nb = size(b) + ma == mb || throw(DimensionMismatch("The numbers of rows in a and b " * + "must match (got $ma and $mb).")) + end + size(P) == (na, nb) || + throw(DimensionMismatch("Incorrect size of P (got $(size(P)), expected $((na, nb))).")) + if dims == 1 + _pairwise!(P, d, transpose(a), transpose(b)) + else + _pairwise!(P, d, a, b) + end + return P +end + +function Distances._pairwise!(P::AbstractMatrix, d::AbstractBinaryOp, a::AbstractMatrix, b::AbstractMatrix) + for ij in CartesianIndices(P) + P[ij] = @views d(a[:, ij[1]], b[:, ij[2]]) + end + return P +end + +## pairwise function for vectors +function Distances.pairwise(d::AbstractBinaryOp, X::AbstractVector, Y::AbstractVector=X) + return broadcast(d, X, permutedims(Y)) +end + +function Distances.pairwise!( + out::AbstractMatrix, + d::AbstractBinaryOp, + X::AbstractVector, + Y::AbstractVector=X, +) + broadcast!(d, out, X, permutedims(Y)) +end + +## Additional needed Helpers +Distances.result_type(::AbstractBinaryOp, Ta::Type, Tb::Type) = promote_type(Ta, Tb) \ No newline at end of file diff --git a/src/distances/delta.jl b/src/binary_op/delta.jl similarity index 100% rename from src/distances/delta.jl rename to src/binary_op/delta.jl diff --git a/src/binary_op/dotproduct.jl b/src/binary_op/dotproduct.jl new file mode 100644 index 000000000..ab2dc5234 --- /dev/null +++ b/src/binary_op/dotproduct.jl @@ -0,0 +1,23 @@ +struct DotProduct <: AbstractBinaryOp end +# struct DotProduct <: Distances.UnionSemiMetric end + + +(d::DotProduct)(a, b) = Distances.evaluate(d, a, b) +Distances.evaluate(::DotProduct, a, b) = dot(a, b) + +function Distances._pairwise!(P::AbstractMatrix, ::KernelFunctions.DotProduct, a::AbstractMatrix, b::AbstractMatrix=a) + return mul!(P, transpose(a), b) +end + +# @inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) +# @boundscheck if length(a) != length(b) +# throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) +# end +# return dot(a,b) +# end + +# Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) + +# @inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b +# @inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b) +# @inline (dist::DotProduct)(a::Number,b::Number) = a * b diff --git a/src/distances/sinus.jl b/src/binary_op/sinus.jl similarity index 100% rename from src/distances/sinus.jl rename to src/binary_op/sinus.jl diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl deleted file mode 100644 index ef0f64b28..000000000 --- a/src/distances/dotproduct.jl +++ /dev/null @@ -1,21 +0,0 @@ -struct DotProduct <: Distances.PreMetric end -# struct DotProduct <: Distances.UnionSemiMetric end - -@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) - @boundscheck if length(a) != length(b) - throw( - DimensionMismatch( - "first array has length $(length(a)) which does not match the length of the second, $(length(b)).", - ), - ) - end - return dot(a, b) -end - -Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) - -@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b -@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray) - return Distances._evaluate(dist, a, b) -end -@inline (dist::DotProduct)(a::Number, b::Number) = a * b diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl deleted file mode 100644 index 7a555222f..000000000 --- a/src/distances/pairwise.jl +++ /dev/null @@ -1,31 +0,0 @@ -# Add our own pairwise function to be able to apply it on vectors - -function pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast(d, X, permutedims(Y)) -end - -pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X) - -function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast!(d, out, X, Y') -end - -pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X) - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1); dims=1) -end - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1) -end - -function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1) -end - -function pairwise!( - out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real} -) - return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) -end diff --git a/src/generic.jl b/src/generic.jl index ef8762fef..ed7b4ab9d 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -6,4 +6,4 @@ Base.iterate(k::Kernel, ::Any) = nothing printshifted(io::IO, o, shift::Int) = print(io, o) # Fallback implementation of evaluate for `SimpleKernel`s. -(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y)) +(k::SimpleKernel)(x, y) = kappa(k, evaluate(binary_op(k), x, y)) \ No newline at end of file diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 5daf38361..a631dd225 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -25,21 +25,21 @@ end (k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y)) # Optimizations for scale transforms of simple kernels to save allocations: -# Instead of a multiplying every element of the inputs before evaluating the metric, +# Instead of a multiplying every element of the inputs before evaluating the binary_op, # we perform a scalar multiplcation of the distance of the original inputs, if possible. function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})( x::AbstractVector{<:Real}, y::AbstractVector{<:Real} ) - return kappa(k.kernel, _scale(k.transform, metric(k.kernel), x, y)) + return kappa(k.kernel, _scale(k.transform, binary_op(k.kernel), x, y)) end -function _scale(t::ScaleTransform, metric::Euclidean, x, y) - return first(t.s) * evaluate(metric, x, y) +function _scale(t::ScaleTransform, binary_op::Euclidean, x, y) + return first(t.s) * evaluate(binary_op, x, y) end -function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) - return first(t.s)^2 * evaluate(metric, x, y) +function _scale(t::ScaleTransform, binary_op::Union{SqEuclidean,DotProduct}, x, y) + return first(t.s)^2 * evaluate(binary_op, x, y) end -_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y)) +_scale(t::ScaleTransform, binary_op, x, y) = evaluate(binary_op, t(x), t(y)) """ transform(k::Kernel, t::Transform) diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index f619368a0..65212ddea 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -79,7 +79,7 @@ kernelmatrix_diag(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector) validate_inplace_dims(K, x) - pairwise!(K, metric(κ), x) + Distances.pairwise!(K, binary_op(κ), x) return map!(d -> kappa(κ, d), K, K) end @@ -87,17 +87,17 @@ function kernelmatrix!( K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector, y::AbstractVector ) validate_inplace_dims(K, x, y) - pairwise!(K, metric(κ), x, y) + Distances.pairwise!(K, binary_op(κ), x, y) return map!(d -> kappa(κ, d), K, K) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector) - return map(d -> kappa(κ, d), pairwise(metric(κ), x)) + return map(d -> kappa(κ, d), Distances.pairwise(binary_op(κ), x)) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) validate_inputs(x, y) - return map(d -> kappa(κ, d), pairwise(metric(κ), x, y)) + return map(d -> kappa(κ, d), Distances.pairwise(binary_op(κ), x, y)) end # diff --git a/src/utils.jl b/src/utils.jl index 7e28c82ef..a868a08d1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -52,18 +52,18 @@ Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) dim(x::ColVecs) = size(x.X, 1) -pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2) -pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2) -function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs) +Distances.pairwise(d::BinaryOp, x::ColVecs) = Distances.pairwise(d, x.X; dims=2) +Distances.pairwise(d::BinaryOp, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2) +function Distances.pairwise(d::BinaryOp, x::AbstractVector, y::ColVecs) return Distances.pairwise(d, reduce(hcat, x), y.X; dims=2) end -function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector) +function Distances.pairwise(d::BinaryOp, x::ColVecs, y::AbstractVector) return Distances.pairwise(d, x.X, reduce(hcat, y); dims=2) end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs) +function Distances.pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs) return Distances.pairwise!(out, d, x.X; dims=2) end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs) +function Distances.pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::ColVecs) return Distances.pairwise!(out, d, x.X, y.X; dims=2) end @@ -91,18 +91,18 @@ Base.setindex!(D::RowVecs, v::AbstractVector, i) = setindex!(D.X, v, i, :) dim(x::RowVecs) = size(x.X, 2) -pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1) -pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1) -function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs) +Distances.pairwise(d::BinaryOp, x::RowVecs) = Distances.pairwise(d, x.X; dims=1) +Distances.pairwise(d::BinaryOp, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1) +function Distances.pairwise(d::BinaryOp, x::AbstractVector, y::RowVecs) return Distances.pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1) end -function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector) +function Distances.pairwise(d::BinaryOp, x::RowVecs, y::AbstractVector) return Distances.pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1) end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs) +function Distances.pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs) return Distances.pairwise!(out, d, x.X; dims=1) end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs) +function Distances.pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::RowVecs) return Distances.pairwise!(out, d, x.X, y.X; dims=1) end diff --git a/test/Project.toml b/test/Project.toml index 43bd03769..dcf7ca1b9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AxisArrays = "0.4.3" -Distances = "0.9, 0.10" +Distances = "0.10" Documenter = "0.25, 0.26" FiniteDifferences = "0.10.8, 0.11, 0.12" Flux = "0.10, 0.11" diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index e18df9419..2b84efe76 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -3,7 +3,7 @@ k = ZeroKernel() @test eltype(k) == Any @test kappa(k, 2.0) == 0.0 - @test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta() + @test binary_op(ZeroKernel()) == KernelFunctions.Delta() @test repr(k) == "Zero Kernel" # Standardised tests. @@ -16,7 +16,7 @@ @test kappa(k, 1.0) == 1.0 @test kappa(k, 0.0) == 0.0 @test EyeKernel == WhiteKernel - @test metric(WhiteKernel()) == KernelFunctions.Delta() + @test binary_op(WhiteKernel()) == KernelFunctions.Delta() @test repr(k) == "White Kernel" # Standardised tests. @@ -29,8 +29,8 @@ @test eltype(k) == Any @test kappa(k, 1.0) == c @test kappa(k, 0.5) == c - @test metric(ConstantKernel()) == KernelFunctions.Delta() - @test metric(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() + @test binary_op(ConstantKernel()) == KernelFunctions.Delta() + @test binary_op(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" test_params(k, ([c],)) diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index 252a36635..40e258df2 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -8,7 +8,7 @@ @test kappa(k, x) ≈ exp(-x / 2) @test k(v1, v2) ≈ exp(-norm(v1 - v2)^2 / 2) @test kappa(SqExponentialKernel(), x) == kappa(k, x) - @test metric(SqExponentialKernel()) == SqEuclidean() + @test binary_op(SqExponentialKernel()) == SqEuclidean() @test RBFKernel == SqExponentialKernel @test GaussianKernel == SqExponentialKernel @test SEKernel == SqExponentialKernel @@ -24,7 +24,7 @@ @test kappa(k, x) ≈ exp(-x) @test k(v1, v2) ≈ exp(-norm(v1 - v2)) @test kappa(ExponentialKernel(), x) == kappa(k, x) - @test metric(ExponentialKernel()) == Euclidean() + @test binary_op(ExponentialKernel()) == Euclidean() @test repr(k) == "Exponential Kernel" @test LaplacianKernel == ExponentialKernel @test KernelFunctions.iskroncompatible(k) == true @@ -39,8 +39,8 @@ @test k(v1, v2) ≈ exp(-norm(v1 - v2)^γ) @test kappa(GammaExponentialKernel(), x) == kappa(k, x) @test GammaExponentialKernel(; gamma=γ).γ == [γ] - @test metric(GammaExponentialKernel()) == Euclidean() - @test metric(GammaExponentialKernel(; γ=2.0)) == Euclidean() + @test binary_op(GammaExponentialKernel()) == Euclidean() + @test binary_op(GammaExponentialKernel(; γ=2.0)) == Euclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ))" @test KernelFunctions.iskroncompatible(k) == true diff --git a/test/basekernels/exponentiated.jl b/test/basekernels/exponentiated.jl index 1d209dc49..f394c047f 100644 --- a/test/basekernels/exponentiated.jl +++ b/test/basekernels/exponentiated.jl @@ -8,7 +8,7 @@ @test kappa(k, x) ≈ exp(x) @test kappa(k, -x) ≈ exp(-x) @test k(v1, v2) ≈ exp(dot(v1, v2)) - @test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct() + @test binary_op(ExponentiatedKernel()) == KernelFunctions.DotProduct() @test repr(k) == "Exponentiated Kernel" # Standardised tests. This kernel appears to be fairly numerically unstable. diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index f5d86f3eb..f364ea806 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -11,8 +11,8 @@ @test kappa(k, x) ≈ matern(x, ν) @test kappa(k, 0.0) == 1.0 @test kappa(MaternKernel(; ν=ν), x) == kappa(k, x) - @test metric(MaternKernel()) == Euclidean() - @test metric(MaternKernel(; ν=2.0)) == Euclidean() + @test binary_op(MaternKernel()) == Euclidean() + @test binary_op(MaternKernel(; ν=2.0)) == Euclidean() @test repr(k) == "Matern Kernel (ν = $(ν))" # test_ADs(x->MaternKernel(nu=first(x)),[ν]) @test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)" @@ -26,7 +26,7 @@ @test kappa(k, x) ≈ (1 + sqrt(3) * x)exp(-sqrt(3) * x) @test k(v1, v2) ≈ (1 + sqrt(3) * norm(v1 - v2))exp(-sqrt(3) * norm(v1 - v2)) @test kappa(Matern32Kernel(), x) == kappa(k, x) - @test metric(Matern32Kernel()) == Euclidean() + @test binary_op(Matern32Kernel()) == Euclidean() @test repr(k) == "Matern 3/2 Kernel" # Standardised tests. @@ -41,7 +41,7 @@ 1 + sqrt(5) * norm(v1 - v2) + 5 / 3 * norm(v1 - v2)^2 )exp(-sqrt(5) * norm(v1 - v2)) @test kappa(Matern52Kernel(), x) == kappa(k, x) - @test metric(Matern52Kernel()) == Euclidean() + @test binary_op(Matern52Kernel()) == Euclidean() @test repr(k) == "Matern 5/2 Kernel" # Standardised tests. diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 7622fd253..8ae02336e 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -9,8 +9,8 @@ @test kappa(k, x) ≈ x @test k(v1, v2) ≈ dot(v1, v2) @test kappa(LinearKernel(), x) == kappa(k, x) - @test metric(LinearKernel()) == KernelFunctions.DotProduct() - @test metric(LinearKernel(; c=c)) == KernelFunctions.DotProduct() + @test binary_op(LinearKernel()) == KernelFunctions.DotProduct() + @test binary_op(LinearKernel(; c=c)) == KernelFunctions.DotProduct() @test repr(k) == "Linear Kernel (c = 0.0)" # Errors. @@ -29,10 +29,10 @@ @test repr(k) == "Polynomial Kernel (c = 0.0, degree = 2)" # Coherence tests. - @test kappa(PolynomialKernel(; degree=1, c=c), x) ≈ kappa(LinearKernel(; c=c), x) + @test kappa(PolynomialKernel(d=1.0,c=c),x) ≈ kappa(LinearKernel(c=c),x) @test metric(PolynomialKernel()) == KernelFunctions.DotProduct() - @test metric(PolynomialKernel(; degree=3)) == KernelFunctions.DotProduct() - @test metric(PolynomialKernel(; degree=3, c=c)) == KernelFunctions.DotProduct() + @test binary_op(PolynomialKernel(d=3.0)) == KernelFunctions.DotProduct() + @test binary_op(PolynomialKernel(d=3.0,c=2.0)) == KernelFunctions.DotProduct() # Deprecations. k = @test_deprecated PolynomialKernel(; d=1) diff --git a/test/basekernels/rationalquad.jl b/test/basekernels/rationalquad.jl index ddb2ea03f..5f80de80a 100644 --- a/test/basekernels/rationalquad.jl +++ b/test/basekernels/rationalquad.jl @@ -17,8 +17,8 @@ ) end - @test metric(RationalQuadraticKernel()) == SqEuclidean() - @test metric(RationalQuadraticKernel(; α=2.0)) == SqEuclidean() + @test binary_op(RationalQuadraticKernel()) == SqEuclidean() + @test binary_op(RationalQuadraticKernel(; α=2.0)) == SqEuclidean() @test repr(k) == "Rational Quadratic Kernel (α = $(α))" # Standardised tests. @@ -78,9 +78,9 @@ ) end - @test metric(GammaRationalQuadraticKernel()) == Euclidean() - @test metric(GammaRationalQuadraticKernel(; γ=2.0)) == Euclidean() - @test metric(GammaRationalQuadraticKernel(; γ=2.0, α=3.0)) == Euclidean() + @test binary_op(GammaRationalQuadraticKernel()) == Euclidean() + @test binary_op(GammaRationalQuadraticKernel(; γ=2.0)) == Euclidean() + @test binary_op(GammaRationalQuadraticKernel(; γ=2.0, α=3.0)) == Euclidean() # Standardised tests. TestUtils.test_interface(k, Float64) diff --git a/test/binary_op/abstractbinaryop.jl b/test/binary_op/abstractbinaryop.jl new file mode 100644 index 000000000..dce31dc68 --- /dev/null +++ b/test/binary_op/abstractbinaryop.jl @@ -0,0 +1,34 @@ +@testset "abstractbinaryop" begin + using KernelFunctions: AbstractBinaryOp + rng = MersenneTwister(123456) + d = SqEuclidean() + Ns = (4, 5) + D = 3 + x = [randn(rng, D) for _ in 1:Ns[1]] + y = [randn(rng, D) for _ in 1:Ns[2]] + X = hcat(x...) + Y = hcat(y...) + K = zeros(Ns) + + struct Max <: AbstractBinaryOp end + + (d::Max)(a, b) = Distances.evaluate(d, a, b) + Distances.evaluate(::Max, a, b) = maximum(abs, a - b) + @test pairwise(d, x, y) ≈ pairwise(d, X, Y; dims=2) + @test pairwise(d, x) ≈ pairwise(d, X; dims=2) + pairwise!(K, d, x, y) + @test K ≈ pairwise(d, X, Y; dims=2) + K = zeros(Ns[1], Ns[1]) + pairwise!(K, d, x) + @test K ≈ pairwise(d, X; dims=2) + + x = randn(rng, 10) + X = reshape(x, :, 1) + y = randn(rng, 11) + Y = reshape(y, :, 1) + K = zeros(10, 11) + @test pairwise(d, x, y) ≈ pairwise(d, X, Y; dims=1) + @test pairwise(d, x) ≈ pairwise(d, X; dims=1) + pairwise!(K, d, x, y) + @test K ≈ pairwise(d, X, Y; dims=1) +end diff --git a/test/distances/delta.jl b/test/binary_op/delta.jl similarity index 100% rename from test/distances/delta.jl rename to test/binary_op/delta.jl diff --git a/test/distances/dotproduct.jl b/test/binary_op/dotproduct.jl similarity index 68% rename from test/distances/dotproduct.jl rename to test/binary_op/dotproduct.jl index b6e3691ed..958aa448b 100644 --- a/test/distances/dotproduct.jl +++ b/test/binary_op/dotproduct.jl @@ -2,7 +2,7 @@ A = rand(10, 5) B = rand(20, 5) d = KernelFunctions.DotProduct() - @test diag(pairwise(d, A; dims=2)) == [dot(A[:, i], A[:, i]) for i in 1:size(A, 2)] + @test diag(pairwise(d, A; dims=1)) == [dot(A[i, :], A[i, :]) for i in 1:size(A, 1)] @test_throws DimensionMismatch d(rand(3), rand(4)) @test d(3.0, 2.0) == 6.0 end diff --git a/test/distances/sinus.jl b/test/binary_op/sinus.jl similarity index 100% rename from test/distances/sinus.jl rename to test/binary_op/sinus.jl diff --git a/test/distances/pairwise.jl b/test/distances/pairwise.jl deleted file mode 100644 index 486097f52..000000000 --- a/test/distances/pairwise.jl +++ /dev/null @@ -1,29 +0,0 @@ -@testset "pairwise" begin - rng = MersenneTwister(123456) - d = SqEuclidean() - Ns = (4, 5) - D = 3 - x = [randn(rng, D) for _ in 1:Ns[1]] - y = [randn(rng, D) for _ in 1:Ns[2]] - X = hcat(x...) - Y = hcat(y...) - K = zeros(Ns) - - @test KernelFunctions.pairwise(d, x, y) ≈ pairwise(d, X, Y; dims=2) - @test KernelFunctions.pairwise(d, x) ≈ pairwise(d, X; dims=2) - KernelFunctions.pairwise!(K, d, x, y) - @test K ≈ pairwise(d, X, Y; dims=2) - K = zeros(Ns[1], Ns[1]) - KernelFunctions.pairwise!(K, d, x) - @test K ≈ pairwise(d, X; dims=2) - - x = randn(rng, 10) - X = reshape(x, :, 1) - y = randn(rng, 11) - Y = reshape(y, :, 1) - K = zeros(10, 11) - @test KernelFunctions.pairwise(d, x, y) ≈ pairwise(d, X, Y; dims=1) - @test KernelFunctions.pairwise(d, x) ≈ pairwise(d, X; dims=1) - KernelFunctions.pairwise!(K, d, x, y) - @test K ≈ pairwise(d, X, Y; dims=1) -end diff --git a/test/matrix/kernelmatrix.jl b/test/matrix/kernelmatrix.jl index f6227bba7..5f89fd403 100644 --- a/test/matrix/kernelmatrix.jl +++ b/test/matrix/kernelmatrix.jl @@ -7,7 +7,7 @@ struct BaseSE <: KernelFunctions.Kernel end # are implemented in the package. That this happens to be an exponentiated quadratic kernel # is a complete coincidence. struct ToySimpleKernel <: SimpleKernel end -KernelFunctions.metric(::ToySimpleKernel) = SqEuclidean() +KernelFunctions.binary_op(::ToySimpleKernel) = SqEuclidean() KernelFunctions.kappa(::ToySimpleKernel, d) = exp(-d / 2) @testset "kernelmatrix" begin diff --git a/test/runtests.jl b/test/runtests.jl index 7ad679905..ae24e1bab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences -using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils +using KernelFunctions: SimpleKernel, binary_op, kappa, ColVecs, RowVecs, TestUtils using KernelFunctions.TestUtils: test_interface @@ -55,13 +55,13 @@ include("test_utils.jl") @testset "KernelFunctions" begin include("utils.jl") - @testset "distances" begin - include(joinpath("distances", "pairwise.jl")) - include(joinpath("distances", "dotproduct.jl")) - include(joinpath("distances", "delta.jl")) - include(joinpath("distances", "sinus.jl")) + @testset "binary_op" begin + include(joinpath("binary_op", "abstractbinaryop.jl")) + include(joinpath("binary_op", "dotproduct.jl")) + include(joinpath("binary_op", "delta.jl")) + include(joinpath("binary_op", "sinus.jl")) end - @info "Ran tests on Distances" + @info "Ran tests on binary_op" @testset "transform" begin include(joinpath("transform", "transform.jl")) diff --git a/test/utils.jl b/test/utils.jl index 8bdd16330..fce5478f8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -29,15 +29,15 @@ Y = randn(rng, D, N + 1) DY = ColVecs(Y) - @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ + @test pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=2) - @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ + @test pairwise(SqEuclidean(), DX, DY) ≈ pairwise(SqEuclidean(), X, Y; dims=2) K = zeros(N, N) - KernelFunctions.pairwise!(K, SqEuclidean(), DX) + pairwise!(K, SqEuclidean(), DX) @test K ≈ pairwise(SqEuclidean(), X; dims=2) K = zeros(N, N + 1) - KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY) + pairwise!(K, SqEuclidean(), DX, DY) @test K ≈ pairwise(SqEuclidean(), X, Y; dims=2) let @@ -68,15 +68,15 @@ Y = randn(rng, D + 1, N) DY = RowVecs(Y) - @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ + @test pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=1) - @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ + @test pairwise(SqEuclidean(), DX, DY) ≈ pairwise(SqEuclidean(), X, Y; dims=1) K = zeros(D, D) - KernelFunctions.pairwise!(K, SqEuclidean(), DX) + pairwise!(K, SqEuclidean(), DX) @test K ≈ pairwise(SqEuclidean(), X; dims=1) K = zeros(D, D + 1) - KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY) + pairwise!(K, SqEuclidean(), DX, DY) @test K ≈ pairwise(SqEuclidean(), X, Y; dims=1) let