Skip to content

Commit f972ff4

Browse files
authored
add chainrules (#9)
1 parent f3c03c4 commit f972ff4

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
name = "FastChebInterp"
22
uuid = "cf66c380-9a80-432c-aff8-4f9c79c0bdde"
3-
version = "1.0"
3+
version = "1.1"
44

55
[deps]
66
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
77
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89

910
[compat]
11+
ChainRulesCore = "1"
12+
ChainRulesTestUtils = "1"
1013
FFTW = "1.0"
1114
StaticArrays = "0.12, 1.0"
1215
julia = "1.3"
1316

1417
[extras]
18+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
19+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1520
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1621

1722
[targets]
18-
test = ["Test"]
23+
test = ["ChainRulesTestUtils", "Random", "Test"]

src/FastChebInterp.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ A multidimensional Chebyshev-polynomial interpolation object.
3434
Given a `c::ChebPoly`, you can evaluate it at a point `x`
3535
with `c(x)`, where `x` is a vector (or a scalar if `c` is 1d).
3636
"""
37-
struct ChebPoly{N,T,Td<:Real}
37+
struct ChebPoly{N,T,Td<:Real} <: Function
3838
coefs::Array{T,N} # chebyshev coefficients
3939
lb::SVector{N,Td} # lower/upper bounds
4040
ub::SVector{N,Td} # of the domain
@@ -48,9 +48,11 @@ function Base.show(io::IO, c::ChebPoly)
4848
end
4949
end
5050
Base.ndims(c::ChebPoly) = ndims(c.coefs)
51+
Base.zero(c::ChebPoly{N,T,Td}) where {N,T,Td} = ChebPoly{N,T,Td}(zero(c.coefs), c.lb, c.ub)
5152

5253
include("interp.jl")
5354
include("regression.jl")
5455
include("eval.jl")
56+
include("chainrules.jl")
5557

5658
end # module

src/chainrules.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import ChainRulesCore
2+
using ChainRulesCore: ProjectTo, NoTangent, @not_implemented
3+
4+
function ChainRulesCore.rrule(c::ChebPoly{1}, x::Real)
5+
project_x = ProjectTo(x)
6+
y, ∇y = chebgradient(c, x)
7+
chebpoly_pullback(∂y) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(real(∇y' * ∂y))
8+
y, chebpoly_pullback
9+
end
10+
11+
function ChainRulesCore.rrule(c::ChebPoly, x::AbstractVector{<:Real})
12+
project_x = ProjectTo(x)
13+
y, J = chebjacobian(c, x)
14+
chebpoly_pullback(Δy) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(vec(real(J' * Δy)))
15+
y, chebpoly_pullback
16+
end
17+
18+
ChainRulesCore.frule((Δself, Δx), c::ChebPoly{1}, x::Real) =
19+
ChainRulesCore.frule((Δself, SVector{1}(Δx)), c, SVector{1}(x))
20+
21+
function ChainRulesCore.frule((Δself, Δx), c::ChebPoly, x::AbstractVector)
22+
y, J = chebjacobian(c, x)
23+
if Δself isa ChainRulesCore.AbstractZero # Δself == 0
24+
Δy = J * Δx
25+
return y, y isa Number ? Δy[1] : Δy
26+
else # need derivatives with respect to changes in c
27+
# additional Δx from changes in bound:
28+
# --- recall x0 = @. (x - c.lb) * 2 / (c.ub - c.lb) - 1,
29+
# but note that J already includes 2 / (c.ub - c.lb)
30+
d2 = @. (x - c.lb) / (c.ub - c.lb)
31+
Δx′ = @. Δx + (d2 - 1) * Δself.lb - d2 * Δself.ub
32+
Δy = J * Δx′
33+
34+
# dependence on coefs is linear
35+
Δcoefs = typeof(c)(Δself.coefs, c.lb, c.ub)
36+
37+
return y, (y isa Number ? Δy[1] : Δy) + Δcoefs(x)
38+
end
39+
end

test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
using Test, FastChebInterp, StaticArrays
1+
using Test, FastChebInterp, StaticArrays, Random, ChainRulesTestUtils
22

33
# similar to ≈, but acts elementwise on tuples
44
′(a::Tuple, b::Tuple; kws...) where {N} = length(a) == length(b) && all(xy -> isapprox(xy[1],xy[2]; kws...), zip(a,b))
55

6+
Random.seed!(314159) # make chainrules tests deterministic
7+
68
@testset "1d test" begin
79
lb,ub = -0.3, 0.9
810
f(x) = exp(x) / (1 + 2x^2)
@@ -14,6 +16,8 @@ using Test, FastChebInterp, StaticArrays
1416
x1 = 0.2
1517
@test interp(x1) f(x1)
1618
@test chebgradient(interp, x1) ′ (f(x1), f′(x1))
19+
test_frule(interp, x1)
20+
test_rrule(interp, x1)
1721
end
1822

1923
@testset "2d test" begin
@@ -29,6 +33,8 @@ end
2933
@test interp(x1) interp0(x1) rtol=1e-15
3034
@test all(n -> n[1] < n[2], zip(size(interp.coefs), size(interp0.coefs)))
3135
@test chebgradient(interp, x1) ′ (f(x1), ∇f(x1))
36+
test_frule(interp, x1)
37+
test_rrule(interp, x1)
3238

3339
# univariate function in 2d should automatically drop down to univariate polynomial
3440
f1(x) = exp(x[1]) / (1 + 2x[1]^2)
@@ -42,6 +48,8 @@ end
4248
interp2 = chebinterp(f2.(x), lb, ub)
4349
@test interp2(x1) f2(x1)
4450
@test chebjacobian(interp2, x1) ′ (f2(x1), ∇f2(x1))
51+
test_frule(interp2, x1)
52+
test_rrule(interp2, x1)
4553

4654
# chebinterp_v1
4755
av1 = Array{ComplexF64}(undef, 2, size(x)...)

0 commit comments

Comments
 (0)