Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey, Base.iterate, Base.keys
@forward Grads.grads Base.setindex!
@forward Grads.params Base.length

Expand All @@ -171,15 +172,15 @@ const ADictOrGrads = Union{AbstractDict, Grads}
# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x ∈ gs.params
Base.keys(gs::Grads) = gs.params
# Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forwarding keys to gs.grads while relying on gs.params for the values is not good. Either we base the dictionary interface entirely on gs.params or we base it on gs.grads, can't have mixed stuff.

The comments above

# Don't use the IdDict directly since it may contain some spurious pairs

suggests that we should think changes thoroughly or leave things as they are (which seems the best option to me)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the spurious stuff you refer to? It contains references to the objects that had gradients along the way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Carlo is referring to the existing comment at https://github.com/FluxML/Zygote.jl/pull/823/files#diff-7511b224d7f3ebb56465690de8e307422e3c9798a22bdd4e960d5c86ba6528aaR173. My understanding of that is that Base.keys(Grads.grads) may contain items not in Base.keys(Grads.params). This seems like it should never happen though, so is the comment out of date or am I missing some scenario where it could?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment is not outdated

julia> using Flux

julia> m = Chain(Dense(2,2), x->relu.(x), BatchNorm(2)) 
Chain(Dense(2, 2), #1, BatchNorm(2))

julia> gs = gradient(() -> sum(m(rand(2,2))), Flux.params(m))
Grads(...)

julia> gs.grads
IdDict{Any, Any} with 8 entries:
  Float32[0.0, 0.0]                              => [0.0, 0.0]
  BatchNorm(2)                                   => RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affi
  Float32[-0.63824 0.222623; -0.785237 0.536415] => [0.0 0.0; 0.0 0.0]
  :(Main.m)                                      => (layers = (nothing, nothing, RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 
  Box([0.0; 0.0])                                => RefValue{Any}((contents = nothing,))
  Float32[0.0, 0.0]                              => [2.0, 2.0]
  Box([0.0; 0.0])                                => RefValue{Any}((contents = nothing,))
  Float32[1.0, 1.0]                              => [0.0, 0.0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to keep the current dict interface based on gs.params

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of this is spurious though, there is no prior knowledge of what needs to be tracked at the beginning of the differentiation. The grads dictionary returns all the stuff it needed to track even if those entities weren't present in the params. They may have been indirectly needed to get the grads of the params. What we can guarantee is that the grads dictionary will always have the params as keys. So the defensive thing is to return the entire dict, so these values for the intermediaries are available to multiple levels of differentiation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then per Carlo's point, Base.values(gs::Grads) should also forward to .grads as well. Having the 2 be differently sized is unexpected (i.e. potentially subtly breaking), and arguably breaking the contract of keys + values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add that


function Base.iterate(gs::Grads, state...)
res = iterate(gs.params, state...)
isnothing(res) && return nothing
p, next_state = res
return gs[p], next_state
end
# function Base.iterate(gs::Grads, state...)
# res = iterate(gs.params, state...)
# isnothing(res) && return nothing
# p, next_state = res
# return gs[p], next_state
# end

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
Expand Down
14 changes: 1 addition & 13 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ end

@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A),
Δ -> (nothing, convert(S, Δ),)

@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)

Expand Down Expand Up @@ -731,7 +732,6 @@ end
return ((uplo=nothing, info=nothing, factors=Δ_factors),)
end
end

@adjoint function logdet(C::Cholesky)
return logdet(C), function(Δ)
return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),)
Expand All @@ -756,14 +756,11 @@ end
@adjoint function -(S::UniformScaling, A::AbstractMatrix)
return S - A, Δ->((λ=tr(Δ),), -Δ)
end

@adjoint +(A::AbstractArray, B::AbstractArray) = A + B, Δ->(Δ, Δ)
@adjoint -(A::AbstractArray, B::AbstractArray) = A - B, Δ->(Δ, -Δ)
@adjoint -(A::AbstractArray) = -A, Δ->(-Δ,)

