Skip to content

Commit 7e8cbe2

Browse files
authored
Add wasserstein and squared2wasserstein (#91)
1 parent 3572ba3 commit 7e8cbe2

File tree

10 files changed

+307
-10
lines changed

10 files changed

+307
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

docs/Manifest.toml

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,22 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[BenchmarkTools]]
1313
deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"]
14-
git-tree-sha1 = "068fda9b756e41e6c75da7b771e6f89fa8a43d15"
14+
git-tree-sha1 = "01ca3823217f474243cc2c8e6e1d1f45956fe872"
1515
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
16-
version = "0.7.0"
16+
version = "1.0.0"
1717

1818
[[Bzip2_jll]]
1919
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
2020
git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2"
2121
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
2222
version = "1.0.8+0"
2323

24+
[[ChainRulesCore]]
25+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
26+
git-tree-sha1 = "4b28f88cecf5d9a07c85b9ce5209a361ecaff34a"
27+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
28+
version = "0.9.45"
29+
2430
[[CodecBzip2]]
2531
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
2632
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
@@ -33,16 +39,51 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
3339
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
3440
version = "0.7.0"
3541

42+
[[Compat]]
43+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
44+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
45+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
46+
version = "3.30.0"
47+
48+
[[CompilerSupportLibraries_jll]]
49+
deps = ["Artifacts", "Libdl"]
50+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
51+
52+
[[DataAPI]]
53+
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
54+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
55+
version = "1.6.0"
56+
57+
[[DataStructures]]
58+
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
59+
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
60+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
61+
version = "0.18.9"
62+
3663
[[Dates]]
3764
deps = ["Printf"]
3865
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
3966

67+
[[DelimitedFiles]]
68+
deps = ["Mmap"]
69+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
70+
4071
[[Distances]]
4172
deps = ["LinearAlgebra", "Statistics", "StatsAPI"]
4273
git-tree-sha1 = "abe4ad222b26af3337262b8afb28fab8d215e9f8"
4374
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
4475
version = "0.10.3"
4576

77+
[[Distributed]]
78+
deps = ["Random", "Serialization", "Sockets"]
79+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
80+
81+
[[Distributions]]
82+
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
83+
git-tree-sha1 = "64a3e756c44dcf33bd33e7f500113d9992a02e92"
84+
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
85+
version = "0.25.2"
86+
4687
[[DocStringExtensions]]
4788
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
4889
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
@@ -59,11 +100,17 @@ version = "0.26.3"
59100
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
60101
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
61102

103+
[[FillArrays]]
104+
deps = ["LinearAlgebra", "Random", "SparseArrays"]
105+
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
106+
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
107+
version = "0.11.7"
108+
62109
[[HTTP]]
63110
deps = ["Base64", "Dates", "IniFile", "MbedTLS", "NetworkOptions", "Sockets", "URIs"]
64-
git-tree-sha1 = "b855bf8247d6e946c75bb30f593bfe7fe591058d"
111+
git-tree-sha1 = "86ed84701fbfd1142c9786f8e53c595ff5a4def9"
65112
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
66-
version = "0.9.8"
113+
version = "0.9.10"
67114

68115
[[IOCapture]]
69116
deps = ["Logging"]
@@ -134,6 +181,12 @@ git-tree-sha1 = "32b517d4d8219d3bbab199de3416ace45010bdb3"
134181
uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
135182
version = "2.8.0"
136183

184+
[[LogExpFunctions]]
185+
deps = ["DocStringExtensions", "LinearAlgebra"]
186+
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
187+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
188+
version = "0.2.4"
189+
137190
[[Logging]]
138191
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
139192

@@ -143,9 +196,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
143196

144197
[[MathOptInterface]]
145198
deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "JSON", "JSONSchema", "LinearAlgebra", "MutableArithmetics", "OrderedCollections", "SparseArrays", "Test", "Unicode"]
146-
git-tree-sha1 = "cd3057ca89a9ab83ce37ec42324523b8db0c60dc"
199+
git-tree-sha1 = "575644e3c05b258250bb599e57cf73bbf1062901"
147200
uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
148-
version = "0.9.21"
201+
version = "0.9.22"
149202

150203
[[MbedTLS]]
151204
deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"]
@@ -157,6 +210,12 @@ version = "1.0.3"
157210
deps = ["Artifacts", "Libdl"]
158211
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
159212

