@@ -1621,16 +1621,15 @@ end
1621
1621
1622
1622
@testset " AbstractFFTs" begin
1623
1623
1624
- # Many of these tests check a complex gradient to a function with real input. This is now
1625
- # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function:
1626
- function oldgradient (f, args... )
1627
- y, back = Zygote. pullback (f, args... )
1628
- back (Zygote. sensitivity (y))
1629
- end
1630
- # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests
1631
- # can be updated to use real / complex consistently.
1624
+ # Eventually these rules and tests will be moved to AbstractFFTs.jl
1625
+ # Rules for direct invocation of [i,r,b]fft have already been defined in
1632
1626
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58
1633
1627
1628
+ # ChainRules involving AbstractFFTs.Plan are not yet part of AbstractFFTs,
1629
+ # but there is a WIP PR:
1630
+ # https://github.com/JuliaMath/AbstractFFTs.jl/pull/67
1631
+ # After the above is merged, this testset can probably be removed entirely.
1632
+
1634
1633
findicateMat (i,j,n1,n2) = [(k== i) && (l== j) ? 1.0 : 0.0 for k= 1 : n1,
1635
1634
l= 1 : n2]
1636
1635
mirrorIndex (i,N) = i - 2 * max (0 ,i - (N>> 1 + 1 ))
@@ -1643,45 +1642,41 @@ end
1643
1642
indicateMat = [(k== i) && (l== j) ? 1.0 : 0.0 for k= 1 : size (X, 1 ),
1644
1643
l= 1 : size (X,2 )]
1645
1644
# gradient of ifft(fft) must be (approximately) 1 (for various cases)
1646
- @test oldgradient ((X)-> real .(ifft (fft (X))[i, j]), X)[1 ] ≈ indicateMat
1645
+ @test gradient ((X)-> real .(ifft (fft (X))[i, j]), X)[1 ] ≈ indicateMat
1647
1646
# same for the inverse
1648
- @test oldgradient ((X̂)-> real .(fft (ifft (X̂))[i, j]), X̂)[1 ] ≈ indicateMat
1647
+ @test gradient ((X̂)-> real .(fft (ifft (X̂))[i, j]), X̂)[1 ] ≈ indicateMat
1649
1648
# same for rfft(irfft)
1650
- @test oldgradient ((X)-> real .(irfft (rfft (X), size (X,1 )))[i, j], X)[1 ] ≈ real .(indicateMat)
1651
- # rfft isn't actually surjective, so rffft (irfft) can't really be tested this way.
1649
+ @test gradient ((X)-> real .(irfft (rfft (X), size (X,1 )))[i, j], X)[1 ] ≈ real .(indicateMat)
1650
+ # rfft isn't actually surjective, so rfft (irfft) can't really be tested this way.
1652
1651
1653
1652
# the gradients are actually just evaluating the inverse transform on the
1654
1653
# indicator matrix
1655
1654
mirrorI = mirrorIndex (i,sizeX[1 ])
1656
1655
FreqIndMat = findicateMat (mirrorI, j, size (X̂r,1 ), sizeX[2 ])
1657
- listOfSols = [(fft, bfft (indicateMat), bfft (indicateMat* im),
1658
- plan_fft (X), i, X),
1659
- (ifft, 1 / N* fft (indicateMat), 1 / N* fft (indicateMat* im),
1660
- plan_fft (X), i, X),
1661
- (bfft, fft (indicateMat), fft (indicateMat* im), nothing , i,
1662
- X),
1663
- (rfft, real .(brfft (FreqIndMat, sizeX[1 ])),
1664
- real .(brfft (FreqIndMat* im, sizeX[1 ])), plan_rfft (X),
1665
- mirrorI, X),
1666
- ((K)-> (irfft (K,sizeX[1 ])), 1 / N * rfft (indicateMat),
1667
- zeros (size (X̂r)), plan_rfft (X), i, X̂r)]
1668
- for (trans, solRe, solIm, P, mI, evalX) in listOfSols
1669
- @test oldgradient ((X)-> real .(trans (X))[mI, j], evalX)[1 ] ≈
1656
+ listOfSols = [(X -> fft (X, (1 , 2 )), real (bfft (indicateMat)), real (bfft (indicateMat* im)),
1657
+ plan_fft (X), i, X, true ),
1658
+ (K -> ifft (K, (1 , 2 )), 1 / N* real (fft (indicateMat)), 1 / N* real (fft (indicateMat* im)),
1659
+ plan_fft (X), i, X, false ),
1660
+ (X -> bfft (X, (1 , 2 )), real (fft (indicateMat)), real (fft (indicateMat* im)), nothing , i,
1661
+ X, false ),
1662
+ ]
1663
+ for (trans, solRe, solIm, P, mI, evalX, fft_or_rfft) in listOfSols
1664
+ @test gradient ((X)-> real .(trans (X))[mI, j], evalX)[1 ] ≈
1670
1665
solRe
1671
- @test oldgradient ((X)-> imag .(trans (X))[mI, j], evalX)[1 ] ≈
1666
+ @test gradient ((X)-> imag .(trans (X))[mI, j], evalX)[1 ] ≈
1672
1667
solIm
1673
- if typeof (P) <: AbstractFFTs.Plan && maximum (trans .== [fft,rfft])
1674
- @test oldgradient ((X)-> real .(P * X)[mI, j], evalX)[1 ] ≈
1668
+ if typeof (P) <: AbstractFFTs.Plan && fft_or_rfft
1669
+ @test gradient ((X)-> real .(P * X)[mI, j], evalX)[1 ] ≈
1675
1670
solRe
1676
- @test oldgradient ((X)-> imag .(P * X)[mI, j], evalX)[1 ] ≈
1671
+ @test gradient ((X)-> imag .(P * X)[mI, j], evalX)[1 ] ≈
1677
1672
solIm
1678
1673
elseif typeof (P) <: AbstractFFTs.Plan
1679
- @test oldgradient ((X)-> real .(P \ X)[mI, j], evalX)[1 ] ≈
1674
+ @test gradient ((X)-> real .(P \ X)[mI, j], evalX)[1 ] ≈
1680
1675
solRe
1681
1676
# for whatever reason the rfft_plan doesn't handle this case well,
1682
1677
# even though irfft does
1683
1678
if eltype (evalX) <: Real
1684
- @test oldgradient ((X)-> imag .(P \ X)[mI, j], evalX)[1 ] ≈
1679
+ @test gradient ((X)-> imag .(P \ X)[mI, j], evalX)[1 ] ≈
1685
1680
solIm
1686
1681
end
1687
1682
end
@@ -1692,47 +1687,44 @@ end
1692
1687
x = [- 0.353213 - 0.789656 - 0.270151 ; - 0.95719 - 1.27933 0.223982 ]
1693
1688
# check ffts for individual dimensions
1694
1689
for trans in (fft, ifft, bfft)
1695
- @test oldgradient ((x)-> sum (abs .(trans (x))), x)[1 ] ≈
1696
- oldgradient ( (x) -> sum (abs .(trans (trans (x,1 ),2 ))), x)[1 ]
1690
+ @test gradient ((x)-> sum (abs .(trans (x, ( 1 , 2 ) ))), x)[1 ] ≈
1691
+ gradient ( (x) -> sum (abs .(trans (trans (x,1 ),2 ))), x)[1 ]
1697
1692
# switch sum abs order
1698
- @test oldgradient ((x)-> abs (sum ((trans (x)))),x)[1 ] ≈
1699
- oldgradient ( (x) -> abs (sum (trans (trans (x,1 ),2 ))), x)[1 ]
1693
+ @test gradient ((x)-> abs (sum ((trans (x)))),x)[1 ] ≈
1694
+ gradient ( (x) -> abs (sum (trans (trans (x,1 ),2 ))), x)[1 ]
1700
1695
# dims parameter for the function
1701
- @test oldgradient ((x, dims)-> sum (abs .(trans (x,dims))), x, (1 ,2 ))[1 ] ≈
1702
- oldgradient ( (x) -> sum (abs .(trans (x))), x)[1 ]
1703
- # (1,2) should be the same as no index
1704
- @test oldgradient ( (x) -> sum (abs .(trans (x,(1 ,2 )))), x)[1 ] ≈
1705
- oldgradient ( (x) -> sum (abs .(trans (trans (x,1 ),2 ))), x)[1 ]
1706
- @test gradcheck (x-> sum (abs .(trans (x))), x)
1696
+ @test gradient ((x, dims)-> sum (abs .(trans (x,dims))), x, (1 ,2 ))[1 ] ≈
1697
+ gradient ( (x) -> sum (abs .(trans (x, (1 , 2 )))), x)[1 ]
1698
+ @test gradcheck (x-> sum (abs .(trans (x, (1 , 2 )))), x)
1707
1699
@test gradcheck (x-> sum (abs .(trans (x, 2 ))), x)
1708
1700
end
1709
1701
1710
- @test oldgradient ((x)-> sum (abs .(rfft (x))), x)[1 ] ≈
1711
- oldgradient ( (x) -> sum (abs .(fft (rfft (x,1 ),2 ))), x)[1 ]
1712
- @test oldgradient ((x, dims)-> sum (abs .(rfft (x,dims))), x, (1 ,2 ))[1 ] ≈
1713
- oldgradient ( (x) -> sum (abs .(rfft (x))), x)[1 ]
1702
+ @test gradient ((x)-> sum (abs .(rfft (x, ( 1 , 2 ) ))), x)[1 ] ≈
1703
+ gradient ( (x) -> sum (abs .(fft (rfft (x,1 ),2 ))), x)[1 ]
1704
+ @test gradient ((x, dims)-> sum (abs .(rfft (x,dims))), x, (1 ,2 ))[1 ] ≈
1705
+ gradient ( (x) -> sum (abs .(rfft (x, ( 1 , 2 ) ))), x)[1 ]
1714
1706
1715
1707
# Test type stability of fft
1716
1708
1717
1709
x = randn (Float64,16 )
1718
1710
P = plan_fft (x)
1719
- @test typeof (oldgradient (x-> sum (abs2,ifft (fft (x) )),x)[1 ]) == Array{Complex{ Float64} ,1 }
1720
- @test typeof (oldgradient (x-> sum (abs2,P\ (P* x)),x)[1 ]) == Array{Complex{ Float64} ,1 }
1721
- @test typeof (oldgradient (x-> sum (abs2,irfft (rfft (x),16 )),x)[1 ]) == Array{Float64,1 }
1711
+ @test typeof (gradient (x-> sum (abs2,ifft (fft (x, 1 ), 1 )),x)[1 ]) == Array{Float64,1 }
1712
+ @test typeof (gradient (x-> sum (abs2,P\ (P* x)),x)[1 ]) == Array{Float64,1 }
1713
+ @test typeof (gradient (x-> sum (abs2,irfft (rfft (x, 1 ),16 , 1 )),x)[1 ]) == Array{Float64,1 }
1722
1714
1723
1715
x = randn (Float64,16 ,16 )
1724
- @test typeof (oldgradient (x-> sum (abs2,ifft (fft (x,1 ),1 )),x)[1 ]) == Array{Complex{ Float64} ,2 }
1725
- @test typeof (oldgradient (x-> sum (abs2,irfft (rfft (x,1 ),16 ,1 )),x)[1 ]) == Array{Float64,2 }
1716
+ @test typeof (gradient (x-> sum (abs2,ifft (fft (x,1 ),1 )),x)[1 ]) == Array{Float64,2 }
1717
+ @test typeof (gradient (x-> sum (abs2,irfft (rfft (x,1 ),16 ,1 )),x)[1 ]) == Array{Float64,2 }
1726
1718
1727
1719
x = randn (Float32,16 )
1728
1720
P = plan_fft (x)
1729
- @test typeof (oldgradient (x-> sum (abs2,ifft (fft (x) )),x)[1 ]) == Array{Complex{ Float32} ,1 }
1730
- @test typeof (oldgradient (x-> sum (abs2,P\ (P* x)),x)[1 ]) == Array{Complex{ Float32} ,1 }
1731
- @test typeof (oldgradient (x-> sum (abs2,irfft (rfft (x),16 )),x)[1 ]) == Array{Float32,1 }
1721
+ @test typeof (gradient (x-> sum (abs2,ifft (fft (x, 1 ), 1 )),x)[1 ]) == Array{Float32,1 }
1722
+ @test typeof (gradient (x-> sum (abs2,P\ (P* x)),x)[1 ]) == Array{Float32,1 }
1723
+ @test typeof (gradient (x-> sum (abs2,irfft (rfft (x, 1 ),16 , 1 )),x)[1 ]) == Array{Float32,1 }
1732
1724
1733
1725
x = randn (Float32,16 ,16 )
1734
- @test typeof (oldgradient (x-> sum (abs2,ifft (fft (x,1 ),1 )),x)[1 ]) == Array{Complex{ Float32} ,2 }
1735
- @test typeof (oldgradient (x-> sum (abs2,irfft (rfft (x,1 ),16 ,1 )),x)[1 ]) == Array{Float32,2 }
1726
+ @test typeof (gradient (x-> sum (abs2,ifft (fft (x,1 ),1 )),x)[1 ]) == Array{Float32,2 }
1727
+ @test typeof (gradient (x-> sum (abs2,irfft (rfft (x,1 ),16 ,1 )),x)[1 ]) == Array{Float32,2 }
1736
1728
end
1737
1729
1738
1730
@testset " FillArrays" begin
0 commit comments