From e2d94329fa48f54e2cbb678589413153b8a9158c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Sep 2023 11:44:16 +0100 Subject: [PATCH 1/4] make adjoint error message a bit more informative --- src/chainrules.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 3b52860dd..66ec2d0ca 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -150,8 +150,9 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some other external computation has acted on `ColVecs` to produce a vector of vectors." * + "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", ) end return ColVecs(X), ColVecs_pullback @@ -162,8 +163,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some other external computation has acted on `RowVecs` to produce a vector of vectors." * + "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", ) end return RowVecs(X), RowVecs_pullback From a8dbcba9e49d210c3dd76c85e6d48680163d3d92 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Sep 2023 11:50:45 +0100 Subject: [PATCH 2/4] Apply suggestions from code review --- src/chainrules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 66ec2d0ca..5292ca51c 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -151,7 +151,7 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * - "or because some other external computation has acted on `ColVecs` to produce a vector of vectors." * + "or because some external computation has acted on `ColVecs` to produce a vector of vectors." * "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", ) end @@ -164,7 +164,7 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * - "or because some other external computation has acted on `RowVecs` to produce a vector of vectors." * + "or because some external computation has acted on `RowVecs` to produce a vector of vectors." * "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", ) end From d6699508d7c637cfce50135f3c6bc4081f589479 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Sep 2023 12:25:36 +0100 Subject: [PATCH 3/4] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 617f6e56c..2830de211 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.56" +version = "0.10.57" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 36d149ab04f7a674c9aa6307e1325c4d021d71c2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 27 Sep 2023 15:08:21 +0100 Subject: [PATCH 4/4] Update src/chainrules.jl --- src/chainrules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 5292ca51c..eebdf95b5 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -152,7 +152,9 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) "Pullback on AbstractVector{<:AbstractVector}.\n" * "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * "or because some external computation has acted on `ColVecs` to produce a vector of vectors." * - "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + "In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." * + "In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," * + " rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it." ) end return ColVecs(X), ColVecs_pullback