Skip to content

Commit fb23ca3

Browse files
Optimal Transport for Multivariate Gaussians (#85)
Co-authored-by: David Widmann <[email protected]>
1 parent 4a26ec9 commit fb23ca3

File tree

8 files changed

+243
-2
lines changed

8 files changed

+243
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ docs/src/examples/
3232

3333
# Files generated by Jupyter Notebooks
3434
*.ipynb_checkpoints
35+
*.ipynb

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -10,6 +10,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1212
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
13+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1314
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1415
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -20,6 +21,7 @@ Distributions = "0.25"
2021
IterativeSolvers = "0.8.4, 0.9"
2122
LogExpFunctions = "0.2"
2223
MathOptInterface = "0.9"
24+
PDMats = "0.11"
2325
QuadGK = "2"
2426
StatsBase = "0.33.8"
2527
julia = "1"
@@ -32,6 +34,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3234
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3335
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3436
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
37+
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
3538

3639
[targets]
37-
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]
40+
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip", "HCubature"]

src/OptimalTransport.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using IterativeSolvers, SparseArrays
1010
using LogExpFunctions: LogExpFunctions
1111
using MathOptInterface
1212
using Distributions
13+
using PDMats
1314
using QuadGK
1415
using StatsBase: StatsBase
1516

@@ -22,6 +23,7 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein
2223

2324
const MOI = MathOptInterface
2425

26+
include("distances/bures.jl")
2527
include("utils.jl")
2628
include("exact.jl")
2729
include("wasserstein.jl")

