-
-
Notifications
You must be signed in to change notification settings - Fork 216
Better handling for nesting Params #823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c81aa22
dcd1687
4404c53
33d4133
27a5521
9f07687
92431c6
e9c2a05
b91c465
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -163,6 +163,7 @@ end | |
|
|
||
| Base.show(io::IO, ps::Grads) = print(io, "Grads(...)") | ||
|
|
||
| @forward Grads.grads Base.getindex, Base.haskey, Base.iterate, Base.keys | ||
| @forward Grads.grads Base.setindex! | ||
| @forward Grads.params Base.length | ||
|
|
||
|
|
@@ -171,15 +172,15 @@ const ADictOrGrads = Union{AbstractDict, Grads} | |
| # Dictionary interface. | ||
| # Don't use the IdDict directly since it may contain some spurious pairs. | ||
| Base.haskey(gs::Grads, x) = x ∈ gs.params | ||
| Base.keys(gs::Grads) = gs.params | ||
| # Base.keys(gs::Grads) = gs.params | ||
| Base.values(gs::Grads) = (gs.grads[p] for p in gs.params) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forwarding The comments above suggests that we should think changes thoroughly or leave things as they are (which seems the best option to me)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the spurious stuff you refer to? It contains references to the objects that had gradients along the way.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Carlo is referring to the existing comment at https://github.com/FluxML/Zygote.jl/pull/823/files#diff-7511b224d7f3ebb56465690de8e307422e3c9798a22bdd4e960d5c86ba6528aaR173. My understanding of that is that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the comment is not outdated julia> using Flux
julia> m = Chain(Dense(2,2), x->relu.(x), BatchNorm(2))
Chain(Dense(2, 2), #1, BatchNorm(2))
julia> gs = gradient(() -> sum(m(rand(2,2))), Flux.params(m))
Grads(...)
julia> gs.grads
IdDict{Any, Any} with 8 entries:
Float32[0.0, 0.0] => [0.0, 0.0]
BatchNorm(2) => RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affi…
Float32[-0.63824 0.222623; -0.785237 0.536415] => [0.0 0.0; 0.0 0.0]
:(Main.m) => (layers = (nothing, nothing, RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = …
Box([0.0; 0.0]) => RefValue{Any}((contents = nothing,))
Float32[0.0, 0.0] => [2.0, 2.0]
Box([0.0; 0.0]) => RefValue{Any}((contents = nothing,))
Float32[1.0, 1.0] => [0.0, 0.0]
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have to keep the current dict interface based on
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of this is spurious though, there is no prior knowledge of what needs to be tracked at the beginning of the differentiation. The grads dictionary returns all the stuff it needed to track even if those entities weren't present in the params. They may have been indirectly needed to get the grads of the params. What we can guarantee is that the grads dictionary will always have the params as keys. So the defensive thing is to return the entire dict, so these values for the intermediaries are available to multiple levels of differentiation.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then per Carlo's point,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll add that |
||
|
|
||
| function Base.iterate(gs::Grads, state...) | ||
| res = iterate(gs.params, state...) | ||
| isnothing(res) && return nothing | ||
| p, next_state = res | ||
| return gs[p], next_state | ||
| end | ||
| # function Base.iterate(gs::Grads, state...) | ||
| # res = iterate(gs.params, state...) | ||
| # isnothing(res) && return nothing | ||
| # p, next_state = res | ||
| # return gs[p], next_state | ||
| # end | ||
|
|
||
| function Base.getindex(gs::Grads, x) | ||
| isbits(x) && error("Only reference types can be differentiated with `Params`.") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -552,6 +552,7 @@ end | |
|
|
||
| @adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A), | ||
| Δ -> (nothing, convert(S, Δ),) | ||
|
|
||
| @adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A), | ||
| Δ -> (convert(S, Δ),) | ||
|
|
||
|
|
@@ -731,7 +732,6 @@ end | |
| return ((uplo=nothing, info=nothing, factors=Δ_factors),) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function logdet(C::Cholesky) | ||
| return logdet(C), function(Δ) | ||
| return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) | ||
|
|
@@ -756,14 +756,11 @@ end | |
| @adjoint function -(S::UniformScaling, A::AbstractMatrix) | ||
| return S - A, Δ->((λ=tr(Δ),), -Δ) | ||
| end | ||
|
|
||
| @adjoint +(A::AbstractArray, B::AbstractArray) = A + B, Δ->(Δ, Δ) | ||
| @adjoint -(A::AbstractArray, B::AbstractArray) = A - B, Δ->(Δ, -Δ) | ||
| @adjoint -(A::AbstractArray) = -A, Δ->(-Δ,) | ||
|
|
||
| # Abstract FFT | ||
| # =================== | ||
|
|
||
| # AbstractFFTs functions do not work with FillArrays, which are needed | ||
| # for some functionality of Zygote. To make it work with FillArrays | ||
| # as well, overload the relevant functions | ||
|
|
@@ -773,56 +770,47 @@ AbstractFFTs.ifft(x::Fill, dims...) = AbstractFFTs.ifft(collect(x), dims...) | |
| AbstractFFTs.rfft(x::Fill, dims...) = AbstractFFTs.rfft(collect(x), dims...) | ||
| AbstractFFTs.irfft(x::Fill, d, dims...) = AbstractFFTs.irfft(collect(x), d, dims...) | ||
| AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims...) | ||
|
|
||
| # the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the | ||
| # gradient of its inputs, but with different normalization factor | ||
| @adjoint function fft(xs) | ||
| return AbstractFFTs.fft(xs), function(Δ) | ||
| return (AbstractFFTs.bfft(Δ),) | ||
| end | ||
| end | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? these empty lines shouldn't be removed |
||
| @adjoint function *(P::AbstractFFTs.Plan, xs) | ||
| return P * xs, function(Δ) | ||
| N = prod(size(xs)[[P.region...]]) | ||
| return (nothing, N * (P \ Δ)) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function \(P::AbstractFFTs.Plan, xs) | ||
| return P \ xs, function(Δ) | ||
| N = prod(size(Δ)[[P.region...]]) | ||
| return (nothing, (P * Δ)/N) | ||
| end | ||
| end | ||
|
|
||
| # all of the plans normalize their inverse, while we need the unnormalized one. | ||
| @adjoint function ifft(xs) | ||
| return AbstractFFTs.ifft(xs), function(Δ) | ||
| N = length(xs) | ||
| return (AbstractFFTs.fft(Δ)/N,) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function bfft(xs) | ||
| return AbstractFFTs.bfft(xs), function(Δ) | ||
| return (AbstractFFTs.fft(Δ),) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function fftshift(x) | ||
| return fftshift(x), function(Δ) | ||
| return (ifftshift(Δ),) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function ifftshift(x) | ||
| return ifftshift(x), function(Δ) | ||
| return (fftshift(Δ),) | ||
| end | ||
| end | ||
|
|
||
|
|
||
| # to actually use rfft, one needs to insure that everything | ||
| # that happens in the Fourier domain could've been done in | ||
| # the space domain with real numbers. This means enforcing | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,27 @@ end | |
| end | ||
| end | ||
|
|
||
| @adjoint function Base._oidd_nextind(a, i) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to define an adjoint for an internal function of Base?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately I couldn't see a different way at the time. I'm with you on internal functions. We ended up dropping some grads without it, which hopefully shouldn't be. |
||
| Base._oidd_nextind(a, i), Δ -> begin | ||
| (nothing, nothing) | ||
| end | ||
| end | ||
| @adjoint! function get(d::AbstractDict, k, default) | ||
| hk = Ref{Bool}() | ||
| val = if haskey(d, k) | ||
| hk[] = true | ||
| d[k] | ||
| else | ||
| hk[] = false | ||
| d[k] = default | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We aren't mutating anything in the user defined objects, so it should be fine.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are defining an adjoint for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I see what you mean |
||
| end | ||
| function back(Δ) | ||
| Δ2 = setindex!(grad_mut(__context__, d), Δ, k) | ||
| (Δ2, nothing, nothing) | ||
| end | ||
| val, back | ||
| end | ||
|
|
||
| # Channels | ||
|
|
||
| @nograd Channel | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.