@@ -8,8 +8,6 @@ using Distances
88using LinearAlgebra
99using IterativeSolvers, SparseArrays
1010using Requires
11- using FillArrays
12- using LazyArrays
1311using MathOptInterface
1412
1513export sinkhorn, sinkhorn2
@@ -27,8 +25,6 @@ function __init__()
2725 end
2826end
2927
30- include (" simplex.jl" )
31-
3228"""
3329 emd(μ, ν, C, optimizer)
3430
@@ -43,16 +39,54 @@ The corresponding linear programming problem is solved with the user-provided `o
4339Possible choices are `Tulip.Optimizer()` and `Clp.Optimizer()` in the `Tulip` and `Clp`
4440packages, respectively.
4541"""
46- function emd (μ, ν, C, optimizer )
42+ function emd (μ, ν, C, model :: MOI.ModelLike )
4743 # check size of cost matrix
48- m = length (μ)
49- n = length (ν)
50- size (C) == (m, n) || error (" cost matrix `C` must be of size `(length(μ), length(ν))`" )
51-
52- # solve linear programming problem
53- c, A, b = toSimplexFormat (μ, ν, C)
54- p = solveLP (c, A, b, optimizer)
55- γ = reshape (p, m, n)
44+ nμ = length (μ)
45+ nν = length (ν)
46+ size (C) == (nμ, nν) || error (" cost matrix `C` must be of size `(length(μ), length(ν))`" )
47+ nC = length (C)
48+
49+ # define variables
50+ x = MOI. add_variables (model, nC)
51+ xmat = reshape (x, nμ, nν)
52+
53+ # define objective function
54+ T = eltype (C)
55+ zero_T = zero (T)
56+ MOI. set (
57+ model,
58+ MOI. ObjectiveFunction {MOI.ScalarAffineFunction{T}} (),
59+ MOI. ScalarAffineFunction (MOI. ScalarAffineTerm .(vec (C), x), zero_T),
60+ )
61+ MOI. set (model, MOI. ObjectiveSense (), MOI. MIN_SENSE)
62+
63+ # add non-negativity constraints
64+ for xi in x
65+ MOI. add_constraint (model, MOI. SingleVariable (xi), MOI. GreaterThan (zero_T))
66+ end
67+
68+ # add constraints for source
69+ for (xs, μi) in zip (eachrow (xmat), μ)
70+ f = MOI. ScalarAffineFunction (
71+ [MOI. ScalarAffineTerm (one (μi), xi) for xi in xs], zero (μi)
72+ )
73+ MOI. add_constraint (model, f, MOI. EqualTo (μi))
74+ end
75+
76+ # add constraints for target
77+ for (xs, νi) in zip (eachcol (xmat), ν)
78+ f = MOI. ScalarAffineFunction (
79+ [MOI. ScalarAffineTerm (one (νi), xi) for xi in xs], zero (νi)
80+ )
81+ MOI. add_constraint (model, f, MOI. EqualTo (νi))
82+ end
83+
84+ # compute optimal solution
85+ MOI. optimize! (model)
86+ status = MOI. get (model, MOI. TerminationStatus ())
87+ status === MOI. OPTIMAL || error (" failed to compute optimal transport map: " , status)
88+ p = MOI. get (model, MOI. VariablePrimal (), x)
89+ γ = reshape (p, nμ, nν)
5690
5791 return γ
5892end
0 commit comments