src/distances/bures.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Code from @devmotion
2+
# https://github.com/devmotion/\
3+
# CalibrationErrorsDistributions.jl/blob/main/src/distances/bures.jl
4+
5+
"""
6+
tr_sqrt(A::AbstractMatrix)
7+
8+
Compute ``\\operatorname{tr}\\big(A^{1/2}\\big)``.
9+
"""
10+
tr_sqrt(A::AbstractMatrix) = LinearAlgebra.tr(sqrt(A))
11+
tr_sqrt(A::PDMats.PDMat) = tr_sqrt(A.mat)
12+
tr_sqrt(A::PDMats.PDiagMat) = sum(sqrt, A.diag)
13+
tr_sqrt(A::PDMats.ScalMat) = A.dim * sqrt(A.value)
14+
15+
"""
16+
_gaussian_ot_A(A::AbstractMatrix, B::AbstractMatrix)
17+
18+
Compute
19+
```math
20+
A^{1/2} B A^{1/2}.
21+
```
22+
"""
23+
function _gaussian_ot_A(A::AbstractMatrix, B::AbstractMatrix)
24+
sqrt_A = sqrt(A)
25+
return sqrt_A * B * sqrt_A
26+
end
27+
function _gaussian_ot_A(A::PDMats.PDiagMat, B::AbstractMatrix)
28+
return sqrt.(A.diag) .* B .* sqrt.(A.diag')
29+
end
30+
function _gaussian_ot_A(A::StridedMatrix, B::PDMats.PDMat)
31+
return PDMats.X_A_Xt(B, sqrt(A))
32+
end
33+
_gaussian_ot_A(A::PDMats.PDMat, B::PDMats.PDMat) = _gaussian_ot_A(A.mat, B)
34+
_gaussian_ot_A(A::AbstractMatrix, B::PDMats.PDiagMat) = _gaussian_ot_A(B, A)
35+
_gaussian_ot_A(A::PDMats.PDMat, B::StridedMatrix) = _gaussian_ot_A(B, A)
36+
37+
"""
38+
sqbures(A::AbstractMatrix, B::AbstractMatrix)
39+
40+
Compute the squared Bures metric
41+
```math
42+
\\operatorname{tr}(A) + \\operatorname{tr}(B)
43+
- \\operatorname{tr}\\Big({\\big(A^{1/2} B A^{1/2}\\big)}^{1/2}\\Big).
44+
```
45+
"""
46+
function sqbures(A::AbstractMatrix, B::AbstractMatrix)
47+
return LinearAlgebra.tr(A) + LinearAlgebra.tr(B) - 2 * tr_sqrt(_gaussian_ot_A(A, B))
48+
end
49+
50+
# diagonal matrix
51+
function sqbures(A::PDMats.PDiagMat, B::PDMats.PDiagMat)
52+
if !(A.dim == B.dim)
53+
throw(ArgumentError("matrices must have the same dimensions."))
54+
end
55+
return sum(zip(A.diag, B.diag)) do (x, y)
56+
abs2(sqrt(x) - sqrt(y))
57+
end
58+
end
59+
60+
# scaled identity matrix
61+
function sqbures(A::PDMats.ScalMat, B::AbstractMatrix)
62+
return LinearAlgebra.tr(A) + LinearAlgebra.tr(B) - 2 * sqrt(A.value) * tr_sqrt(B)
63+
end
64+
sqbures(A::AbstractMatrix, B::PDMats.ScalMat) = sqbures(B, A)
65+
sqbures(A::PDMats.ScalMat, B::PDMats.ScalMat) = A.dim * abs2(sqrt(A.value) - sqrt(B.value))
66+
67+
# combinations
68+
function sqbures(A::PDMats.PDiagMat, B::PDMats.ScalMat)
69+
sqrt_B = sqrt(B.value)
70+
return sum(A.diag) do x
71+
abs2(sqrt(x) - sqrt_B)
72+
end
73+
end
74+
sqbures(A::PDMats.ScalMat, B::PDMats.PDiagMat) = sqbures(B, A)

src/exact.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,82 @@ end
345345
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, plan)
346346
return dot(plan, StatsBase.pairwise(c, support(μ), support(ν)))
347347
end
348+
349+
################
350+
# OT Gaussians
351+
################
352+
353+
"""
354+
ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal)
355+
356+
Compute the squared 2-Wasserstein distance between normal distributions `μ` and `ν` as
357+
source and target marginals.
358+
359+
In this setting, the optimal transport cost can be computed as
360+
```math
361+
W_2^2(\\mu, \\nu) = \\|m_\\mu - m_\\nu \\|^2 + \\mathcal{B}(\\Sigma_\\mu, \\Sigma_\\nu)^2,
362+
```
363+
where ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)``,
364+
``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, and ``\\mathcal{B}`` is the Bures metric.
365+
366+
See also: [`ot_plan`](@ref), [`emd2`](@ref)
367+
"""
368+
function ot_cost(::SqEuclidean, μ::MvNormal, ν::MvNormal)
369+
return sqeuclidean.μ, ν.μ) + sqbures.Σ, ν.Σ)
370+
end
371+
372+
"""
373+
ot_cost(::SqEuclidean, μ::Normal, ν::Normal)
374+
375+
Compute the squared 2-Wasserstein distance between univariate normal distributions `μ` and
376+
`ν` as source and target marginals.
377+
378+
See also: [`ot_plan`](@ref), [`emd2`](@ref)
379+
"""
380+
function ot_cost(::SqEuclidean, μ::Normal, ν::Normal)
381+
return.μ - ν.μ)^2 +.σ - ν.σ)^2
382+
end
383+
384+
"""
385+
ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal)
386+
387+
Compute the optimal transport plan for the Monge-Kantorovich problem with multivariate
388+
normal distributions `μ` and `ν` as source and target marginals and cost function
389+
``c(x, y) = \\|x - y\\|_2^2``.
390+
391+
In this setting, for ``\\mu = \\mathcal{N}(m_\\mu, \\Sigma_\\mu)`` and
392+
``\\nu = \\mathcal{N}(m_\\nu, \\Sigma_\\nu)``, the optimal transport plan is the Monge
393+
map
394+
```math
395+
T \\colon x \\mapsto m_\\nu
396+
+ \\Sigma_\\mu^{-1/2}
397+
{\\big(\\Sigma_\\mu^{1/2} \\Sigma_\\nu \\Sigma_\\mu^{1/2}\\big)}^{1/2}\\Sigma_\\mu^{-1/2}
398+
(x - m_\\mu).
399+
400+
See also: [`ot_cost`](@ref), [`emd`](@ref)
401+
"""
402+
function ot_plan(::SqEuclidean, μ::MvNormal, ν::MvNormal)
403+
Σμsqrt = μ.Σ^(-1 / 2)
404+
A = Σμsqrt * sqrt(_gaussian_ot_A.Σ, ν.Σ)) * Σμsqrt
405+
= μ.μ
406+
= ν.μ
407+
T(x) =+ A * (x - mμ)
408+
return T
409+
end
410+
411+
"""
412+
ot_plan(::SqEuclidean, μ::Normal, ν::Normal)
413+
414+
Compute the optimal transport plan for the Monge-Kantorovich problem with
415+
normal distributions `μ` and `ν` as source and target marginals and cost function
416+
``c(x, y) = \\|x - y\\|_2^2``.
417+
418+
See also: [`ot_cost`](@ref), [`emd`](@ref)
419+
"""
420+
function ot_plan(::SqEuclidean, μ::Normal, ν::Normal)
421+
= μ.μ
422+
= ν.μ
423+
a = ν.σ / μ.σ
424+
T(x) =+ a * (x - mμ)
425+
return T
426+
end

