|
1 | 1 | module BijectorsEnzymeCoreExt |
2 | 2 |
|
3 | | -using EnzymeCore: |
4 | | - Active, |
5 | | - Const, |
6 | | - Duplicated, |
7 | | - DuplicatedNoNeed, |
8 | | - BatchDuplicated, |
9 | | - BatchDuplicatedNoNeed, |
10 | | - EnzymeRules |
11 | | -using Bijectors: find_alpha |
12 | | - |
13 | | -# Compute a tuple of partial derivatives wrt non-`Const` arguments |
14 | | -# and `nothing`s for `Const` arguments |
15 | | -function ∂find_alpha( |
16 | | - Ω::Real, |
17 | | - wt_y::Union{Const,Active,Duplicated,BatchDuplicated}, |
18 | | - wt_u_hat::Union{Const,Active,Duplicated,BatchDuplicated}, |
19 | | - b::Union{Const,Active,Duplicated,BatchDuplicated}, |
20 | | -) |
21 | | - # We reuse the following term in the computation of the derivatives |
22 | | - Ωpb = Ω + b.val |
23 | | - c = wt_u_hat.val * sech(Ωpb)^2 |
24 | | - cp1 = c + 1 |
25 | | - |
26 | | - ∂Ω_∂wt_y = wt_y isa Const ? nothing : oneunit(wt_y.val) / cp1 |
27 | | - ∂Ω_∂wt_u_hat = wt_u_hat isa Const ? nothing : -tanh(Ωpb) / cp1 |
28 | | - ∂Ω_∂b = b isa Const ? nothing : -c / cp1 |
29 | | - |
30 | | - return (∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b) |
31 | | -end |
32 | | - |
33 | | -# `muladd` for partial derivatives that can deal with `nothing` derivatives |
34 | | -_muladd_partial(::Nothing, ::Const, x::Union{Real,Tuple{Vararg{Real}},Nothing}) = x |
35 | | -_muladd_partial(x::Real, y::Duplicated, z::Real) = muladd(x, y.dval, z) |
36 | | -_muladd_partial(x::Real, y::Duplicated, ::Nothing) = x * y.dval |
37 | | -function _muladd_partial(x::Real, y::BatchDuplicated{<:Real,N}, z::NTuple{N,Real}) where {N} |
38 | | - let x = x |
39 | | - map((a, b) -> muladd(x, a, b), y.dval, z) |
40 | | - end |
41 | | -end |
42 | | -_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval) |
43 | | - |
44 | | -function EnzymeRules.forward( |
45 | | - config::EnzymeRules.FwdConfig, |
46 | | - ::Const{typeof(find_alpha)}, |
47 | | - ::Type{RT}, |
48 | | - wt_y::Union{Const,Duplicated,BatchDuplicated}, |
49 | | - wt_u_hat::Union{Const,Duplicated,BatchDuplicated}, |
50 | | - b::Union{Const,Duplicated,BatchDuplicated}, |
51 | | -) where {RT<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}} |
52 | | - # Check that the types of the activities are consistent |
53 | | - if !( |
54 | | - RT <: Union{Const,Duplicated,DuplicatedNoNeed} && |
55 | | - wt_y isa Union{Const,Duplicated} && |
56 | | - wt_u_hat isa Union{Const,Duplicated} && |
57 | | - b isa Union{Const,Duplicated} |
58 | | - ) && !( |
59 | | - RT <: Union{Const,BatchDuplicated,BatchDuplicatedNoNeed} && |
60 | | - wt_y isa Union{Const,BatchDuplicated} && |
61 | | - wt_u_hat isa Union{Const,BatchDuplicated} && |
62 | | - b isa Union{Const,BatchDuplicated} |
63 | | - ) |
64 | | - throw(ArgumentError("inconsistent activities")) |
65 | | - end |
66 | | - |
67 | | - # Early exit: Neither primal nor shadow needed |
68 | | - if !EnzymeRules.needs_primal(config) && !EnzymeRules.needs_shadow(config) |
69 | | - return nothing |
70 | | - end |
71 | | - |
72 | | - # Compute primal value |
73 | | - Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val) |
74 | | - |
75 | | - # Early exit if no derivatives are requested |
76 | | - if !EnzymeRules.needs_shadow(config) |
77 | | - return Ω |
78 | | - end |
79 | | - |
80 | | - Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const |
81 | | - # Trivial case: All partial derivatives are 0 |
82 | | - if EnzymeRules.width(config) == 1 |
83 | | - zero(Ω) |
84 | | - else |
85 | | - ntuple(Zero(Ω), Val(EnzymeRules.width(config))) |
86 | | - end |
87 | | - else |
88 | | - # In all other cases we have to compute the partial derivatives |
89 | | - ∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b = ∂find_alpha(Ω, wt_y, wt_u_hat, b) |
90 | | - _muladd_partial( |
91 | | - ∂Ω_∂wt_y, |
92 | | - wt_y, |
93 | | - _muladd_partial(∂Ω_∂wt_u_hat, wt_u_hat, _muladd_partial(∂Ω_∂b, b, nothing)), |
94 | | - ) |
95 | | - end |
96 | | - @assert (EnzymeRules.width(config) == 1 && Ω̇ isa Real) || |
97 | | - (EnzymeRules.width(config) > 1 && Ω̇ isa NTuple{EnzymeRules.width(config),Real}) |
98 | | - |
99 | | - if EnzymeRules.needs_primal(config) |
100 | | - if EnzymeRules.width(config) == 1 |
101 | | - return Duplicated(Ω, Ω̇) |
102 | | - else |
103 | | - return BatchDuplicated(Ω, Ω̇) |
104 | | - end |
105 | | - else |
106 | | - return Ω̇ |
107 | | - end |
108 | | -end |
| 3 | +using EnzymeCore |
109 | 4 |
|
110 | | -struct Zero{T} |
111 | | - x::T |
112 | | -end |
113 | | -(f::Zero)(_) = zero(f.x) |
114 | | - |
115 | | -function EnzymeRules.augmented_primal( |
116 | | - config::EnzymeRules.RevConfig, |
117 | | - ::Const{typeof(find_alpha)}, |
118 | | - ::Type{RT}, |
119 | | - wt_y::Union{Const,Active}, |
120 | | - wt_u_hat::Union{Const,Active}, |
121 | | - b::Union{Const,Active}, |
122 | | -) where {RT<:Union{Const,Active}} |
123 | | - # Only compute the the original return value if it is actually needed |
124 | | - Ω = |
125 | | - if EnzymeRules.needs_primal(config) || |
126 | | - EnzymeRules.needs_shadow(config) || |
127 | | - !(RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const)) |
128 | | - find_alpha(wt_y.val, wt_u_hat.val, b.val) |
129 | | - else |
130 | | - nothing |
131 | | - end |
132 | | - |
133 | | - tape = if RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const) |
134 | | - # Trivial case: No differentiation or all derivatives are 0 |
135 | | - # Thus no tape is needed |
136 | | - nothing |
137 | | - else |
138 | | - # Derivatives with respect to at least one argument needed |
139 | | - # They are computed in the reverse pass, and therefore the original return is cached |
140 | | - # In principle, the partial derivatives could be computed here and be cached |
141 | | - # But Enzyme only executes the reverse pass once, |
142 | | - # thus this would not increase efficiency but instead more values would have to be cached |
143 | | - Ω |
144 | | - end |
145 | | - |
146 | | - # Ensure that we follow the interface requirements of `augmented_primal` |
147 | | - primal = EnzymeRules.needs_primal(config) ? Ω : nothing |
148 | | - shadow = if EnzymeRules.needs_shadow(config) |
149 | | - if EnzymeRules.width(config) === 1 |
150 | | - zero(Ω) |
151 | | - else |
152 | | - ntuple(Zero(Ω), Val(EnzymeRules.width(config))) |
153 | | - end |
154 | | - else |
155 | | - nothing |
156 | | - end |
157 | | - |
158 | | - return EnzymeRules.AugmentedReturn(primal, shadow, tape) |
159 | | -end |
160 | | - |
161 | | -struct ZeroOrNothing{N} end |
162 | | -(::ZeroOrNothing)(::Const) = nothing |
163 | | -(::ZeroOrNothing{1})(x::Active) = zero(x.val) |
164 | | -(::ZeroOrNothing{N})(x::Active) where {N} = ntuple(Zero(x.val), Val{N}()) |
165 | | - |
166 | | -function EnzymeRules.reverse( |
167 | | - config::EnzymeRules.RevConfig, |
168 | | - ::Const{typeof(find_alpha)}, |
169 | | - ::Type{<:Const}, |
170 | | - ::Nothing, |
171 | | - wt_y::Union{Const,Active}, |
172 | | - wt_u_hat::Union{Const,Active}, |
173 | | - b::Union{Const,Active}, |
174 | | -) |
175 | | - # Trivial case: Nothing to be differentiated (return activity is `Const`) |
176 | | - return map(ZeroOrNothing{EnzymeRules.width(config)}(), (wt_y, wt_u_hat, b)) |
177 | | -end |
178 | | -function EnzymeRules.reverse( |
179 | | - ::EnzymeRules.RevConfig, |
180 | | - ::Const{typeof(find_alpha)}, |
181 | | - ::Active, |
182 | | - ::Nothing, |
183 | | - ::Const, |
184 | | - ::Const, |
185 | | - ::Const, |
186 | | -) |
187 | | - # Trivial case: Tape does not exist sice all partial derivatives are 0 |
188 | | - return (nothing, nothing, nothing) |
189 | | -end |
190 | | - |
191 | | -struct MulPartialOrNothing{T<:Union{Real,Tuple{Vararg{Real}}}} |
192 | | - x::T |
193 | | -end |
194 | | -(::MulPartialOrNothing)(::Nothing) = nothing |
195 | | -(f::MulPartialOrNothing{<:Real})(∂f_∂x::Real) = ∂f_∂x * f.x |
196 | | -function (f::MulPartialOrNothing{<:NTuple{N,Real}})(∂f_∂x::Real) where {N} |
197 | | - return map(Base.Fix1(*, ∂f_∂x), f.x) |
198 | | -end |
| 5 | +using Bijectors: find_alpha |
199 | 6 |
|
200 | | -function EnzymeRules.reverse( |
201 | | - ::EnzymeRules.RevConfig, |
202 | | - ::Const{typeof(find_alpha)}, |
203 | | - ΔΩ::Active, |
204 | | - Ω::Real, |
205 | | - wt_y::Union{Const,Active}, |
206 | | - wt_u_hat::Union{Const,Active}, |
207 | | - b::Union{Const,Active}, |
| 7 | +EnzymeCore.EnzymeRules.@easy_rule( |
| 8 | + find_alpha(wt_y::Real, wt_u_hat::Real, b::Real), |
| 9 | + @setup(x = inv(1 + wt_u_hat * sech(Ω + b)^2),), |
| 10 | + (x, -tanh(Ω + b) * x, x - 1), |
208 | 11 | ) |
209 | | - # Tape must be `nothing` if all arguments are `Const` |
210 | | - @assert !(wt_y isa Const && wt_u_hat isa Const && b isa Const) |
211 | | - |
212 | | - # Compute partial derivatives |
213 | | - ∂Ω_∂xs = ∂find_alpha(Ω, wt_y, wt_u_hat, b) |
214 | | - return map(MulPartialOrNothing(ΔΩ.val), ∂Ω_∂xs) |
215 | | -end |
216 | 12 |
|
217 | 13 | end # module |
0 commit comments