Skip to content

Commit 714a519

Browse files
wsmosespenelopeysm
andauthored
Enzyme: migrate to easy_rule (#420)
* Enzyme: migrate to easy_rule * Bump patch * Add more links for the 1.11 LLVM issue --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 3334fc5 commit 714a519

File tree

4 files changed

+60
-254
lines changed

4 files changed

+60
-254
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.15.12
2+
3+
Improved implementation of the Enzyme rule for `Bijectors.find_alpha`.
4+
15
# 0.15.11
26

37
Bijectors for ProductNamedTupleDistribution are now implemented.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.11"
3+
version = "0.15.12"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -46,7 +46,7 @@ ChangesOfVariables = "0.1"
4646
Distributions = "0.25.33"
4747
DistributionsAD = "0.6"
4848
DocStringExtensions = "0.9"
49-
EnzymeCore = "0.8.4"
49+
EnzymeCore = "0.8.15"
5050
ForwardDiff = "0.10, 1.0.1"
5151
Functors = "0.1, 0.2, 0.3, 0.4, 0.5"
5252
InverseFunctions = "0.1"

ext/BijectorsEnzymeCoreExt.jl

Lines changed: 6 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -1,217 +1,13 @@
11
module BijectorsEnzymeCoreExt
22

3-
using EnzymeCore:
4-
Active,
5-
Const,
6-
Duplicated,
7-
DuplicatedNoNeed,
8-
BatchDuplicated,
9-
BatchDuplicatedNoNeed,
10-
EnzymeRules
11-
using Bijectors: find_alpha
12-
13-
# Compute a tuple of partial derivatives wrt non-`Const` arguments
14-
# and `nothing`s for `Const` arguments
15-
function ∂find_alpha(
16-
Ω::Real,
17-
wt_y::Union{Const,Active,Duplicated,BatchDuplicated},
18-
wt_u_hat::Union{Const,Active,Duplicated,BatchDuplicated},
19-
b::Union{Const,Active,Duplicated,BatchDuplicated},
20-
)
21-
# We reuse the following term in the computation of the derivatives
22-
Ωpb = Ω + b.val
23-
c = wt_u_hat.val * sech(Ωpb)^2
24-
cp1 = c + 1
25-
26-
∂Ω_∂wt_y = wt_y isa Const ? nothing : oneunit(wt_y.val) / cp1
27-
∂Ω_∂wt_u_hat = wt_u_hat isa Const ? nothing : -tanh(Ωpb) / cp1
28-
∂Ω_∂b = b isa Const ? nothing : -c / cp1
29-
30-
return (∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b)
31-
end
32-
33-
# `muladd` for partial derivatives that can deal with `nothing` derivatives
34-
_muladd_partial(::Nothing, ::Const, x::Union{Real,Tuple{Vararg{Real}},Nothing}) = x
35-
_muladd_partial(x::Real, y::Duplicated, z::Real) = muladd(x, y.dval, z)
36-
_muladd_partial(x::Real, y::Duplicated, ::Nothing) = x * y.dval
37-
function _muladd_partial(x::Real, y::BatchDuplicated{<:Real,N}, z::NTuple{N,Real}) where {N}
38-
let x = x
39-
map((a, b) -> muladd(x, a, b), y.dval, z)
40-
end
41-
end
42-
_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval)
43-
44-
function EnzymeRules.forward(
45-
config::EnzymeRules.FwdConfig,
46-
::Const{typeof(find_alpha)},
47-
::Type{RT},
48-
wt_y::Union{Const,Duplicated,BatchDuplicated},
49-
wt_u_hat::Union{Const,Duplicated,BatchDuplicated},
50-
b::Union{Const,Duplicated,BatchDuplicated},
51-
) where {RT<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}}
52-
# Check that the types of the activities are consistent
53-
if !(
54-
RT <: Union{Const,Duplicated,DuplicatedNoNeed} &&
55-
wt_y isa Union{Const,Duplicated} &&
56-
wt_u_hat isa Union{Const,Duplicated} &&
57-
b isa Union{Const,Duplicated}
58-
) && !(
59-
RT <: Union{Const,BatchDuplicated,BatchDuplicatedNoNeed} &&
60-
wt_y isa Union{Const,BatchDuplicated} &&
61-
wt_u_hat isa Union{Const,BatchDuplicated} &&
62-
b isa Union{Const,BatchDuplicated}
63-
)
64-
throw(ArgumentError("inconsistent activities"))
65-
end
66-
67-
# Early exit: Neither primal nor shadow needed
68-
if !EnzymeRules.needs_primal(config) && !EnzymeRules.needs_shadow(config)
69-
return nothing
70-
end
71-
72-
# Compute primal value
73-
Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val)
74-
75-
# Early exit if no derivatives are requested
76-
if !EnzymeRules.needs_shadow(config)
77-
return Ω
78-
end
79-
80-
Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const
81-
# Trivial case: All partial derivatives are 0
82-
if EnzymeRules.width(config) == 1
83-
zero(Ω)
84-
else
85-
ntuple(Zero(Ω), Val(EnzymeRules.width(config)))
86-
end
87-
else
88-
# In all other cases we have to compute the partial derivatives
89-
∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b = ∂find_alpha(Ω, wt_y, wt_u_hat, b)
90-
_muladd_partial(
91-
∂Ω_∂wt_y,
92-
wt_y,
93-
_muladd_partial(∂Ω_∂wt_u_hat, wt_u_hat, _muladd_partial(∂Ω_∂b, b, nothing)),
94-
)
95-
end
96-
@assert (EnzymeRules.width(config) == 1 && Ω̇ isa Real) ||
97-
(EnzymeRules.width(config) > 1 && Ω̇ isa NTuple{EnzymeRules.width(config),Real})
98-
99-
if EnzymeRules.needs_primal(config)
100-
if EnzymeRules.width(config) == 1
101-
return Duplicated(Ω, Ω̇)
102-
else
103-
return BatchDuplicated(Ω, Ω̇)
104-
end
105-
else
106-
return Ω̇
107-
end
108-
end
3+
using EnzymeCore
1094