# Abstract FFT
# ===================

# AbstractFFTs functions do not work with FillArrays, which are needed
# for some functionality of Zygote. To make it work with FillArrays
# as well, overload the relevant functions
Expand All @@ -773,56 +770,47 @@ AbstractFFTs.ifft(x::Fill, dims...) = AbstractFFTs.ifft(collect(x), dims...)
AbstractFFTs.rfft(x::Fill, dims...) = AbstractFFTs.rfft(collect(x), dims...)
AbstractFFTs.irfft(x::Fill, d, dims...) = AbstractFFTs.irfft(collect(x), d, dims...)
AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims...)

# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjoint function fft(xs)
return AbstractFFTs.fft(xs), function(Δ)
return (AbstractFFTs.bfft(Δ),)
end
end

Copy link
Member

@CarloLucibello CarloLucibello Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? these empty lines shouldn't be removed

@adjoint function *(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
N = prod(size(xs)[[P.region...]])
return (nothing, N * (P \ Δ))
end
end

@adjoint function \(P::AbstractFFTs.Plan, xs)
return P \ xs, function(Δ)
N = prod(size(Δ)[[P.region...]])
return (nothing, (P * Δ)/N)
end
end

# all of the plans normalize their inverse, while we need the unnormalized one.
@adjoint function ifft(xs)
return AbstractFFTs.ifft(xs), function(Δ)
N = length(xs)
return (AbstractFFTs.fft(Δ)/N,)
end
end

@adjoint function bfft(xs)
return AbstractFFTs.bfft(xs), function(Δ)
return (AbstractFFTs.fft(Δ),)
end
end

@adjoint function fftshift(x)
return fftshift(x), function(Δ)
return (ifftshift(Δ),)
end
end

@adjoint function ifftshift(x)
return ifftshift(x), function(Δ)
return (fftshift(Δ),)
end
end


# to actually use rfft, one needs to insure that everything
# that happens in the Fourier domain could've been done in
# the space domain with real numbers. This means enforcing
Expand Down
21 changes: 21 additions & 0 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@ end
end
end

@adjoint function Base._oidd_nextind(a, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to define an adjoint for an internal function of Base?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I couldn't see a different way at the time. I'm with you on internal functions. We ended up dropping some grads without it, which hopefully shouldn't be.

Base._oidd_nextind(a, i), Δ -> begin
(nothing, nothing)
end
end
@adjoint! function get(d::AbstractDict, k, default)
hk = Ref{Bool}()
val = if haskey(d, k)
hk[] = true
d[k]
else
hk[] = false
d[k] = default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get should not mutate, maybe you want to define the adjoiint for get!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't mutating anything in the user defined objects, so it should be fine. get! would still mutate the gradient dictionary so its not any better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are defining an adjoint for get(d::AbstractDict, k, default) that mutates d, this is not fine at all

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see what you mean

end
function back(Δ)
Δ2 = setindex!(grad_mut(__context__, d), Δ, k)
(Δ2, nothing, nothing)
end
val, back
end

# Channels

@nograd Channel
Expand Down
30 changes: 30 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,33 @@ end
@test all(abs.(gs[b]) .<= 1e-5)
end
end

@testset "Params nesting" begin
struct Dense{F,T,S}
W::T
b::S
σ::F
end

(d::Dense)(x) = d.σ.(d.W * x .+ d.b)
d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
ps = Zygote.Params([d.W, d.b])
r = ones(Float32, 3,3)

gs = gradient(ps) do
p, pb = pullback(ps) do
sum(d(r))
end
g = pb(p)
sum(g[d.W]) # + sum(g[d.b])
end

@test gs[d.W] ≈ fill(81f0, (3,3))

# Test L2
l2g = gradient(ps) do
sum(sum(x .^ 2) for x in ps)
end
@test l2g[d.W] ≈ fill(2.f0, size(d.W))
@test l2g[d.b] ≈ fill(0.f0, size(d.b))
end