213+
[[Missings]]
214+
deps = ["DataAPI"]
215+
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
216+
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
217+
version = "1.0.0"
218+
160219
[[Mmap]]
161220
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
162221

@@ -172,17 +231,29 @@ version = "0.2.19"
172231
[[NetworkOptions]]
173232
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
174233

234+
[[OpenSpecFun_jll]]
235+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
236+
git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4"
237+
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
238+
version = "0.5.4+0"
239+
175240
[[OptimalTransport]]
176-
deps = ["Distances", "IterativeSolvers", "LinearAlgebra", "MathOptInterface", "SparseArrays"]
241+
deps = ["Distances", "Distributions", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "MathOptInterface", "QuadGK", "SparseArrays", "StatsBase"]
177242
path = ".."
178243
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
179-
version = "0.3.0"
244+
version = "0.3.6"
180245

181246
[[OrderedCollections]]
182247
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
183248
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
184249
version = "1.4.1"
185250

251+
[[PDMats]]
252+
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
253+
git-tree-sha1 = "f82a0e71f222199de8e9eb9a09977bd0767d52a0"
254+
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
255+
version = "0.11.0"
256+
186257
[[Parsers]]
187258
deps = ["Dates"]
188259
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
@@ -203,6 +274,12 @@ version = "1.2.2"
203274
deps = ["Unicode"]
204275
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
205276

277+
[[QuadGK]]
278+
deps = ["DataStructures", "LinearAlgebra"]
279+
git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b"
280+
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
281+
version = "2.4.1"
282+
206283
[[REPL]]
207284
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
208285
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
@@ -216,19 +293,47 @@ git-tree-sha1 = "b3fb709f3c97bfc6e948be68beeecb55a0b340ae"
216293
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
217294
version = "1.1.1"
218295

296+
[[Rmath]]
297+
deps = ["Random", "Rmath_jll"]
298+
git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f"
299+
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
300+
version = "0.7.0"
301+
302+
[[Rmath_jll]]
303+
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
304+
git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7"
305+
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
306+
version = "0.3.0+0"
307+
219308
[[SHA]]
220309
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
221310

222311
[[Serialization]]
223312
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
224313

314+
[[SharedArrays]]
315+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
316+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
317+
225318
[[Sockets]]
226319
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
227320

321+
[[SortingAlgorithms]]
322+
deps = ["DataStructures"]
323+
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
324+
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
325+
version = "1.0.0"
326+
228327
[[SparseArrays]]
229328
deps = ["LinearAlgebra", "Random"]
230329
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
231330

331+
[[SpecialFunctions]]
332+
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
333+
git-tree-sha1 = "371204984184315ed7228bcc604d08e1bbc18f31"
334+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
335+
version = "1.4.2"
336+
232337
[[Statistics]]
233338
deps = ["LinearAlgebra", "SparseArrays"]
234339
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -238,6 +343,22 @@ git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
238343
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
239344
version = "1.0.0"
240345

346+
[[StatsBase]]
347+
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
348+
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
349+
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
350+
version = "0.33.8"
351+
352+
[[StatsFuns]]
353+
deps = ["LogExpFunctions", "Rmath", "SpecialFunctions"]
354+
git-tree-sha1 = "30cd8c360c54081f806b1ee14d2eecbef3c04c49"
355+
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
356+
version = "0.9.8"
357+
358+
[[SuiteSparse]]
359+
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
360+
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
361+
241362
[[TOML]]
242363
deps = ["Dates"]
243364
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
2+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
34
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
45
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
56

67
[compat]
8+
Distributions = "0.25"
79
Documenter = "0.26"
810
Literate = "2.8"
911
OptimalTransport = "0.3"

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using OptimalTransport
2+
using Distributions
3+
24
using Literate: Literate
35
using Pkg: Pkg
46

