diff --git a/Project.toml b/Project.toml index 7177c71..5e2cdea 100644 --- a/Project.toml +++ b/Project.toml @@ -20,8 +20,8 @@ EllipsisNotation = "1" FFTW = "1" FourierTools = "0.4" IndexFunArrays = "0.2" -NDTools = "0.5, 0.6, 0.7" -Zygote = "0.6.60" +NDTools = "0.5, 0.6, 0.7, 0.8" +Zygote = "0.6.60, 0.7" julia = "1.9" [extras] diff --git a/src/angular_spectrum.jl b/src/angular_spectrum.jl index b3ccfc9..4dd6b51 100644 --- a/src/angular_spectrum.jl +++ b/src/angular_spectrum.jl @@ -28,7 +28,7 @@ function _prepare_angular_spectrum(field::AbstractArray{CT}, z, λ, _L; fieldp = padding ? pad(field, pad_factor2) : field # helpful propagation variables - (; k, f_x, f_y, x, y) = Zygote.@ignore _propagation_variables(fieldp, λ, Lp) + (; k, f_x, f_y, x, y) = ChainRulesCore.@ignore_derivatives _propagation_variables(fieldp, λ, Lp) # transfer function kernel of angular spectrum H = exp.(1im .* k .* abs.(z) .* sqrt.(CT(1) .- abs2.(f_x .* λ) .- abs2.(f_y .* λ))) @@ -40,13 +40,13 @@ function _prepare_angular_spectrum(field::AbstractArray{CT}, z, λ, _L; # as addition we introduce a smooth bandlimit with a Hann window # and fuzzy logic Δu = 1 ./ Lp - u_limit = Zygote.@ignore 1 ./ (sqrt.((2 .* Δu .* z).^2 .+ 1) .* λ) + u_limit = ChainRulesCore.@ignore_derivatives 1 ./ (sqrt.((2 .* Δu .* z).^2 .+ 1) .* λ) # y and x positions in real space, use correct spacing -> fftpos y1 = similar(field, real(eltype(field)), (size(field, 1), 1)) - Zygote.@ignore y1 .= (fftpos(L[1], size(field, 1), CenterFT)) + ChainRulesCore.@ignore_derivatives y1 .= (fftpos(L[1], size(field, 1), CenterFT)) x1 = similar(field, real(eltype(field)), (1, size(field, 2))) - Zygote.@ignore x1 .= (fftpos(L[2], size(field, 2), CenterFT))' + ChainRulesCore.@ignore_derivatives x1 .= (fftpos(L[2], size(field, 2), CenterFT))' params = Params(y1, x1, y1, x1, L, L) diff --git a/src/fraunhofer.jl b/src/fraunhofer.jl index 62ba299..1dc5d28 100644 --- a/src/fraunhofer.jl +++ b/src/fraunhofer.jl @@ -10,7 +10,7 @@ function fraunhofer(U, z, λ, L; skip_final_phase=true) L_new = λ * z / L * size(U, 1) Ns = size(U)[1:2] - p = Zygote.@ignore plan_fft(U, (1,2)) + p = ChainRulesCore.@ignore_derivatives plan_fft(U, (1,2)) if skip_final_phase out = fftshift(p * ifftshift(U)) ./ √(size(U, 1) * size(U, 2)) @@ -18,9 +18,9 @@ function fraunhofer(U, z, λ, L; skip_final_phase=true) k = eltype(U)(2π) / λ # output coordinates y = similar(U, real(eltype(U)), (Ns[1], 1)) - Zygote.@ignore y .= (fftpos(L, Ns[1], CenterFT)) + ChainRulesCore.@ignore_derivatives y .= (fftpos(L, Ns[1], CenterFT)) x = similar(U, real(eltype(U)), (1, Ns[2])) - Zygote.@ignore x .= (fftpos(L, Ns[2], CenterFT))' + ChainRulesCore.@ignore_derivatives x .= (fftpos(L, Ns[2], CenterFT))' phasefactor = (-1im) .* exp.(1im * k / (2 * z) .* (x.^2 .+ y.^2)) out = phasefactor .* fftshift(p * ifftshift(U)) ./ √(size(U, 1) * size(U, 2)) end diff --git a/src/shifted_angular_spectrum.jl b/src/shifted_angular_spectrum.jl index 6bca0bf..87c025e 100644 --- a/src/shifted_angular_spectrum.jl +++ b/src/shifted_angular_spectrum.jl @@ -26,7 +26,7 @@ function _prepare_shifted_angular_spectrum(field::AbstractArray{CT}, z, λ, L, field_new = padding ? pad(field, pad_factor2) : field # helpful propagation variables - (; k, f_x, f_y, x, y) = Zygote.@ignore _propagation_variables(field_new, λ, L_new) + (; k, f_x, f_y, x, y) = ChainRulesCore.@ignore_derivatives _propagation_variables(field_new, λ, L_new) H = exp.(1im .* k .* z .* (sqrt.(CT(1) .- abs2.(f_x .* λ .+ sxy[2]) .- abs2.(f_y .* λ .+ sxy[1])) @@ -52,9 +52,9 @@ function _prepare_shifted_angular_spectrum(field::AbstractArray{CT}, z, λ, L, shift = txy .* z ya = similar(field_new, real(eltype(field)), (size(field_new, 1), 1)) - Zygote.@ignore ya .= (fftpos(L_new[1], size(field_new, 1), CenterFT)) .+ shift[1] + ChainRulesCore.@ignore_derivatives ya .= (fftpos(L_new[1], size(field_new, 1), CenterFT)) .+ shift[1] xa = similar(field_new, real(eltype(field)), (1, size(field_new, 2))) - Zygote.@ignore xa .= (fftpos(L_new[2], size(field_new, 2), CenterFT))' .+ shift[2] + ChainRulesCore.@ignore_derivatives xa .= (fftpos(L_new[2], size(field_new, 2), CenterFT))' .+ shift[2] ramp_before = ifftshift(exp.(1im .* 2 .* T(π) ./ λ .* (sxy[2] .* x .+ sxy[1] .* y)), (1,2)) ramp_after = ifftshift(exp.(1im .* 2 .* T(π) ./ λ .* (sxy[2] .* xa .+ sxy[1] .* ya)), (1,2))