From 57be075144cbbdb287cba00b2a250f615516cc13 Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 1 Apr 2022 19:53:32 +0300 Subject: [PATCH 1/3] support ScalMat in FiniteGP observation covariance --- src/LaplaceApproximationModule.jl | 8 ++++---- src/SparseVariationalApproximationModule.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/LaplaceApproximationModule.jl b/src/LaplaceApproximationModule.jl index e9e968b6..beb8a257 100644 --- a/src/LaplaceApproximationModule.jl +++ b/src/LaplaceApproximationModule.jl @@ -179,16 +179,16 @@ function _check_laplace_inputs( end struct LaplaceCache{ - Tm<:AbstractMatrix,Tv<:AbstractVector,Td<:Diagonal,Tf<:Real,Tc<:Cholesky + Tm<:AbstractMatrix,Tv1<:AbstractVector,Tv2<:AbstractVector,Tv3<:AbstractVector,Td<:Diagonal,Tf<:Real,Tc<:Cholesky } K::Tm # kernel matrix - f::Tv # mode of posterior p(f | y) + f::Tv1 # mode of posterior p(f | y) W::Td # diagonal matrix of ∂²/∂fᵢ² loglik Wsqrt::Td # sqrt(W) loglik::Tf # ∑ᵢlog p(yᵢ|fᵢ) - d_loglik::Tv # ∂/∂fᵢloglik + d_loglik::Tv2 # ∂/∂fᵢloglik B_ch::Tc # cholesky(I + Wsqrt * K * Wsqrt) - a::Tv # K⁻¹ f + a::Tv3 # K⁻¹ f end function _laplace_train_intermediates(dist_y_given_f, ys, K, f) diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 132de1ab..2a134b4b 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -9,7 +9,7 @@ using LinearAlgebra using Statistics using StatsBase using FillArrays: Fill -using PDMats: chol_lower +using PDMats: chol_lower, ScalMat using AbstractGPs: AbstractGPs using AbstractGPs: @@ -306,7 +306,7 @@ Statistics. PMLR, 2015. """ function AbstractGPs.elbo( sva::SparseVariationalApproximation, - fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}}, + fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Union{Diagonal{<:Real,<:Fill},ScalMat}}, y::AbstractVector{<:Real}; num_data=length(y), quadrature=DefaultExpectationMethod(), From 693147dc5a400762145368f26f091aae6fbba25f Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 1 Apr 2022 19:53:46 +0300 Subject: [PATCH 2/3] WIP: ProjectTo type piracy AD fix --- src/ApproximateGPs.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 029be317..1e1e01db 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -23,4 +23,10 @@ include("deprecations.jl") include("TestUtils.jl") +import ChainRulesCore: ProjectTo, Tangent +using PDMats: ScalMat +ProjectTo(x::T) where T <: ScalMat = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value)) +(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value)) +(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value)) + end From 09044de4bb4a4de1aa3428943f67b13c0d2b9067 Mon Sep 17 00:00:00 2001 From: st-- Date: Sat, 2 Apr 2022 14:50:32 +0300 Subject: [PATCH 3/3] Update src/ApproximateGPs.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/ApproximateGPs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 1e1e01db..97a9d29b 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -25,7 +25,7 @@ include("TestUtils.jl") import ChainRulesCore: ProjectTo, Tangent using PDMats: ScalMat -ProjectTo(x::T) where T <: ScalMat = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value)) +ProjectTo(x::T) where {T<:ScalMat} = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value)) (pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value)) (pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value))