Skip to content

Commit d6ba08e

Browse files
authored
Merge pull request #43 from zsteve/dw/simplex
Simplify implementation of LP formulation
2 parents 9ad8f73 + 7530304 commit d6ba08e

File tree

3 files changed

+48
-100
lines changed

3 files changed

+48
-100
lines changed

Project.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8-
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
98
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
10-
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
119
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1210
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1311
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1412
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1513

1614
[compat]
1715
Distances = "0.9.0, 0.10"
18-
FillArrays = "0.9, 0.10, 0.11"
1916
IterativeSolvers = "0.8.4, 0.9"
20-
LazyArrays = "0.18, 0.19, 0.20, 0.21"
2117
MathOptInterface = "0.9"
2218
Requires = "1.1"
2319
julia = "1"

src/OptimalTransport.jl

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ using Distances
88
using LinearAlgebra
99
using IterativeSolvers, SparseArrays
1010
using Requires
11-
using FillArrays
12-
using LazyArrays
1311
using MathOptInterface
1412

1513
export sinkhorn, sinkhorn2
@@ -27,8 +25,6 @@ function __init__()
2725
end
2826
end
2927

30-
include("simplex.jl")
31-
3228
"""
3329
emd(μ, ν, C, optimizer)
3430
@@ -43,16 +39,54 @@ The corresponding linear programming problem is solved with the user-provided `o
4339
Possible choices are `Tulip.Optimizer()` and `Clp.Optimizer()` in the `Tulip` and `Clp`
4440
packages, respectively.
4541
"""
46-
function emd(μ, ν, C, optimizer)
42+
function emd(μ, ν, C, model::MOI.ModelLike)
4743
# check size of cost matrix
48-
m = length(μ)
49-
n = length(ν)
50-
size(C) == (m, n) || error("cost matrix `C` must be of size `(length(μ), length(ν))`")
51-
52-
# solve linear programming problem
53-
c, A, b = toSimplexFormat(μ, ν, C)
54-
p = solveLP(c, A, b, optimizer)
55-
γ = reshape(p, m, n)
44+
= length(μ)
45+
= length(ν)
46+
size(C) == (nμ, nν) || error("cost matrix `C` must be of size `(length(μ), length(ν))`")
47+
nC = length(C)
48+
49+
# define variables
50+
x = MOI.add_variables(model, nC)
51+
xmat = reshape(x, nμ, nν)
52+
53+
# define objective function
54+
T = eltype(C)
55+
zero_T = zero(T)
56+
MOI.set(
57+
model,
58+
MOI.ObjectiveFunction{MOI.ScalarAffineFunction{T}}(),
59+
MOI.ScalarAffineFunction(MOI.ScalarAffineTerm.(vec(C), x), zero_T),
60+
)
61+
MOI.set(model, MOI.ObjectiveSense(), MOI.MIN_SENSE)
62+
63+
# add non-negativity constraints
64+
for xi in x
65+
MOI.add_constraint(model, MOI.SingleVariable(xi), MOI.GreaterThan(zero_T))
66+
end
67+
68+
# add constraints for source
69+
for (xs, μi) in zip(eachrow(xmat), μ)
70+
f = MOI.ScalarAffineFunction(
71+
[MOI.ScalarAffineTerm(one(μi), xi) for xi in xs], zero(μi)
72+
)
73+
MOI.add_constraint(model, f, MOI.EqualTo(μi))
74+
end
75+
76+
# add constraints for target
77+
for (xs, νi) in zip(eachcol(xmat), ν)
78+
f = MOI.ScalarAffineFunction(
79+
[MOI.ScalarAffineTerm(one(νi), xi) for xi in xs], zero(νi)
80+
)
81+
MOI.add_constraint(model, f, MOI.EqualTo(νi))
82+
end
83+
84+
# compute optimal solution
85+
MOI.optimize!(model)
86+
status = MOI.get(model, MOI.TerminationStatus())
87+
status === MOI.OPTIMAL || error("failed to compute optimal transport map: ", status)
88+
p = MOI.get(model, MOI.VariablePrimal(), x)
89+
γ = reshape(p, nμ, nν)
5690

5791
return γ
5892
end

src/simplex.jl

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)