Skip to content

Commit caec540

Browse files
zstevedevmotion
andauthored
Enable batch computation of OT via Sinkhorn (#67)
* allow barycenter to be computed with batch kernel reduction (changes calling convention) * sinkhorn batch computation * update example * sinkhorn batch computation * fix tests * format * fix type instability as per review * type instability fix attempt 2 * type instability fix attempt 3 * formatting * implement common output type for sinkhorn and sinkhorn2 * formatting * Update src/OptimalTransport.jl Co-authored-by: David Widmann <[email protected]> * removed multiple cost matrix from sinkhorn_barycenter, updated docs * Update examples/basic/script.jl Co-authored-by: David Widmann <[email protected]> * Update examples/basic/script.jl Co-authored-by: David Widmann <[email protected]> * allow one to many sinkhorn computation * rebase * fix output dimensions * format * update docstrings * increment version * update tests * formatting Co-authored-by: David Widmann <[email protected]>
1 parent 845edda commit caec540

File tree

4 files changed

+141
-76
lines changed

4 files changed

+141
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

examples/basic/script.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ heatmap(
206206
# the entropically regularised **barycenter** in $\mathcal{P}$ is the discrete probability
207207
# measure $\mu$ that solves
208208
# ```math
209-
# \inf_{\mu \in \mathcal{P}} \sum_{i = 1}^N \lambda_i \mathrm{entropicOT}^{\epsilon}_{C_i}(\mu, \mu_i)
209+
# \inf_{\mu \in \mathcal{P}} \sum_{i = 1}^N \lambda_i \operatorname{OT}_{\epsilon}(\mu, \mu_i)
210210
# ```
211-
# where $\mathrm{entropicOT}^\epsilon_{C_i}(\mu, \mu_i)$ denotes the entropically regularised
212-
# optimal transport cost with marginals $\mu$ and $\mu_i$, cost matrix $C_i$, and entropic
211+
# where $\operatorname{OT}_\epsilon(\mu, \mu_i)$ denotes the entropically regularised
212+
# optimal transport cost with marginals $\mu$ and $\mu_i$, cost matrix $C$, and entropic
213213
# regularisation parameter $\epsilon$.
214214
#
215215
# We set up two measures and compute the weighted barycenters. We choose weights
@@ -225,9 +225,8 @@ plt = plot(; size=(800, 400), legend=:outertopright)
225225
plot!(plt, support, mu1; label=raw"$\mu_1$")
226226
plot!(plt, support, mu2; label=raw"$\mu_2$")
227227

228-
mu = hcat(mu1, mu2)'
229-
C1 = C2 = pairwise(SqEuclidean(), support'; dims=2)
230-
C = [C1, C2]
228+
mu = hcat(mu1, mu2)
229+
C = pairwise(SqEuclidean(), support'; dims=2)
231230
for λ1 in (0.25, 0.5, 0.75)
232231
λ2 = 1 - λ1
233232
a = sinkhorn_barycenter(mu, C, 0.01, [λ1, λ2]; max_iter=1000)

src/OptimalTransport.jl

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ export ot_cost, ot_plan
2121

2222
const MOI = MathOptInterface
2323

24+
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
25+
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
26+
xmat = reshape(x, size(x, 1) * size(x, 2), :)
27+
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
28+
end
29+
2430
"""
2531
emd(μ, ν, C, optimizer)
2632
@@ -138,14 +144,19 @@ and ``v`` as
138144
```math
139145
\\gamma = \\operatorname{diag}(u) K \\operatorname{diag}(v).
140146
```
141-
142147
Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
143148
the iterate of the transport plan `G` satisfies
144149
```julia
145150
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
146151
```
147152
The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
148153
the computation is stopped.
154+
155+
Note that for a common kernel `K`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
156+
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
157+
The output are then matrices `u` and `v` such that `u[:, i]` and `v[:, i]` are the dual variables for `μ[:, i]` and `ν[:, i]`.
158+
159+
In addition, the case where one of `μ` or `ν` is a single histogram and the other a matrix of histograms is supported.
149160
"""
150161
function sinkhorn_gibbs(
151162
μ,
@@ -170,7 +181,14 @@ function sinkhorn_gibbs(
170181
:sinkhorn_gibbs,
171182
)
172183
end
173-
sum(μ) sum(ν) ||
184+
if (size(μ, 2) != size(ν, 2)) && (min(size(μ, 2), size(ν, 2)) > 1)
185+
throw(
186+
DimensionMismatch(
187+
"Error: number of columns in μ and ν must coincide, if both are matrix valued",
188+
),
189+
)
190+
end
191+
all(sum(μ; dims=1) .≈ sum(ν; dims=1)) ||
174192
throw(ArgumentError("source and target marginals must have the same mass"))
175193

176194
# set default values of tolerances
@@ -179,32 +197,37 @@ function sinkhorn_gibbs(
179197
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol
180198

181199
# initial iteration
182-
u = μ ./ sum(K; dims=2)
200+
u = if isequal(size(μ, 2), size(ν, 2))
201+
similar(μ)
202+
else
203+
repeat(similar(μ[:, 1]); outer=(1, max(size(μ, 2), size(ν, 2))))
204+
end
205+
u .= μ ./ vec(sum(K; dims=2))
183206
v = ν ./ (K' * u)
184207
tmp1 = K * v
185208
tmp2 = similar(u)
186209

187-
norm_μ = sum(abs, μ) # for convergence check
210+
norm_μ = sum(abs, μ; dims=1) # for convergence check
188211
isconverged = false
189212
check_step = check_convergence === nothing ? 10 : check_convergence
190213
for iter in 0:maxiter
191214
if iter % check_step == 0
192215
# check source marginal
193216
# do not overwrite `tmp1` but reuse it for computing `u` if not converged
194217
@. tmp2 = u * tmp1
195-
norm_uKv = sum(abs, tmp2)
218+
norm_uKv = sum(abs, tmp2; dims=1)
196219
@. tmp2 = μ - tmp2
197-
norm_diff = sum(abs, tmp2)
220+
norm_diff = sum(abs, tmp2; dims=1)
198221

199222
@debug "Sinkhorn algorithm (" *
200223
string(iter) *
201224
"/" *
202225
string(maxiter) *
203226
": absolute error of source marginal = " *
204-
string(norm_diff)
227+
string(maximum(norm_diff))
205228

206229
# check stopping criterion
207-
if norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
230+
if all(@. norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv)))
208231
@debug "Sinkhorn algorithm ($iter/$maxiter): converged"
209232
isconverged = true
210233
break
@@ -227,6 +250,13 @@ function sinkhorn_gibbs(
227250
return u, v
228251
end
229252

253+
function add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
254+
shape = ntuple(ndims(x) + 1) do i
255+
return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1)
256+
end
257+
return reshape(x, shape)
258+
end
259+
230260
"""
231261
sinkhorn(
232262
μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -252,6 +282,11 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
252282
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
253283
the computation is stopped.
254284
285+
Note that for a common cost `C`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
286+
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
287+
288+
The output in this case is an `Array` `γ` of coupling matrices such that `γ[:, :, i]` is a coupling of `μ[:, i]` and `ν[:, i]`.
289+
255290
See also: [`sinkhorn2`](@ref)
256291
"""
257292
function sinkhorn(μ, ν, C, ε; kwargs...)
@@ -260,8 +295,7 @@ function sinkhorn(μ, ν, C, ε; kwargs...)
260295

261296
# compute dual potentials
262297
u, v = sinkhorn_gibbs(μ, ν, K; kwargs...)
263-
264-
return K .* u .* v'
298+
return K .* add_singleton(u, Val(2)) .* add_singleton(v, Val(1))
265299
end
266300

267301
"""
@@ -286,18 +320,19 @@ function sinkhorn2(μ, ν, C, ε; regularization=false, plan=nothing, kwargs...)
286320
sinkhorn(μ, ν, C, ε; kwargs...)
287321
else
288322
# check dimensions
289-
size(C) == (length(μ), length(ν)) ||
290-
error("cost matrix `C` must be of size `(length(μ), length(ν))`")
291-
size(plan) == size(C) || error(
323+
size(C) == (size(μ, 1), size(ν, 1)) || error(
324+
"cost matrix `C` must be of size `(size(μ, dims = 1), size(ν, dims = 1))`",
325+
)
326+
(size(plan, 1), size(plan, 2)) == size(C) || error(
292327
"optimal transport plan `plan` and cost matrix `C` must be of the same size",
293328
)
294329
plan
295330
end
296-
297331
cost = if regularization
298-
dot(γ, C) + ε * sum(LogExpFunctions.xlogx, γ)
332+
dot_matwise(γ, C) .+
333+
ε * reshape(sum(LogExpFunctions.xlogx, γ; dims=(1, 2)), size(γ)[3:end])
299334
else
300-
dot(γ, C)
335+
dot_matwise(γ, C)
301336
end
302337

303338
return cost
@@ -668,54 +703,36 @@ function sinkhorn_stabilized(
668703
end
669704

670705
"""
671-
sinkhorn_barycenter(mu_all, C_all, eps, lambda_all; tol = 1e-9, check_marginal_step = 10, max_iter = 1000)
672-
673-
Compute the entropically regularised (i.e. Sinkhorn) barycenter for a collection of `N`
674-
histograms `mu_all` with respective cost matrices `C_all`, relative weights `lambda_all`,
675-
and entropic regularisation parameter `eps`.
706+
sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
676707
677-
- `mu_all` is taken to contain `N` histograms `mu_all[i, :]` for `math i = 1, \\ldots, N`.
678-
- `C_all` is taken to be a list of `N` cost matrices corresponding to the `mu_all[i, :]`.
679-
- `eps` is the scalar regularisation parameter.
680-
- `lambda_all` are positive weights.
681-
682-
Returns the entropically regularised barycenter of the `mu_all`, i.e. the distribution that minimises
708+
Compute the Sinkhorn barycenter for a collection of `N` histograms contained in the columns of `μ`, for a cost matrix `C` of size `(size(μ, 1), size(μ, 1))`, relative weights `w` of size `N`, and entropic regularisation parameter `ε`.
709+
Returns the entropically regularised barycenter of the `μ`, i.e. the histogram `ρ` of length `size(μ, 1)` that solves
683710
684711
```math
685-
\\min_{\\mu \\in \\Sigma} \\sum_{i = 1}^N \\lambda_i \\mathrm{entropicOT}^{\\epsilon}_{C_i}(\\mu, \\mu_i)
712+
\\min_{\\rho \\in \\Sigma} \\sum_{i = 1}^N w_i \\operatorname{OT}_{\\varepsilon}(\\mu_i, \\rho)
686713
```
687714
688-
where ``\\mathrm{entropicOT}^{\\epsilon}_{C}`` denotes the entropic optimal transport cost with cost ``C`` and entropic regularisation level ``\\epsilon``.
715+
where ``\\operatorname{OT}_{ε}(\\mu, \\nu) = \\inf_{\\gamma \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle + \\varepsilon \\Omega(\\gamma)``
716+
is the entropic optimal transport loss with cost ``C`` and regularisation ``\\epsilon``.
689717
"""
690-
function sinkhorn_barycenter(
691-
mu_all, C_all, eps, lambda_all; tol=1e-9, check_marginal_step=10, max_iter=1000
692-
)
693-
sums = sum(mu_all; dims=2)
718+
function sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
719+
sums = sum(μ; dims=1)
694720
if !isapprox(extrema(sums)...)
695721
throw(ArgumentError("Error: marginals are unbalanced"))
696722
end
697-
K_all = [exp.(-C_all[i] / eps) for i in 1:length(C_all)]
723+
K = exp.(-C / ε)
698724
converged = false
699-
v_all = ones(size(mu_all))
700-
u_all = ones(size(mu_all))
701-
N = size(mu_all, 1)
725+
v = ones(size(μ))
726+
u = ones(size(μ))
727+
N = size(μ, 2)
702728
for n in 1:max_iter
703-
for i in 1:N
704-
v_all[i, :] = mu_all[i, :] ./ (K_all[i]' * u_all[i, :])
705-
end
706-
a = ones(size(u_all, 2))
707-
for i in 1:N
708-
a = a .* ((K_all[i] * v_all[i, :]) .^ (lambda_all[i]))
709-
end
710-
for i in 1:N
711-
u_all[i, :] = a ./ (K_all[i] * v_all[i, :])
712-
end
729+
v = μ ./ (K' * u)
730+
a = ones(size(u, 1))
731+
a = prod((K * v)' .^ w; dims=1)'
732+
u = a ./ (K * v)
713733
if n % check_marginal_step == 0
714734
# check marginal errors
715-
err = maximum([
716-
maximum(abs.(mu_all[i, :] .- v_all[i, :] .* (K_all[i]' * u_all[i, :]))) for
717-
i in 1:N
718-
])
735+
err = maximum(abs.(μ .- v .* (K' * u)))
719736
@debug "Sinkhorn algorithm: iteration $n" err
720737
if err < tol
721738
converged = true
@@ -726,7 +743,7 @@ function sinkhorn_barycenter(
726743
if !converged
727744
@warn "Sinkhorn did not converge"
728745
end
729-
return u_all[1, :] .* (K_all[1] * v_all[1, :])
746+
return u[:, 1] .* (K * v[:, 1])
730747
end
731748

732749
"""

test/entropic.jl

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,53 @@ Random.seed!(100)
7979
# compare with POT
8080
c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-6)[1]
8181
@test Float32(c_pot) c rtol = 1e-3
82+
83+
# batch
84+
d = 10
85+
μ = fill(Float32(1 / M), (M, d))
86+
ν = fill(Float32(1 / N), N)
87+
88+
γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-6)
89+
γ_pot = [
90+
POT.sinkhorn(μ[:, i], vec(ν), C, eps; numItermax=5_000, stopThr=1e-6) for
91+
i in 1:d
92+
]
93+
@test all([
94+
isapprox(Float32.(γ_pot[i]), γ_all[:, :, i]; rtol=1e-3) for i in 1:d
95+
])
96+
@test eltype(γ_all) == Float32
97+
end
98+
99+
@testset "batch" begin
100+
# create two sets of batch histograms
101+
d = 10
102+
μ = rand(Float64, (M, d))
103+
μ = μ ./ sum(μ; dims=1)
104+
ν = rand(Float64, (N, d))
105+
ν = ν ./ sum(ν; dims=1)
106+
107+
# create random cost matrix
108+
C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2)
109+
110+
# compute optimal transport map (Julia implementation + POT)
111+
eps = 0.01
112+
γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000)
113+
γ_pot = [POT.sinkhorn(μ[:, i], ν[:, i], C, eps; numItermax=5_000) for i in 1:d]
114+
@test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d])
115+
116+
c_all = sinkhorn2(μ, ν, C, eps; maxiter=5_000)
117+
c_pot = [
118+
POT.sinkhorn2(μ[:, i], ν[:, i], C, eps; numItermax=5_000)[1] for i in 1:d
119+
]
120+
@test c_all c_pot rtol = 1e-6
121+
122+
γ_all = sinkhorn(μ[:, 1], ν, C, eps; maxiter=5_000)
123+
γ_pot = [POT.sinkhorn(μ[:, 1], ν[:, i], C, eps; numItermax=5_000) for i in 1:d]
124+
@test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d])
125+
126+
γ_all = sinkhorn(μ, ν[:, 1], C, eps; maxiter=5_000)
127+
γ_pot = [POT.sinkhorn(μ[:, i], ν[:, 1], C, eps; numItermax=5_000) for i in 1:d]
128+
@test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d])
82129
end
83130