110-
struct Zero{T}
111-
x::T
112-
end
113-
(f::Zero)(_) = zero(f.x)
114-
115-
function EnzymeRules.augmented_primal(
116-
config::EnzymeRules.RevConfig,
117-
::Const{typeof(find_alpha)},
118-
::Type{RT},
119-
wt_y::Union{Const,Active},
120-
wt_u_hat::Union{Const,Active},
121-
b::Union{Const,Active},
122-
) where {RT<:Union{Const,Active}}
123-
# Only compute the the original return value if it is actually needed
124-
Ω =
125-
if EnzymeRules.needs_primal(config) ||
126-
EnzymeRules.needs_shadow(config) ||
127-
!(RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const))
128-
find_alpha(wt_y.val, wt_u_hat.val, b.val)
129-
else
130-
nothing
131-
end
132-
133-
tape = if RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const)
134-
# Trivial case: No differentiation or all derivatives are 0
135-
# Thus no tape is needed
136-
nothing
137-
else
138-
# Derivatives with respect to at least one argument needed
139-
# They are computed in the reverse pass, and therefore the original return is cached
140-
# In principle, the partial derivatives could be computed here and be cached
141-
# But Enzyme only executes the reverse pass once,
142-
# thus this would not increase efficiency but instead more values would have to be cached
143-
Ω
144-
end
145-
146-
# Ensure that we follow the interface requirements of `augmented_primal`
147-
primal = EnzymeRules.needs_primal(config) ? Ω : nothing
148-
shadow = if EnzymeRules.needs_shadow(config)
149-
if EnzymeRules.width(config) === 1
150-
zero(Ω)
151-
else
152-
ntuple(Zero(Ω), Val(EnzymeRules.width(config)))
153-
end
154-
else
155-
nothing
156-
end
157-
158-
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
159-
end
160-
161-
struct ZeroOrNothing{N} end
162-
(::ZeroOrNothing)(::Const) = nothing
163-
(::ZeroOrNothing{1})(x::Active) = zero(x.val)
164-
(::ZeroOrNothing{N})(x::Active) where {N} = ntuple(Zero(x.val), Val{N}())
165-
166-
function EnzymeRules.reverse(
167-
config::EnzymeRules.RevConfig,
168-
::Const{typeof(find_alpha)},
169-
::Type{<:Const},
170-
::Nothing,
171-
wt_y::Union{Const,Active},
172-
wt_u_hat::Union{Const,Active},
173-
b::Union{Const,Active},
174-
)
175-
# Trivial case: Nothing to be differentiated (return activity is `Const`)
176-
return map(ZeroOrNothing{EnzymeRules.width(config)}(), (wt_y, wt_u_hat, b))
177-
end
178-
function EnzymeRules.reverse(
179-
::EnzymeRules.RevConfig,
180-
::Const{typeof(find_alpha)},
181-
::Active,
182-
::Nothing,
183-
::Const,
184-
::Const,
185-
::Const,
186-
)
187-
# Trivial case: Tape does not exist sice all partial derivatives are 0
188-
return (nothing, nothing, nothing)
189-
end
190-
191-
struct MulPartialOrNothing{T<:Union{Real,Tuple{Vararg{Real}}}}
192-
x::T
193-
end
194-
(::MulPartialOrNothing)(::Nothing) = nothing
195-
(f::MulPartialOrNothing{<:Real})(∂f_∂x::Real) = ∂f_∂x * f.x
196-
function (f::MulPartialOrNothing{<:NTuple{N,Real}})(∂f_∂x::Real) where {N}
197-
return map(Base.Fix1(*, ∂f_∂x), f.x)
198-
end
5+
using Bijectors: find_alpha
1996