test/bures.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Code from @devmotion
2+
# https://github.com/devmotion/\
3+
# CalibrationErrorsDistributions.jl/blob/main/src/distances/bures.jl
4+
using OptimalTransport
5+
6+
using LinearAlgebra
7+
using Random
8+
using PDMats
9+
10+
@testset "bures.jl" begin
11+
function _sqbures(A, B)
12+
sqrt_A = sqrt(A)
13+
return tr(A) + tr(B) - 2 * tr(sqrt(sqrt_A * B * sqrt_A'))
14+
end
15+
16+
function rand_matrices(n)
17+
A = randn(n, n)
18+
B = A' * A + I
19+
return B, PDMat(B), PDiagMat(diag(B)), ScalMat(n, B[1])
20+
end
21+
22+
for (x, y) in Iterators.product(rand_matrices(10), rand_matrices(10))
23+
xfull = Matrix(x)
24+
yfull = Matrix(y)
25+
@test OptimalTransport.sqbures(x, y) _sqbures(xfull, yfull)
26+
end
27+
end

test/exact.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using PythonOT: PythonOT
55
using Tulip
66
using MathOptInterface
77
using Distributions
8+
using HCubature
89

910
using LinearAlgebra
1011
using Random
@@ -164,4 +165,55 @@ Random.seed!(100)
164165
@test c2 c
165166
end
166167
end
168+
169+
@testset "Multivariate Gaussians" begin
170+
@testset "translation with constant covariance" begin
171+
m = randn(100)
172+
τ = rand(100)
173+
Σ = Matrix(Hermitian(rand(100, 100) + 100I))
174+
μ = MvNormal(m, Σ)
175+
ν = MvNormal(m .+ τ, Σ)
176+
@test ot_cost(SqEuclidean(), μ, ν) norm(τ)^2
177+
178+
x = rand(100, 10)
179+
T = ot_plan(SqEuclidean(), μ, ν)
180+
@test pdf(ν, mapslices(T, x; dims=1)) pdf(μ, x)
181+
end
182+
183+
@testset "comparison to grid approximation" begin
184+
μ = MvNormal([0, 0], [1 0; 0 2])
185+
ν = MvNormal([10, 10], [2 0; 0 1])
186+
# Constructing circular grid approximation
187+
# Angular grid step
188+
θ = collect(0:0.2:(2π))
189+
θx = cos.(θ)
190+
θy = sin.(θ)
191+
# Radius grid step
192+
δ = collect(0:0.2:1)
193+
μsupp = [0.0 0.0]
194+
νsupp = [10.0 10.0]
195+
for i in δ[2:end]
196+
a = [θx .* i θy .* i * 2]
197+
b = [θx .* i * 2 θy .* i] .+ [10 10]
198+
μsupp = vcat(μsupp, a)
199+
νsupp = vcat(νsupp, b)
200+
end
201+
202+
# Create discretized distribution
203+
μprobs = pdf(μ, μsupp')
204+
μprobs = μprobs ./ sum(μprobs)
205+
νprobs = pdf(ν, νsupp')
206+
νprobs = νprobs ./ sum(νprobs)
207+
C = pairwise(SqEuclidean(), μsupp', νsupp')
208+
@test emd2(μprobs, νprobs, C, Tulip.Optimizer()) ot_cost(SqEuclidean(), μ, ν) rtol =
209+
1e-3
210+
211+
# Use hcubature integration to perform ``\\int c(x,T(x)) d\\mu``
212+
T = ot_plan(SqEuclidean(), μ, ν)
213+
c_hcubature, _ = hcubature([-10, -10], [10, 10]) do x
214+
return sqeuclidean(x, T(x)) * pdf(μ, x)
215+
end
216+
@test ot_cost(SqEuclidean(), μ, ν) c_hcubature rtol = 1e-3
217+
end
218+
end
167219
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ const GROUP = get(ENV, "GROUP", "All")
2626
@safetestset "Wasserstein distance" begin
2727
include("wasserstein.jl")
2828
end
29+
@safetestset "Bures distance" begin
30+
include("bures.jl")
31+
end
2932
end
3033

3134
# CUDA requires Julia >= 1.6

0 commit comments

Comments
 (0)