84131
@testset "deprecations" begin
@@ -146,22 +193,24 @@ Random.seed!(100)
146193
end
147194
end
148195

149-
@testset "sinkhorn_barycenter" begin
150-
# set up support
151-
support = range(-1; stop=1, length=250)
152-
μ1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
153-
μ1 ./= sum(μ1)
154-
μ2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
155-
μ2 ./= sum(μ2)
156-
μ_all = hcat(μ1, μ2)'
157-
158-
# create cost matrix
159-
C = pairwise(SqEuclidean(), support'; dims=2)
160-
161-
# compute Sinkhorn barycenter (Julia implementation + POT)
162-
eps = 0.01
163-
μ_interp = sinkhorn_barycenter(μ_all, [C, C], eps, [0.5, 0.5])
164-
μ_interp_pot = POT.barycenter(μ_all', C, eps; weights=[0.5, 0.5], stopThr=1e-9)
165-
@test μ_interp μ_interp_pot
196+
@testset "sinkhorn barycenter" begin
197+
@testset "example" begin
198+
# set up support
199+
support = range(-1; stop=1, length=250)
200+
μ1 = exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2)
201+
μ1 ./= sum(μ1)
202+
μ2 = exp.(-(support .- 0.5) .^ 2 ./ 0.1^2)
203+
μ2 ./= sum(μ2)
204+
μ_all = hcat(μ1, μ2)
205+
# create cost matrix
206+
C = pairwise(SqEuclidean(), support'; dims=2)
207+
208+
# compute Sinkhorn barycenter (Julia implementation + POT)
209+
eps = 0.01
210+
μ_interp = sinkhorn_barycenter(μ_all, C, eps, [0.5, 0.5])
211+
μ_interp_pot = POT.barycenter(μ_all, C, eps; weights=[0.5, 0.5], stopThr=1e-9)
212+
# need to use a larger tolerance here because of a quirk with the POT solver
213+
@test μ_interp μ_interp_pot rtol = 1e-6
214+
end
166215
end
167216
end

0 commit comments

Comments
 (0)