200-
function EnzymeRules.reverse(
201-
::EnzymeRules.RevConfig,
202-
::Const{typeof(find_alpha)},
203-
ΔΩ::Active,
204-
Ω::Real,
205-
wt_y::Union{Const,Active},
206-
wt_u_hat::Union{Const,Active},
207-
b::Union{Const,Active},
7+
EnzymeCore.EnzymeRules.@easy_rule(
8+
find_alpha(wt_y::Real, wt_u_hat::Real, b::Real),
9+
@setup(x = inv(1 + wt_u_hat * sech+ b)^2),),
10+
(x, -tanh+ b) * x, x - 1),
20811
)
209-
# Tape must be `nothing` if all arguments are `Const`
210-
@assert !(wt_y isa Const && wt_u_hat isa Const && b isa Const)
211-
212-
# Compute partial derivatives
213-
∂Ω_∂xs = ∂find_alpha(Ω, wt_y, wt_u_hat, b)
214-
return map(MulPartialOrNothing(ΔΩ.val), ∂Ω_∂xs)
215-
end
21612

21713
end # module

test/ad/enzyme.jl

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,51 +9,57 @@ using Test
99
#
1010
# https://github.com/EnzymeAD/Enzyme.jl/issues/2121
1111
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2470766968
12+
#
13+
# The fix to this needs to be made in Julia itself: it seems that this has already been done
14+
# in https://github.com/JuliaLang/llvm-project/pull/49 although whether this will be
15+
# incorporated into the built Julia version itself seems unclear. See
16+
# https://github.com/JuliaLang/julia/pull/59521#issuecomment-3300480633.
1217
#
13-
# Ideally we'd use `@test_throws`. However, that doesn't work because
14-
# `test_forward` itself calls `@test`, and the error is captured by that
15-
# `@test`, not our `@test_throws`. Consequently `@test_throws` doesn't actually
16-
# see any error. Weird Julia behaviour.
17-
18-
@static if VERSION < v"1.11"
19-
@testset "Enzyme: Bijectors.find_alpha" begin
20-
x = randn()
21-
y = expm1(randn())
22-
z = randn()
23-
24-
@testset "forward" begin
25-
# No batches
26-
@testset for RT in (Const, Duplicated, DuplicatedNoNeed),
27-
Tx in (Const, Duplicated),
28-
Ty in (Const, Duplicated),
29-
Tz in (Const, Duplicated)
30-
31-
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
32-
end
33-
34-
# Batches
35-
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
36-
Tx in (Const, BatchDuplicated),
37-
Ty in (Const, BatchDuplicated),
38-
Tz in (Const, BatchDuplicated)
39-
40-
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
41-
end
18+
# If this does not end up being backported to 1.11, then we may have to permanently skip
19+
# these tests.
20+
#
21+
# On another note: Ideally we'd use `@test_throws`. However, that doesn't work because
22+
# `test_forward` itself calls `@test`, and the error is captured by that `@test`, not our
23+
# `@test_throws`. Consequently `@test_throws` doesn't actually see any error. Weird Julia
24+
# behaviour.
25+
26+
@testset "Enzyme: Bijectors.find_alpha" begin
27+
x = randn()
28+
y = expm1(randn())
29+
z = randn()
30+
31+
@testset "forward" begin
32+
# No batches
33+
@testset for RT in (Const, Duplicated, DuplicatedNoNeed),
34+
Tx in (Const, Duplicated),
35+
Ty in (Const, Duplicated),
36+
Tz in (Const, Duplicated)
37+
38+
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
39+
end
40+
41+
# Batches
42+
@testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed),
43+
Tx in (Const, BatchDuplicated),
44+
Ty in (Const, BatchDuplicated),
45+
Tz in (Const, BatchDuplicated)
46+
47+
test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
4248
end
43-
@testset "reverse" begin
44-
# No batches
45-
@testset for RT in (Const, Active),
46-
Tx in (Const, Active),
47-
Ty in (Const, Active),
48-
Tz in (Const, Active)
49-
50-
test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
51-
end
52-
53-
# TODO: Test batch mode
54-
# This is a bit problematic since Enzyme does not support all combinations of activities currently
55-
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2480468728
49+
end
50+
@testset "reverse" begin
51+
# No batches
52+
@testset for RT in (Const, Active),
53+
Tx in (Const, Active),
54+
Ty in (Const, Active),
55+
Tz in (Const, Active)
56+
57+
test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz))
5658
end
59+
60+
# TODO: Test batch mode
61+
# This is a bit problematic since Enzyme does not support all combinations of activities currently
62+
# https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2480468728
5763
end
5864
end
5965

0 commit comments

Comments
 (0)