Skip to content

Commit 942ca6d

Browse files
trahflowToucheSir
andauthored
drop adjoints for [i,r,b]fft() (#1386)
* drop adjoints for [i,r,b]fft() Partially addresses #1377 ChainRules for these have been added in JuliaMath/AbstractFFTs.jl#58 * add back gradient test for *fft without dims argument * increase compat constraint for AbstractFFTs to 1.3.1 * fix typo Co-authored-by: Brian Chen <[email protected]> --------- Co-authored-by: Brian Chen <[email protected]>
1 parent fb93ba5 commit 942ca6d

File tree

3 files changed

+48
-179
lines changed

3 files changed

+48
-179
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2727
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2828

2929
[compat]
30-
AbstractFFTs = "0.5, 1.0"
30+
AbstractFFTs = "1.3.1"
3131
ChainRules = "1.44.1"
3232
ChainRulesCore = "1.9"
3333
ChainRulesTestUtils = "1"

src/lib/array.jl

Lines changed: 0 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,6 @@ AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims
665665

666666
# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
667667
# gradient of its inputs, but with different normalization factor
668-
@adjoint function fft(xs)
669-
return AbstractFFTs.fft(xs), function(Δ)
670-
return (AbstractFFTs.bfft(Δ),)
671-
end
672-
end
673-
674668
@adjoint function *(P::AbstractFFTs.Plan, xs)
675669
return P * xs, function(Δ)
676670
N = prod(size(xs)[[P.region...]])
@@ -685,123 +679,6 @@ end
685679
end
686680
end
687681

688-
# all of the plans normalize their inverse, while we need the unnormalized one.
689-
@adjoint function ifft(xs)
690-
return AbstractFFTs.ifft(xs), function(Δ)
691-
N = length(xs)
692-
return (AbstractFFTs.fft(Δ)/N,)
693-
end
694-
end
695-
696-
@adjoint function bfft(xs)
697-
return AbstractFFTs.bfft(xs), function(Δ)
698-
return (AbstractFFTs.fft(Δ),)
699-
end
700-
end
701-
702-
@adjoint function fftshift(x)
703-
return fftshift(x), function(Δ)
704-
return (ifftshift(Δ),)
705-
end
706-
end
707-
708-
@adjoint function ifftshift(x)
709-
return ifftshift(x), function(Δ)
710-
return (fftshift(Δ),)
711-
end
712-
end
713-
714-
715-
# to actually use rfft, one needs to insure that everything
716-
# that happens in the Fourier domain could've been done in
717-
# the space domain with real numbers. This means enforcing
718-
# conjugate symmetry along all transformed dimensions besides
719-
# the first. Otherwise this is going to result in *very* weird
720-
# behavior.
721-
@adjoint function rfft(xs::AbstractArray{<:Real})
722-
return AbstractFFTs.rfft(xs), function(Δ)
723-
N = length(Δ)
724-
originalSize = size(xs,1)
725-
return (AbstractFFTs.brfft(Δ, originalSize),)
726-
end
727-
end
728-
729-
@adjoint function irfft(xs, d)
730-
return AbstractFFTs.irfft(xs, d), function(Δ)
731-
total = length(Δ)
732-
fullTransform = AbstractFFTs.rfft(real.(Δ))/total
733-
return (fullTransform, nothing)
734-
end
735-
end
736-
737-
@adjoint function brfft(xs, d)
738-
return AbstractFFTs.brfft(xs, d), function(Δ)
739-
fullTransform = AbstractFFTs.rfft(real.(Δ))
740-
return (fullTransform, nothing)
741-
end
742-
end
743-
744-
745-
# if we're specifying the dimensions
746-
@adjoint function fft(xs, dims)
747-
return AbstractFFTs.fft(xs, dims), function(Δ)
748-
# dims can be int, array or tuple,
749-
# convert to collection for use as index
750-
dims = collect(dims)
751-
return (AbstractFFTs.bfft(Δ, dims), nothing)
752-
end
753-
end
754-
755-
@adjoint function bfft(xs, dims)
756-
return AbstractFFTs.ifft(xs, dims), function(Δ)
757-
dims = collect(dims)
758-
return (AbstractFFTs.fft(Δ, dims),nothing)
759-
end
760-
end
761-
762-
@adjoint function ifft(xs, dims)
763-
return AbstractFFTs.ifft(xs, dims), function(Δ)
764-
dims = collect(dims)
765-
N = prod(collect(size(xs))[dims])
766-
return (AbstractFFTs.fft(Δ, dims)/N,nothing)
767-
end
768-
end
769-
770-
@adjoint function rfft(xs, dims)
771-
return AbstractFFTs.rfft(xs, dims), function(Δ)
772-
dims = collect(dims)
773-
N = prod(collect(size(xs))[dims])
774-
return (N * AbstractFFTs.irfft(Δ, size(xs,dims[1]), dims), nothing)
775-
end
776-
end
777-
778-
@adjoint function irfft(xs, d, dims)
779-
return AbstractFFTs.irfft(xs, d, dims), function(Δ)
780-
dims = collect(dims)
781-
N = prod(collect(size(xs))[dims])
782-
return (AbstractFFTs.rfft(real.(Δ), dims)/N, nothing, nothing)
783-
end
784-
end
785-
@adjoint function brfft(xs, d, dims)
786-
return AbstractFFTs.brfft(xs, d, dims), function(Δ)
787-
dims = collect(dims)
788-
return (AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing)
789-
end
790-
end
791-
792-
793-
@adjoint function fftshift(x, dims)
794-
return fftshift(x), function(Δ)
795-
return (ifftshift(Δ, dims), nothing)
796-
end
797-
end
798-
799-
@adjoint function ifftshift(x, dims)
800-
return ifftshift(x), function(Δ)
801-
return (fftshift(Δ, dims), nothing)
802-
end
803-
end
804-
805682
# FillArray functionality
806683
# =======================
807684

test/gradcheck.jl

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,16 +1621,15 @@ end
16211621

16221622
@testset "AbstractFFTs" begin
16231623

1624-
# Many of these tests check a complex gradient to a function with real input. This is now
1625-
# clamped to real by ProjectTo, but to run the old tests, use here the old gradient function:
1626-
function oldgradient(f, args...)
1627-
y, back = Zygote.pullback(f, args...)
1628-
back(Zygote.sensitivity(y))
1629-
end
1630-
# Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests
1631-
# can be updated to use real / complex consistently.
1624+
# Eventually these rules and tests will be moved to AbstractFFTs.jl
1625+
# Rules for direct invocation of [i,r,b]fft have already been defined in
16321626
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58
16331627

1628+
# ChainRules involving AbstractFFTs.Plan are not yet part of AbstractFFTs,
1629+
# but there is a WIP PR:
1630+
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/67
1631+
# After the above is merged, this testset can probably be removed entirely.
1632+
16341633
findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1,
16351634
l=1:n2]
16361635
mirrorIndex(i,N) = i - 2*max(0,i - (N>>1+1))
@@ -1643,45 +1642,41 @@ end
16431642
indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1),
16441643
l=1:size(X,2)]
16451644
# gradient of ifft(fft) must be (approximately) 1 (for various cases)
1646-
@test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] indicateMat
1645+
@test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] indicateMat
16471646
# same for the inverse
1648-
@test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] indicateMat
1647+
@test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] indicateMat
16491648
# same for rfft(irfft)
1650-
@test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] real.(indicateMat)
1651-
# rfft isn't actually surjective, so rffft(irfft) can't really be tested this way.
1649+
@test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] real.(indicateMat)
1650+
# rfft isn't actually surjective, so rfft(irfft) can't really be tested this way.
16521651

