Skip to content

Commit 9738946

Browse files
committed
Fix LKJ numerical stability with PDMats
1 parent b441777 commit 9738946

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.7"
3+
version = "0.15.8"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -14,6 +14,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1616
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
17+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1920
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
@@ -37,8 +38,8 @@ BijectorsEnzymeCoreExt = "EnzymeCore"
3738
BijectorsForwardDiffExt = "ForwardDiff"
3839
BijectorsLazyArraysExt = "LazyArrays"
3940
BijectorsMooncakeExt = "Mooncake"
40-
BijectorsReverseDiffExt = "ReverseDiff"
4141
BijectorsReverseDiffChainRulesExt = ["ChainRules", "ReverseDiff"]
42+
BijectorsReverseDiffExt = "ReverseDiff"
4243
BijectorsTrackerExt = "Tracker"
4344
BijectorsZygoteExt = "Zygote"
4445

@@ -59,6 +60,7 @@ LazyArrays = "2"
5960
LogExpFunctions = "0.3.3"
6061
MappedArrays = "0.2.2, 0.3, 0.4"
6162
Mooncake = "0.4.95"
63+
PDMats = "0.11.35"
6264
Reexport = "0.2, 1"
6365
ReverseDiff = "1"
6466
Roots = "1.3.15, 2"

src/bijectors/corr.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,7 @@ function logabsdetjac(b::VecCorrBijector, x)
136136
end
137137

138138
function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y)
139-
U_logJ = _inv_link_chol_lkj(y)
140-
# workaround for `Tracker.TrackedTuple` not supporting iteration
141-
U, logJ = U_logJ[1], U_logJ[2]
139+
U, logJ = _inv_link_chol_lkj(y)
142140
K = size(U, 1)
143141
for j in 2:(K - 1)
144142
logJ += (K - j) * log(U[j, j])

src/utils.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using PDMats: PDMat
2+
13
# `permutedims` seems to work better with AD (cf. KernelFunctions.jl)
24
aT_b(a::AbstractVector{<:Real}, b::AbstractMatrix{<:Real}) = permutedims(a) * b
35
# `permutedims` can't be used here since scalar output is desired
@@ -11,14 +13,8 @@ _vec(x::Real) = x
1113
lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A))
1214
upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))
1315

14-
function pd_from_lower(X)
15-
L = lower_triangular(X)
16-
return L * L'
17-
end
18-
function pd_from_upper(X)
19-
U = upper_triangular(X)
20-
return U' * U
21-
end
16+
pd_from_lower(X) = PDMat(Cholesky(LowerTriangular(X)))
17+
pd_from_upper(X) = PDMat(Cholesky(UpperTriangular(X)))
2218

2319
# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
2420
transpose_eager(X::AbstractMatrix) = permutedims(X)

test/bijectors/corr.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Bijectors, DistributionsAD, LinearAlgebra, Test
22
using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
3+
using Random: Xoshiro
34

45
@testset "CorrBijector & VecCorrBijector" begin
56
for d in [1, 2, 5]
@@ -43,6 +44,20 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
4344
@test size(dist_unconstrained) == size(x)
4445
@test dist_unconstrained isa MatrixDistribution
4546
end
47+
48+
@testset "Pathological samples for invlink" begin
49+
# see https://github.com/TuringLang/Bijectors.jl/issues/387
50+
d = LKJ(3, 3.0)
51+
for i in 1:100
52+
rng = Xoshiro(i)
53+
y = randn(rng, 3) * 15
54+
f_inv = inverse(bijector(d))
55+
x = f_inv(y)
56+
@test logpdf(d, x) isa Float64 # used to crash.
57+
x, _ = with_logabsdet_jacobian(f_inv, y)
58+
@test logpdf(d, x) isa Float64
59+
end
60+
end
4661
end
4762

4863
@testset "VecCholeskyBijector" begin

0 commit comments

Comments
 (0)