Skip to content

[WIP] Use Tullio for pairwise distances #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
888d752
Correction of the docstring
theogf Apr 16, 2021
a8f1183
Merge branch 'master' into tg/correct_doc_spectral_mixture
st-- Jul 9, 2021
31adb2b
Fixes documentations
theogf Jul 25, 2021
966a726
Merge remote-tracking branch 'origin/tg/correct_doc_spectral_mixture'…
theogf Jul 25, 2021
f721a58
Merge branch 'master' into tg/correct_doc_spectral_mixture
theogf Jul 29, 2021
7cb39d3
Merge branch 'master' into tg/correct_doc_spectral_mixture
theogf Aug 30, 2021
ca3a76c
Merge branch 'tg/correct_doc_spectral_mixture' of github.com:JuliaGau…
theogf Aug 31, 2021
de1c84a
Fix multiple docs
theogf Aug 31, 2021
50d5471
Version to test
theogf Aug 31, 2021
e86f572
Wrote more detailed tests
theogf Aug 31, 2021
01cfaaf
Add a bunch of cdot
theogf Aug 31, 2021
ffba539
changed spectral_mixture_kernel to a struct
theogf Aug 31, 2021
cb27209
formatting
theogf Aug 31, 2021
4f5c025
revert change lineartransform
theogf Aug 31, 2021
5f8e7f5
Add functor
theogf Aug 31, 2021
911cb73
Update the docs
theogf Sep 1, 2021
64e3641
Update docs
theogf Sep 1, 2021
6f9cae8
Update src/basekernels/sm.jl
theogf Oct 15, 2021
6fde691
Fixed docstrings and added constructors checks
theogf Oct 15, 2021
27db614
Merge branch 'master' into tg/correct_doc_spectral_mixture
theogf Oct 15, 2021
38f9eed
Update src/basekernels/sm.jl
theogf Oct 15, 2021
a07704b
Fix checks
theogf Oct 15, 2021
f37b566
improve computations using broadcaster
theogf Oct 19, 2021
1b1ea99
Introduction of Tullio to perform pairwise operations
theogf Oct 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ GammaRationalKernel
### Spectral Mixture Kernels

```@docs
spectral_mixture_kernel
SpectralMixtureKernel
spectral_mixture_product_kernel
```

Expand Down
13 changes: 12 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export SpectralMixtureKernel
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
Expand All @@ -33,7 +34,7 @@ export with_lengthscale
export NystromFact, nystrom

export gaborkernel
export spectral_mixture_kernel, spectral_mixture_product_kernel
export spectral_mixture_product_kernel

export ColVecs, RowVecs

Expand All @@ -58,6 +59,7 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using Tullio
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
Expand All @@ -67,6 +69,13 @@ abstract type Kernel end
abstract type SimpleKernel <: Kernel end

include("utils.jl")

const VecOfVecs = Union{ColVecs,RowVecs}

# A general binary op type not respecting Distances metric rules
abstract type AbstractBinaryOp end
const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric}

include("distances/pairwise.jl")
include("distances/dotproduct.jl")
include("distances/delta.jl")
Expand Down Expand Up @@ -120,6 +129,8 @@ include("mokernels/lmm.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecated.jl")

include("test_utils.jl")

