Skip to content

Commit 3572ba3

Browse files
Add implementation of discrete 1D optimal transport (#88)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ea089f2 commit 3572ba3

File tree

4 files changed

+377
-157
lines changed

4 files changed

+377
-157
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -12,6 +12,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1212
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1313
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1516

1617
[compat]
1718
Distances = "0.9.0, 0.10"
@@ -20,6 +21,7 @@ IterativeSolvers = "0.8.4, 0.9"
2021
LogExpFunctions = "0.2"
2122
MathOptInterface = "0.9"
2223
QuadGK = "2"
24+
StatsBase = "0.33.8"
2325
julia = "1"
2426

2527
[extras]

src/OptimalTransport.jl

Lines changed: 3 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using LogExpFunctions: LogExpFunctions
1111
using MathOptInterface
1212
using Distributions
1313
using QuadGK
14+
using StatsBase: StatsBase
1415

1516
export sinkhorn, sinkhorn2
1617
export emd, emd2
@@ -21,110 +22,14 @@ export ot_cost, ot_plan
2122

2223
const MOI = MathOptInterface
2324

25+
include("exact.jl")
26+
2427
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
2528
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
2629
xmat = reshape(x, size(x, 1) * size(x, 2), :)
2730
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
2831
end
2932

30-
"""
31-
emd(μ, ν, C, optimizer)
32-
33-
Compute the optimal transport plan `γ` for the Monge-Kantorovich problem with source
34-
histogram `μ`, target histogram `ν`, and cost matrix `C` of size `(length(μ), length(ν))`
35-
which solves
36-
```math
37-
\\inf_{γ ∈ Π(μ, ν)} \\langle γ, C \\rangle.
38-
```
39-
40-
The corresponding linear programming problem is solved with the user-provided `optimizer`.
41-
Possible choices are `Tulip.Optimizer()` and `Clp.Optimizer()` in the `Tulip` and `Clp`
42-
packages, respectively.
43-
"""
44-
function emd(μ, ν, C, model::MOI.ModelLike)
45-
# check size of cost matrix
46-
= length(μ)
47-
= length(ν)
48-
size(C) == (nμ, nν) || error("cost matrix `C` must be of size `(length(μ), length(ν))`")
49-
nC = length(C)
50-
51-
# define variables
52-
x = MOI.add_variables(model, nC)
53-
xmat = reshape(x, nμ, nν)
54-
55-
# define objective function
56-
T = float(eltype(C))
57-
zero_T = zero(T)
58-
MOI.set(
59-
model,
60-
MOI.ObjectiveFunction{MOI.ScalarAffineFunction{T}}(),
61-
MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(float.(vec(C)), x), zero_T),
62-
)
63-
MOI.set(model, MOI.ObjectiveSense(), MOI.MIN_SENSE)
64-
65-
# add non-negativity constraints
66-
for xi in x
67-
MOI.add_constraint(model, MOI.SingleVariable(xi), MOI.GreaterThan(zero_T))
68-
end
69-
70-
# add constraints for source
71-
for (i, μi) in zip(axes(xmat, 1), μ) # eachrow(xmat) is not available on Julia 1.0
72-
f = MOI.ScalarAffineFunction(
73-
[MOI.ScalarAffineTerm(one(μi), xi) for xi in view(xmat, i, :)], zero(μi)
74-
)
75-
MOI.add_constraint(model, f, MOI.EqualTo(μi))
76-
end
77-
78-
# add constraints for target
79-
for (i, νi) in zip(axes(xmat, 2), ν) # eachcol(xmat) is not available on Julia 1.0
80-
f = MOI.ScalarAffineFunction(
81-
[MOI.ScalarAffineTerm(one(νi), xi) for xi in view(xmat, :, i)], zero(νi)
82-
)
83-
MOI.add_constraint(model, f, MOI.EqualTo(νi))
84-
end
85-
86-
# compute optimal solution
87-
MOI.optimize!(model)
88-
status = MOI.get(model, MOI.TerminationStatus())
89-
status === MOI.OPTIMAL || error("failed to compute optimal transport plan: ", status)
90-
p = MOI.get(model, MOI.VariablePrimal(), x)
91-
γ = reshape(p, nμ, nν)
92-
93-
return γ
94-
end
95-
96-
"""
97-
emd2(μ, ν, C, optimizer; plan=nothing)
98-
99-
Compute the optimal transport cost (a scalar) for the Monge-Kantorovich problem with source
100-
histogram `μ`, target histogram `ν`, and cost matrix `C` of size `(length(μ), length(ν))`
101-
which is given by
102-
```math
103-
\\inf_{γ ∈ Π(μ, ν)} \\langle γ, C \\rangle.
104-
```
105-
106-
The corresponding linear programming problem is solved with the user-provided `optimizer`.
107-
Possible choices are `Tulip.Optimizer()` and `Clp.Optimizer()` in the `Tulip` and `Clp`
108-
packages, respectively.
109-
110-
A pre-computed optimal transport `plan` may be provided.
111-
"""
112-
function emd2(μ, ν, C, optimizer; plan=nothing)
113-
γ = if plan === nothing
114-
# compute optimal transport plan
115-
emd(μ, ν, C, optimizer)
116-
else
117-
# check dimensions
118-
size(C) == (length(μ), length(ν)) ||
119-
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
120-
size(plan) == size(C) || error(
121-
"optimal transport plan `plan` and cost matrix `C` must be of the same size",
122-
)
123-
plan
124-
end
125-
return dot(γ, C)
126-
end
127-
12833
"""
12934
sinkhorn_gibbs(
13035
μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -879,62 +784,4 @@ function quadreg(mu, nu, C, ϵ; θ=0.1, tol=1e-5, maxiter=50, κ=0.5, δ=1e-5)
879784
return sparse')
880785
end
881786

882-
"""
883-
ot_cost(
884-
c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution; plan=nothing
885-
)
886-
887-
Compute the optimal transport cost for the Monge-Kantorovich problem with univariate
888-
distributions `μ` and `ν` as source and target marginals and cost function `c` of
889-
the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.
890-
891-
In this setting, the optimal transport cost can be computed as
892-
```math
893-
\\int_0^1 c(F_\\mu^{-1}(x), F_\\nu^{-1}(x)) \\mathrm{d}x
894-
```
895-
where ``F_\\mu^{-1}`` and ``F_\\nu^{-1}`` are the quantile functions of `μ` and `ν`,
896-
respectively.
897-
898-
A pre-computed optimal transport `plan` may be provided.
899-
900-
See also: [`ot_plan`](@ref), [`emd2`](@ref)
901-
"""
902-
function ot_cost(
903-
c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution; plan=nothing
904-
)
905-
cost, _ = if plan === nothing
906-
quadgk(0, 1) do q
907-
return c(quantile(μ, q), quantile(ν, q))
908-
end
909-
else
910-
quadgk(0, 1) do q
911-
x = quantile(μ, q)
912-
return c(x, plan(x))
913-
end
914-
end
915-
return cost
916-
end
917-
918-
"""
919-
ot_plan(c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution)
920-
921-
Compute the optimal transport plan for the Monge-Kantorovich problem with univariate
922-
distributions `μ` and `ν` as source and target marginals and cost function `c` of
923-
the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.
924-
925-
In this setting, the optimal transport plan is the Monge map
926-
```math
927-
T = F_\\nu^{-1} \\circ F_\\mu
928-
```
929-
where ``F_\\mu`` is the cumulative distribution function of `μ` and ``F_\\nu^{-1}`` is the
930-
quantile function of `ν`.
931-
932-
See also: [`ot_cost`](@ref), [`emd`](@ref)
933-
"""
934-
function ot_plan(c, μ::ContinuousUnivariateDistribution, ν::UnivariateDistribution)
935-
# Use T instead of γ to indicate that this is a Monge map.
936-
T(x) = quantile(ν, cdf(μ, x))
937-
return T
938-
end
939-
940787
end

0 commit comments

Comments
 (0)