Skip to content

Commit dc4dd48

Browse files
authored
Improve sinkhorn_gibbs (#90)
1 parent 7e8cbe2 commit dc4dd48

File tree

7 files changed

+329
-109
lines changed

7 files changed

+329
-109
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -25,6 +25,7 @@ StatsBase = "0.33.8"
2525
julia = "1"
2626

2727
[extras]
28+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2829
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2930
PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
3031
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -33,4 +34,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3334
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"
3435

3536
[targets]
36-
test = ["Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]
37+
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]

src/OptimalTransport.jl

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,10 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein
2222

2323
const MOI = MathOptInterface
2424

25+
include("utils.jl")
2526
include("exact.jl")
2627
include("wasserstein.jl")
2728

28-
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
29-
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
30-
xmat = reshape(x, size(x, 1) * size(x, 2), :)
31-
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
32-
end
33-
3429
"""
3530
sinkhorn_gibbs(
3631
μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -58,11 +53,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
5853
The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
5954
the computation is stopped.
6055
61-
Note that for a common kernel `K`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
62-
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
63-
The output are then matrices `u` and `v` such that `u[:, i]` and `v[:, i]` are the dual variables for `μ[:, i]` and `ν[:, i]`.
64-
65-
In addition, the case where one of `μ` or `ν` is a single histogram and the other a matrix of histograms is supported.
56+
Batch computations for multiple histograms with a common Gibbs kernel `K` can be performed
57+
by passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required
58+
that the number of source and target marginals is equal or that a single source or single
59+
target marginal is provided (either as matrix or as vector). The optimal transport plans are
60+
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
61+
`i`th pair of source and target marginals.
6662
"""
6763
function sinkhorn_gibbs(
6864
μ,
@@ -87,43 +83,66 @@ function sinkhorn_gibbs(
8783
:sinkhorn_gibbs,
8884
)
8985
end
90-
if (size(μ, 2) != size(ν, 2)) && (min(size(μ, 2), size(ν, 2)) > 1)
91-
throw(
92-
DimensionMismatch(
93-
"Error: number of columns in μ and ν must coincide, if both are matrix valued",
94-
),
95-
)
96-
end
97-
all(sum(μ; dims=1) .≈ sum(ν; dims=1)) ||
98-
throw(ArgumentError("source and target marginals must have the same mass"))
86+
87+
# checks
88+
size2 = checksize2(μ, ν)
89+
checkbalanced(μ, ν)
9990

10091
# set default values of tolerances
10192
T = float(Base.promote_eltype(μ, ν, K))
10293
_atol = atol === nothing ? 0 : atol
10394
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol
10495

105-
# initial iteration
106-
u = if isequal(size(μ, 2), size(ν, 2))
107-
similar(μ)
108-
else
109-
repeat(similar(μ[:, 1]); outer=(1, max(size(μ, 2), size(ν, 2))))
96+
# initialize iterates
97+
u = similar(μ, T, size(μ, 1), size2...)
98+
v = similar(ν, T, size(ν, 1), size2...)
99+
fill!(v, one(T))
100+
101+
# arrays for convergence check
102+
Kv = similar(u)
103+
mul!(Kv, K, v)
104+
tmp = similar(u)
105+
norm_μ = μ isa AbstractVector ? sum(abs, μ) : sum(abs, μ; dims=1)
106+
if u isa AbstractMatrix
107+
tmp2 = similar(u)
108+
norm_uKv = similar(u, 1, size2...)
109+
norm_diff = similar(u, 1, size2...)
110+
_isconverged = similar(u, Bool, 1, size2...)
110111
end
111-
u .= μ ./ vec(sum(K; dims=2))
112-
v = ν ./ (K' * u)
113-
tmp1 = K * v
114-
tmp2 = similar(u)
115112

116-
norm_μ = sum(abs, μ; dims=1) # for convergence check
117113
isconverged = false
118114
check_step = check_convergence === nothing ? 10 : check_convergence
119-
for iter in 0:maxiter
120-
if iter % check_step == 0
121-
# check source marginal
122-
# do not overwrite `tmp1` but reuse it for computing `u` if not converged
123-
@. tmp2 = u * tmp1
124-
norm_uKv = sum(abs, tmp2; dims=1)
125-
@. tmp2 = μ - tmp2
126-
norm_diff = sum(abs, tmp2; dims=1)
115+
to_check_step = check_step
116+
for iter in 1:maxiter
117+
# reduce counter
118+
to_check_step -= 1
119+
120+
# compute next iterate
121+
u .= μ ./ Kv
122+
mul!(v, K', u)
123+
v .= ν ./ v
124+
mul!(Kv, K, v)
125+
126+
# check source marginal
127+
# always check convergence after the final iteration
128+
if to_check_step <= 0 || iter == maxiter
129+
# reset counter
130+
to_check_step = check_step
131+
132+
# do not overwrite `Kv` but reuse it for computing `u` if not converged
133+
tmp .= u .* Kv
134+
if u isa AbstractMatrix
135+
tmp2 .= abs.(tmp)
136+
sum!(norm_uKv, tmp2)
137+
else
138+
norm_uKv = sum(abs, tmp)
139+
end
140+
tmp .= abs.(μ .- tmp)
141+
if u isa AbstractMatrix
142+
sum!(norm_diff, tmp)
143+
else
144+
norm_diff = sum(tmp)
145+
end
127146

128147
@debug "Sinkhorn algorithm (" *
129148
string(iter) *
@@ -133,20 +152,17 @@ function sinkhorn_gibbs(
133152
string(maximum(norm_diff))
134153

135154
# check stopping criterion
136-
if all(@. norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv)))
155+
isconverged = if u isa AbstractMatrix
156+
@. _isconverged = norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
157+
all(_isconverged)
158+
else
159+
norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
160+
end
161+
if isconverged
137162
@debug "Sinkhorn algorithm ($iter/$maxiter): converged"
138-
isconverged = true
139163
break
140164
end
141165
end
142-
143-
# perform next iteration
144-
if iter < maxiter
145-
@. u = μ / tmp1
146-
mul!(v, K', u)
147-
@. v = ν / v
148-
mul!(tmp1, K, v)
149-
end
150166
end
151167

152168
if !isconverged
@@ -156,13 +172,6 @@ function sinkhorn_gibbs(
156172
return u, v
157173
end
158174

159-
function add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
160-
shape = ntuple(ndims(x) + 1) do i
161-
return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1)
162-
end
163-
return reshape(x, shape)
164-
end
165-
166175
"""
167176
sinkhorn(
168177
μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -188,10 +197,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
188197
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
189198
the computation is stopped.
190199
191-
Note that for a common cost `C`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
192-
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
193-
194-
The output in this case is an `Array` `γ` of coupling matrices such that `γ[:, :, i]` is a coupling of `μ[:, i]` and `ν[:, i]`.
200+
Batch computations for multiple histograms with a common cost matrix `C` can be performed by
201+
passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that
202+
the number of source and target marginals is equal or that a single source or single target
203+
marginal is provided (either as matrix or as vector). The optimal transport plans are
204+
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
205+
`i`th pair of source and target marginals.
195206
196207
See also: [`sinkhorn2`](@ref)
197208
"""

src/utils.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
3+
4+
Add an additional dimension `dim` of size 1 to array `x`.
5+
"""
6+
function add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
7+
shape = ntuple(max(ndims(x) + 1, dim)) do i
8+
return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1)
9+
end
10+
return reshape(x, shape)
11+
end
12+
13+
"""
14+
dot_matwise(x::AbstractArray, y::AbstractArray)
15+
16+
Compute the inner product of all matrices in `x` and `y`.
17+
18+
At least one of `x` and `y` has to be a matrix.
19+
"""
20+
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
21+
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
22+
xmat = reshape(x, size(x, 1) * size(x, 2), :)
23+
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
24+
end
25+
dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x)
26+
27+
"""
28+
checksize2(x::AbstractVecOrMat, y::AbstractVecOrMat)
29+
30+
Check if arrays `x` and `y` are compatible, then return a tuple of its broadcasted second
31+
dimension.
32+
"""
33+
checksize2(::AbstractVector, ::AbstractVector) = ()
34+
function checksize2::AbstractVecOrMat, ν::AbstractVecOrMat)
35+
size_μ_2 = size(μ, 2)
36+
size_ν_2 = size(ν, 2)
37+
if size_μ_2 > 1 && size_ν_2 > 1 && size_μ_2 != size_ν_2
38+
throw(DimensionMismatch("size of source and target marginals is not compatible"))
39+
end
40+
return (max(size_μ_2, size_ν_2),)
41+
end
42+
43+
"""
44+
checkbalanced(μ::AbstractVecOrMat, ν::AbstractVecOrMat)
45+
46+
Check that source and target marginals `μ` and `ν` are balanced.
47+
"""
48+
function checkbalanced::AbstractVector, ν::AbstractVector)
49+
sum(μ) sum(ν) || throw(ArgumentError("source and target marginals are not balanced"))
50+
return nothing
51+
end
52+
function checkbalanced(x::AbstractVecOrMat, y::AbstractVecOrMat)
53+
all(isapprox.(sum(x; dims=1), sum(y; dims=1))) ||
54+
throw(ArgumentError("source and target marginals are not balanced"))
55+
return nothing
56+
end

0 commit comments

Comments
 (0)