Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.1.0"
version = "1.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`.
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `region(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ AbstractFFTs.brfft
AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.region
AbstractFFTs.fftshift
AbstractFFTs.ifftshift
AbstractFFTs.fftfreq
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ChainRulesCore
export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
region, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("chainrules.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

"""
region(p::Plan)

Return an iterable of the dimensions that are transformed by the FFT plan `p`.

# Implementation

The default definition of `region` returns `p.region`.
Hence this method should be implemented only for types of `Plan`s that do not store the transformed region in a field of name `region`.
"""
region(p::Plan) = p.region

fftfloat(x) = _fftfloat(float(x))
_fftfloat(::Type{T}) where {T<:BlasReal} = T
_fftfloat(::Type{Float16}) = Float32
Expand Down Expand Up @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)

region(p::ScaledPlan) = region(p.p)

show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))

Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ end
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test AbstractFFTs.region(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft
P = plan_bfft(x, dims)
@test P * y ≈ fftw_bfft
@test P \ (P * y) ≈ y
@test AbstractFFTs.region(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft
P = plan_ifft(x, dims)
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test AbstractFFTs.region(P) == dims

# real FFT
fftw_rfft = fftw_fft[
Expand All @@ -84,18 +87,21 @@ end
@test eltype(P) === Int
@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test AbstractFFTs.region(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_brfft
@test P \ (P * ry) ≈ ry
@test AbstractFFTs.region(P) == dims

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test AbstractFFTs.region(P) == dims
end
end

Expand Down Expand Up @@ -187,7 +193,7 @@ end
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, AbstractFFTs.region(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

Expand Down