16531652
# the gradients are actually just evaluating the inverse transform on the
16541653
# indicator matrix
16551654
mirrorI = mirrorIndex(i,sizeX[1])
16561655
FreqIndMat = findicateMat(mirrorI, j, size(X̂r,1), sizeX[2])
1657-
listOfSols = [(fft, bfft(indicateMat), bfft(indicateMat*im),
1658-
plan_fft(X), i, X),
1659-
(ifft, 1/N*fft(indicateMat), 1/N*fft(indicateMat*im),
1660-
plan_fft(X), i, X),
1661-
(bfft, fft(indicateMat), fft(indicateMat*im), nothing, i,
1662-
X),
1663-
(rfft, real.(brfft(FreqIndMat, sizeX[1])),
1664-
real.(brfft(FreqIndMat*im, sizeX[1])), plan_rfft(X),
1665-
mirrorI, X),
1666-
((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat),
1667-
zeros(size(X̂r)), plan_rfft(X), i, X̂r)]
1668-
for (trans, solRe, solIm, P, mI, evalX) in listOfSols
1669-
@test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1]
1656+
listOfSols = [(X -> fft(X, (1, 2)), real(bfft(indicateMat)), real(bfft(indicateMat*im)),
1657+
plan_fft(X), i, X, true),
1658+
(K -> ifft(K, (1, 2)), 1/N*real(fft(indicateMat)), 1/N*real(fft(indicateMat*im)),
1659+
plan_fft(X), i, X, false),
1660+
(X -> bfft(X, (1, 2)), real(fft(indicateMat)), real(fft(indicateMat*im)), nothing, i,
1661+
X, false),
1662+
]
1663+
for (trans, solRe, solIm, P, mI, evalX, fft_or_rfft) in listOfSols
1664+
@test gradient((X)->real.(trans(X))[mI, j], evalX)[1]
16701665
solRe
1671-
@test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1]
1666+
@test gradient((X)->imag.(trans(X))[mI, j], evalX)[1]
16721667
solIm
1673-
if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft])
1674-
@test oldgradient((X)->real.(P * X)[mI, j], evalX)[1]
1668+
if typeof(P) <:AbstractFFTs.Plan && fft_or_rfft
1669+
@test gradient((X)->real.(P * X)[mI, j], evalX)[1]
16751670
solRe
1676-
@test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1]
1671+
@test gradient((X)->imag.(P * X)[mI, j], evalX)[1]
16771672
solIm
16781673
elseif typeof(P) <: AbstractFFTs.Plan
1679-
@test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1]
1674+
@test gradient((X)->real.(P \ X)[mI, j], evalX)[1]
16801675
solRe
16811676
# for whatever reason the rfft_plan doesn't handle this case well,
16821677
# even though irfft does
16831678
if eltype(evalX) <: Real
1684-
@test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1]
1679+
@test gradient((X)->imag.(P \ X)[mI, j], evalX)[1]
16851680
solIm
16861681
end
16871682
end
@@ -1692,47 +1687,44 @@ end
16921687
x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982]
16931688
# check ffts for individual dimensions
16941689
for trans in (fft, ifft, bfft)
1695-
@test oldgradient((x)->sum(abs.(trans(x))), x)[1]
1696-
oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
1690+
@test gradient((x)->sum(abs.(trans(x, (1, 2)))), x)[1]
1691+
gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
16971692
# switch sum abs order
1698-
@test oldgradient((x)->abs(sum((trans(x)))),x)[1]
1699-
oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
1693+
@test gradient((x)->abs(sum((trans(x)))),x)[1]
1694+
gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
17001695
# dims parameter for the function
1701-
@test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1]
1702-
oldgradient( (x) -> sum(abs.(trans(x))), x)[1]
1703-
# (1,2) should be the same as no index
1704-
@test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1]
1705-
oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
1706-
@test gradcheck(x->sum(abs.(trans(x))), x)
1696+
@test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1]
1697+
gradient( (x) -> sum(abs.(trans(x, (1, 2)))), x)[1]
1698+
@test gradcheck(x->sum(abs.(trans(x, (1, 2)))), x)
17071699
@test gradcheck(x->sum(abs.(trans(x, 2))), x)
17081700
end
17091701

