Skip to content
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
9 changes: 5 additions & 4 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the trait experimental? And what exactly is experimental? I guess downstream packages would be interested in what parts (if any) are considered stable, if every patch release may break things etc.

To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::Plan, ::AbstractArray, ::AS)` and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e., adjoint_mul is part of the API?

`AbstractFFTs._output_size(::Plan, ::AS)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume _output_size should be implemented as well? Could we just rename it to output_size? To me the underscore suggests that it is an internal method.


`AbstractFFTs` pre-implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is meant with pre-implements? AFAICT the package implements these styles? Or is there anything missing?


The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
93 changes: 68 additions & 25 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,17 +583,58 @@ plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
abstract type AdjointStyle end

"""
FFTAdjointStyle()

Projection style for complex to complex discrete Fourier transforms that normalize
the output analogously to [`fft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct FFTAdjointStyle <: AdjointStyle end

"""
RFFTAdjointStyle()

Projection style for real to complex discrete Fourier transforms that halve
one of the output's dimensions and normalize the output analogously to [`rfft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the output is projected
to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RFFTAdjointStyle <: AdjointStyle end

"""
IRFFTAdjointStyle(d::Dim)

Projection style for complex to real discrete Fourier transforms that expect
an input with a halved dimension and normalize the output analogously to [`irfft`](@ref),
where `d` is the original length of the dimension.

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the input is projected
to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()

Projection style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

output_size(p::Plan) = _output_size(p, AdjointStyle(p))
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -620,40 +661,42 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x, AdjointStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down