function __init__()
Expand Down
175 changes: 121 additions & 54 deletions src/basekernels/sm.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
"""
spectral_mixture_kernel(
@doc raw"""
SpectralMixtureKernel(
h::Kernel=SqExponentialKernel(),
αs::AbstractVector{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractVector{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
SpectralMixtureKernel(
h::Kernel=SqExponentialKernel(),
α::AbstractVector{<:Real},
γ::AbstractVector{<:AbstractVecOrMat{<:Real}},
ω::AbstractVector{<:AbstractVecOrMat{<:Real}},
)

where αs are the weights of dimension (A, ), γs is the covariance matrix of
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
Here, D is input dimension and A is the number of spectral components.

`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
Generalised Spectral Mixture kernel function as described in [1] in equation 6.
This family of functions is dense in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]

Generalised Spectral Mixture kernel function. This family of functions is dense
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
## Definition

For inputs ``x, x′ \in \mathbb{R}^D``, the spectral mixture kernel ``\tilde{k}`` with ``K`` mixture components, mixture weights ``\alpha \in \mathbb{R}^K``, linear transformations ``\gamma_1, \ldots, \gamma_K \in \mathbb{R}^D``, and frequencies ``\omega_1, \ldots, \omega_K \in \mathbb{R}^D`` derived from a translation-invariant kernel ``k`` is defined as
```math
κ(x, y) = αs' (h(-(γs' * t)^2) .* cos(π * ωs' * t), t = x - y
\tilde{k}(x, x'; \alpha, \gamma_1, \ldots, \gamma_K, \omega_1, \ldots, \omega_K, k) = \sum_{i=1}^K \alpha_i k(\gamma_i \odot x, \gamma_i \odot y) \cos(2\pi \omega_i^\top (x-y)).
```

## Arguments
- `h`: Stationary kernel (translation invariant), [`SqExponentialKernel`](@ref) by default
- `α`: Weight vector of each mixture component (should be positive)
- `γ`: Linear transformation of the input for `h`.
- `ω`: Frequencies for the cosine function. (should be positive)

`γ` and `ω` can be an
- `AbstractMatrix` of dimension `D x K` where `D` is the dimension of the inputs
and `K` is the number of components
- `AbstractVector` of `K` `D`-dimensional `AbstractVector`


# References:
[1] Generalized Spectral Kernels, by Yves-Laurent Kom Samo and Stephen J. Roberts
[2] SM: Gaussian Process Kernels for Pattern Discovery and Extrapolation,
Expand All @@ -29,77 +44,129 @@ in the family of stationary real-valued kernels with respect to the pointwise co
[4] http://www.cs.cmu.edu/~andrewgw/pattern/.

"""
function spectral_mixture_kernel(
struct SpectralMixtureKernel{
K<:Kernel,Tα<:AbstractVector,Tγ<:AbstractVector,Tω<:AbstractVector
} <: Kernel
kernel::K
α::Tα
γ::Tγ
ω::Tω
function SpectralMixtureKernel(
h::Kernel,
α::AbstractVector{<:Real},
γ::AbstractVector{<:AbstractVector},
ω::AbstractVector{<:AbstractVector},
)
(length(α) == length(γ) == length(ω)) ||
throw(DimensionMismatch("The dimensions of α, γ, ans ω do not match"))
any(<(0), α) && throw(ArgumentError("At least one element of α is negative"))
any(any.(<(0), ω)) && throw(ArgumentError("At least one element of ω is negative"))
return new{typeof(h),typeof(α),typeof(γ),typeof(ω)}(h, α, γ, ω)
end
end

@functor SpectralMixtureKernel

function SpectralMixtureKernel(
h::Kernel,
αs::AbstractVector{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractVector{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
if !(size(αs, 1) == size(γs, 2) == size(ωs, 2))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
end
if size(γs) != size(ωs)
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
end
size(γ) == size(ω) || throw(DimensionMismatch("γ and ω have different dimensions"))
return SpectralMixtureKernel(h, α, ColVecs(γ), ColVecs(ω))
end

return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
a = TransformedKernel(h, LinearTransform(γ'))
b = TransformedKernel(CosineKernel(), LinearTransform(ω'))
return α * a * b
function SpectralMixtureKernel(
αs::AbstractVector{<:Real}, γs::AbstractVecOrMat, ωs::AbstractVecOrMat
)
return SpectralMixtureKernel(SqExponentialKernel(), αs, γs, ωs)
end

function (κ::SpectralMixtureKernel)(x, y)
xy = x - y
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
broadcasted = Broadcast.broadcasted(κ.α, κ.γ, κ.ω) do α, γ, ω
k = TransformedKernel(κ.kernel, ARDTransform(γ))
return α * k(x, y) * cospi(2 * dot(ω, xy))
end
return sum(Broadcast.instantiate(broadcasted))
end

function spectral_mixture_kernel(
αs::AbstractVector{<:Real}, γs::AbstractMatrix{<:Real}, ωs::AbstractMatrix{<:Real}
)
return spectral_mixture_kernel(SqExponentialKernel(), αs, γs, ωs)
function Base.show(io::IO, κ::SpectralMixtureKernel)
return print(
io,
"SpectralMixtureKernel Kernel (kernel = ",
κ.kernel,
", # components = ",
length(κ.α),
")",
)
end

"""
@doc raw"""
spectral_mixture_product_kernel(
h::Kernel=SqExponentialKernel(),
αs::AbstractMatrix{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractMatrix{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)

where αs are the weights of dimension (D, A), γs is the covariance matrix of
dimension (D, A) and ωs are the mean vectors and is of dimension (D, A).
Here, D is input dimension and A is the number of spectral components.

Spectral Mixture Product Kernel. With enough components A, the SMP kernel
The spectral mixture product is tensor product of spectral mixture kernel applied
on each dimension as described in [1] in equations 13 and 14.
With enough components, the SMP kernel
can model any product kernel to arbitrary precision, and is flexible even
with a small number of components [1]
with a small number of components

## Definition

`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
For inputs ``x, x′ \in \mathbb{R}^D``, the spectral mixture product kernel ``\tilde{k}`` with ``K`` mixture components, mixture weights ``\alpha_1, \alpha_2, \ldots, \alpha_K \in \mathbb{R}^D``, linear transformations ``\gamma_1, \ldots, \gamma_K \in \mathbb{R}^D``, and frequencies ``\omega_1, \ldots, \omega_K \in \mathbb{R}^D`` derived from a translation-invariant kernel ``k`` is defined as

```math
κ(x, y) = Πᵢ₌₁ᴷ Σ(αsᵢᵀ .* (h(-(γsᵢᵀ * tᵢ)²) .* cos(ωsᵢᵀ * tᵢ))), tᵢ = xᵢ - yᵢ
\tilde{k}(x, x'; \alpha_1, \ldots, \alpha_k, \gamma_1, \ldots, \gamma_K, \omega_1, \ldots, \omega_K, k) = \prod_{i=1}^D \sum_{k=1}^K \alpha_{ik} \cdot h(\gamma_{ik} \cdot x_i, \gamma_{ik} \cdot y_i)) \cdot \cos(2\pi \cdot \omega_{ik} \cdot (x_i - y_i))))
```

## Arguments
- `h`: Stationary kernel (translation invariant), [`SqExponentialKernel`](@ref) by default
- `α`: Weight of each mixture component for each dimension
- `γ`: Linear transformation of the input for `h`.
- `ω`: Frequencies for the cosine function.

`α`, `γ` and `ω` can be an
- `AbstractMatrix` of dimension `D x K` where `D` is the dimension of the inputs
and `K` is the number of components
- `AbstractVector` of `D` `K`-dimensional `AbstractVector`


# References:
[1] GPatt: Fast Multidimensional Pattern Extrapolation with GPs,
arXiv 1310.5288, 2013, by Andrew Gordon Wilson, Elad Gilboa,
Arye Nehorai and John P. Cunningham
"""
function spectral_mixture_product_kernel(
h::Kernel,
αs::AbstractMatrix{<:Real},
γs::AbstractMatrix{<:Real},
ωs::AbstractMatrix{<:Real},
α::AbstractMatrix{<:Real},
γ::AbstractMatrix{<:Real},
ω::AbstractMatrix{<:Real},
)
if !(size(αs) == size(γs) == size(ωs))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
(size(α) == size(γ) == size(ω)) ||
throw(DimensionMismatch("α, γ and ω have different dimensions"))
return spectral_mixture_product_kernel(h, RowVecs(α), RowVecs(γ), RowVecs(ω))
end

function spectral_mixture_product_kernel(
h::Kernel,
α::AbstractVector{<:AbstractVector{<:Real}},
γ::AbstractVector{<:AbstractVector{<:Real}},
ω::AbstractVector{<:AbstractVector{<:Real}},
)
return mapreduce(⊗, α, γ, ω) do αᵢ, γᵢ, ωᵢ
return SpectralMixtureKernel(h, αᵢ, permutedims(γᵢ), permutedims(ωᵢ))
end
return KernelTensorProduct(
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
)
end

function spectral_mixture_product_kernel(
αs::AbstractMatrix{<:Real}, γs::AbstractMatrix{<:Real}, ωs::AbstractMatrix{<:Real}
α::AbstractVecOrMat, γ::AbstractVecOrMat, ω::AbstractVecOrMat
)
return spectral_mixture_product_kernel(SqExponentialKernel(), αs, γs, ωs)
return spectral_mixture_product_kernel(SqExponentialKernel(), α, γ, ω)
end
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate spectral_mixture_kernel SpectralMixtureKernel
Empty file added src/distances/binaryop.jl
Empty file.
6 changes: 2 additions & 4 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end
struct Delta <: AbstractBinaryOp end

# Basic definitions
(dist::Delta)(a::Number, b::Number) = a == b
Base.@propagate_inbounds function (dist::Delta)(
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
Expand All @@ -14,5 +14,3 @@ Base.@propagate_inbounds function (dist::Delta)(
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
31 changes: 16 additions & 15 deletions src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
struct DotProduct <: Distances.UnionPreMetric end
struct DotProduct <: AbstractBinaryOp end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
throw(
DimensionMismatch(
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
),
)
end
return dot(a, b)
(::DotProduct)(a::AbstractVector, b::AbstractVector) = dot(a, b)

(::DotProduct)(a::Number, b::Number) = a * b

function pairwise(::DotProduct, x::ColVecs, y::ColVecs)
return @tullio out[i, j] := x.X[k, i] * y.X[k, j]
end

Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
function pairwise(::DotProduct, x::RowVecs, y::RowVecs)
return @tullio out[i, j] := x.X[i, k] * y.X[j, k]
end

@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray)
return Distances._evaluate(dist, a, b)
function colwise(::DotProduct, x::RowVecs, y::RowVecs=x)
return @tullio out[i] := x.X[i, k] * y.X[i, k]
end
@inline (dist::DotProduct)(a::Number, b::Number) = a * b

function colwise(::DotProduct, x::ColVecs, y::ColVecs=x)
return @tullio out[i] := x.X[k, i] * y.X[k, i]
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Loading