Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Mahalanobis distance-based kernel given by
```math
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
```
where the matrix P is the metric.

Expand All @@ -20,4 +20,8 @@ kappa(κ::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)

metric(κ::MahalanobisKernel) = SqMahalanobis(κ.P)

function dot_perslice(A::AbstractMatrix, B::AbstractMatrix; dims=2)
return reshape(sum(A .* B, dims=3-dims), :)
end

Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size(κ.P), ")")
28 changes: 28 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X, dims=1)
Y_2 = sum(y.X .* y.X, dims=1)
XY = x.X' * y.X
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
X_2_1 = sum(x.X .* x.X, dims=1) .+ 1
XX = x.X' * x.X
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X, dims=2)
Y_2 = sum(y.X .* y.X, dims=2)
XY = x.X * y.X'
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
X_2_1 = sum(x.X .* x.X, dims=2) .+ 1
XX = x.X * x.X'
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
10 changes: 10 additions & 0 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ end
@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end

@adjoint function (dist::Distances.SqMahalanobis)(a, b)
function back(Δ::Real)
B_B_inv = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = B_B_inv * a_b
return (qmat = a_b * a_b',), δa, -δa
end
return evaluate(dist::SqMahalanobis, a, b), back
end
3 changes: 1 addition & 2 deletions test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
@test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ])
#Coherence :
@test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2)
@test GammaExponentialKernel(γ=0.5)(v1,v2) ≈ ExponentialKernel()(v1,v2)
Expand Down
4 changes: 2 additions & 2 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5

@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
test_ADs(FBMKernel, ADs = [:ReverseDiff])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"
test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
end
3 changes: 1 addition & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
@test k.ell 1.0 atol=1e-5
@test k.p 1.0 atol=1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Tests failing for Zygote on differentiating through ell and p"
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
end
2 changes: 1 addition & 1 deletion test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
@test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
# test_ADs(P -> MahalanobisKernel(P), P)
test_ADs(P -> MahalanobisKernel(P), P, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
end
3 changes: 1 addition & 2 deletions test/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,5 @@
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))

@test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote uncompatible with BaseKernel"
test_ADs(NeuralNetworkKernel)
end