docs/src/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
emd
88
emd2
99
ot_plan
10+
ot_plan(::Any, ::ContinuousUnivariateDistribution, ::UnivariateDistribution)
11+
ot_plan(::Any, ::DiscreteNonParametric, ::DiscreteNonParametric)
1012
ot_cost
13+
ot_cost(::Any, ::ContinuousUnivariateDistribution, ::UnivariateDistribution)
14+
ot_cost(::Any, ::DiscreteNonParametric, ::DiscreteNonParametric)
15+
wasserstein
16+
squared2wasserstein
1117
```
1218

1319
## Entropically regularised optimal transport

src/OptimalTransport.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ export emd, emd2
1818
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
1919
export sinkhorn_unbalanced, sinkhorn_unbalanced2
2020
export quadreg
21-
export ot_cost, ot_plan
21+
export ot_cost, ot_plan, wasserstein, squared2wasserstein
2222

2323
const MOI = MathOptInterface
2424

2525
include("exact.jl")
26+
include("wasserstein.jl")
2627

2728
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
2829
function dot_matwise(x::AbstractArray, y::AbstractMatrix)

src/exact.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
"""
2+
ot_plan(c, μ, ν; kwargs...)
3+
4+
Compute the optimal transport plan for the Monge-Kantorovich problem with source and target
5+
marginals `μ` and `ν` and cost `c`.
6+
7+
The optimal transport plan solves
8+
```math
9+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\int c(x, y) \\, \\mathrm{d}\\gamma(x, y)
10+
```
11+
where ``\\Pi(\\mu, \\nu)`` denotes the couplings of ``\\mu`` and ``\\nu``.
12+
13+
See also: [`ot_cost`](@ref)
14+
"""
15+
function ot_plan end
16+
17+
"""
18+
ot_cost(c, μ, ν; kwargs...)
19+
20+
Compute the optimal transport cost for the Monge-Kantorovich problem with source and target
21+
marginals `μ` and `ν` and cost `c`.
22+
23+
The optimal transport cost is the scalar value
24+
```math
25+
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\int c(x, y) \\, \\mathrm{d}\\gamma(x, y)
26+
```
27+
where ``\\Pi(\\mu, \\nu)`` denotes the couplings of ``\\mu`` and ``\\nu``.
28+
29+
See also: [`ot_plan`](@ref)
30+
"""
31+
function ot_cost end
32+
133
#############
234
# Discrete OT
335
#############

src/wasserstein.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
wasserstein(μ, ν; metric=Euclidean(), p=Val(1), kwargs...)
3+
4+
Compute the `p`-Wasserstein distance with respect to the `metric` between measures `μ` and
5+
`ν`.
6+
7+
Order `p` can be provided as a scalar of type `Real` or as a parameter of a value type
8+
`Val(p)`. For certain combinations of `metric` and `p`, such as `metric=Euclidean()` and
9+
`p=Val(2)`, the computations are more efficient if `p` is specified as a value type. The
10+
remaining keyword arguments are forwarded to [`ot_cost`](@ref).
11+
12+
See also: [`squared2wasserstein`](@ref), [`ot_cost`](@ref)
13+
"""
14+
function wasserstein(μ, ν; metric=Euclidean(), p::Union{Real,Val}=Val(1), kwargs...)
15+
cost = ot_cost(p2distance(metric, p), μ, ν; kwargs...)
16+
return prt(cost, p)
17+
end
18+
19+
# compute the cost function corresponding to a metric and exponent `p`
20+
p2distance(metric, ::Val{1}) = metric
21+
p2distance(metric, ::Val{P}) where {P} = (x, y) -> metric(x, y)^P
22+
p2distance(d::Euclidean, ::Val{2}) = SqEuclidean(d.thresh)
23+
p2distance(metric, p) = (x, y) -> metric(x, y)^p
24+
25+
# compute the `p` root
26+
prt(x, ::Val{1}) = x
27+
prt(x, ::Val{2}) = sqrt(x)
28+
prt(x, ::Val{3}) = cbrt(x)
29+
prt(x, ::Val{P}) where {P} = x^(1 / P)
30+
prt(x, p) = x^(1 / p)
31+
32+
"""
33+
squared2wasserstein(μ, ν; metric=Euclidean(), kwargs...)
34+
35+
Compute the squared 2-Wasserstein distance with respect to the `metric` between measures `μ`
36+
and `ν`.
37+
38+
The remaining keyword arguments are forwarded to [`ot_cost`](@ref).
39+
40+
See also: [`wasserstein`](@ref), [`ot_cost`](@ref)
41+
"""
42+
function squared2wasserstein(μ, ν; metric=Euclidean(), kwargs...)
43+
return ot_cost(p2distance(metric, Val(2)), μ, ν; kwargs...)
44+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ const GROUP = get(ENV, "GROUP", "All")
2020
@safetestset "Unbalanced OT" begin
2121
include("unbalanced.jl")
2222
end
23+
@safetestset "Wasserstein distance" begin
24+
include("wasserstein.jl")
25+
end
2326
end
2427

2528
# CUDA requires Julia >= 1.6

0 commit comments

Comments
 (0)