-
Notifications
You must be signed in to change notification settings - Fork 41
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
e525614
Use broadcasting instead of map for kerneldiagmatrix
theogf e56492a
Removed method for transformedkernel
theogf 35a6306
Restored functions and applied suggestions
theogf 25e5efd
Added tests for diagmatrix
theogf 2f85ebc
Put changes to the right file and removed utils_AD.jl
theogf cae225f
Apply suggestions from code review
theogf 3f16f07
Added colwise and fixed kerneldiagmatrix
theogf 8c0d0a2
Added colwise for RowVecs and ColVecs
theogf 13a10fd
Removed definition relying on Distances.colwise!
theogf 78a2078
Merge branch 'master' into fix_diagmat
theogf 5ca94e7
Readapt to kernelmatrix_diag
theogf 2c60abd
Fixes for Zygote
theogf 9214211
Remove type piracy
theogf 87edbc8
Adding some adjoints (not everything fixed yet)
theogf f65556b
Fixed adjoint for polynomials
theogf 48e2dcb
Add ChainRulesCore for defining rrule
theogf 6cc803d
Replace broadcast by map
theogf 0e30941
Missing return for style
theogf 61869b1
Fixing ZygoteRules
theogf 06bd4f0
Renamed zygote_adjoints to chainrules
theogf 8e1e516
Apply formatting suggestions
theogf aaa16de
Added forward rule for Euclidean distance
theogf 52b1ae5
Corrected rules for Row/ColVecs constructors
theogf 4067a42
Added ZygoteRules back for the "map hack"
theogf 641ebee
Corrected the rrules
theogf 13d1e39
Type stable frule
theogf 4675c2f
Corrected tests
theogf 0b97c1a
Adapted the use of Distances.jl
theogf ad9838e
Added methods to make nn work
theogf 650dc08
Missing kernelmatrix_diag
theogf 1703db1
Formatting suggestions
theogf e2cd167
Added methods for FBM
theogf 01ffac0
Last fix on Delta
theogf 9bfb6eb
Potential fix for Euclidean
theogf f3fa4bc
Missing Distances.
theogf a0c2a64
Wrong file naming
theogf ff5a66b
Correct formatting
theogf 8157b4c
Better error message
theogf e6bfdb1
Moar formatting
theogf db5e7b8
Applied suggestions
theogf a44a762
Fixed the dims issue with pairwise
theogf 72889dd
Fixed formatting
theogf 25549c1
Missing @thunk
theogf bbe5c7c
Putting back Composite to Any
theogf e08dbf4
add @thunk for -delta a
theogf 48bd681
Update src/chainrules.jl
theogf 3298d34
Update KernelFunctions.jl
theogf 0b99771
Apply suggestions from code review
theogf c26edf3
Update Project.toml
theogf 647862a
Merge branch 'master' into fix_diagmat
theogf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
|
||
const FDM = FiniteDifferences.central_fdm(5, 1) | ||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
gradient(f, s::Symbol, args) = gradient(f, Val(s), args) | ||
|
||
function gradient(f, ::Val{:Zygote}, args) | ||
g = first(Zygote.gradient(f, args)) | ||
if isnothing(g) | ||
if args isa AbstractArray{<:Real} | ||
return zeros(size(args)) # To respect the same output as other ADs | ||
else | ||
return zeros.(size.(args)) | ||
end | ||
else | ||
return g | ||
end | ||
end | ||
|
||
function gradient(f, ::Val{:ForwardDiff}, args) | ||
ForwardDiff.gradient(f, args) | ||
end | ||
|
||
function gradient(f, ::Val{:ReverseDiff}, args) | ||
ReverseDiff.gradient(f, args) | ||
end | ||
|
||
function gradient(f, ::Val{:FiniteDiff}, args) | ||
first(FiniteDifferences.grad(FDM, f, args)) | ||
end | ||
|
||
function compare_gradient(f, AD::Symbol, args) | ||
grad_AD = gradient(f, AD, args) | ||
grad_FD = gradient(f, :FiniteDiff, args) | ||
@test grad_AD ≈ grad_FD atol=1e-8 rtol=1e-5 | ||
end | ||
|
||
testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim)) | ||
testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim)) | ||
testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim = dim)) | ||
testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim = dim)) | ||
|
||
function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3]) | ||
test_fd = test_FiniteDiff(kernelfunction, args, dims) | ||
if !test_fd.anynonpass | ||
for AD in ADs | ||
test_AD(AD, kernelfunction, args, dims) | ||
end | ||
end | ||
end | ||
|
||
function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) | ||
# Init arguments : | ||
k = if args === nothing | ||
kernelfunction() | ||
else | ||
kernelfunction(args) | ||
end | ||
rng = MersenneTwister(42) | ||
@testset "FiniteDifferences" begin | ||
if k isa SimpleKernel | ||
for d in log.([eps(), rand(rng)]) | ||
@test_nowarn gradient(:FiniteDiff, [d]) do x | ||
kappa(k, exp(first(x))) | ||
end | ||
end | ||
end | ||
## Testing Kernel Functions | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
@test_nowarn gradient(:FiniteDiff, x) do x | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
## Testing Kernel Matrices | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff , A) do a | ||
testfunction(k, a, B, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, B) do b | ||
testfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
testfunction(kernelfunction(p), A, B, dim) | ||
end | ||
end | ||
|
||
@test_nowarn gradient(:FiniteDiff, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff , A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
@test_nowarn gradient(:FiniteDiff, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
@test_nowarn gradient(:FiniteDiff, args) do p | ||
testdiagfunction(kernelfunction(p), A, B, dim) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) | ||
@testset "$(AD)" begin | ||
# Test kappa function | ||
k = if args === nothing | ||
kernelfunction() | ||
else | ||
kernelfunction(args) | ||
end | ||
rng = MersenneTwister(42) | ||
if k isa SimpleKernel | ||
for d in log.([eps(), rand(rng)]) | ||
compare_gradient(AD, [d]) do x | ||
kappa(k, exp(x[1])) | ||
end | ||
end | ||
end | ||
# Testing kernel evaluations | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
compare_gradient(AD, x) do x | ||
k(x, y) | ||
end | ||
compare_gradient(AD, y) do y | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
kernelfunction(p)(x,y) | ||
end | ||
end | ||
# Testing kernel matrices | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
testfunction(kernelfunction(p), A, dim) | ||
end | ||
end | ||
|
||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
testdiagfunction(kernelfunction(p), A, dim) | ||
end | ||
end | ||
end | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.