1710-
@test oldgradient((x)->sum(abs.(rfft(x))), x)[1]
1711-
oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
1712-
@test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1]
1713-
oldgradient( (x) -> sum(abs.(rfft(x))), x)[1]
1702+
@test gradient((x)->sum(abs.(rfft(x, (1, 2)))), x)[1]
1703+
gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
1704+
@test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1]
1705+
gradient( (x) -> sum(abs.(rfft(x, (1, 2)))), x)[1]
17141706

17151707
# Test type stability of fft
17161708

17171709
x = randn(Float64,16)
17181710
P = plan_fft(x)
1719-
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1}
1720-
@test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1}
1721-
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1}
1711+
@test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float64,1}
1712+
@test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float64,1}
1713+
@test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float64,1}
17221714

17231715
x = randn(Float64,16,16)
1724-
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
1725-
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
1716+
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float64,2}
1717+
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
17261718

17271719
x = randn(Float32,16)
17281720
P = plan_fft(x)
1729-
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1}
1730-
@test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1}
1731-
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}
1721+
@test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float32,1}
1722+
@test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float32,1}
1723+
@test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float32,1}
17321724

17331725
x = randn(Float32,16,16)
1734-
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
1735-
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
1726+
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float32,2}
1727+
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
17361728
end
17371729

17381730
@testset "FillArrays" begin

0 commit comments

Comments
 (0)