Skip to content

Commit 3674ceb

Browse files
authored
Simplify stabilized Sinkhorn algorithm (#99)
* Move file * Move test file * Improve stabilized Sinkhorn algorithm * Add test and remove redundant computations * Bump version
1 parent 24ecf5d commit 3674ceb

File tree

8 files changed

+451
-202
lines changed

8 files changed

+451
-202
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/OptimalTransport.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ include("distances/bures.jl")
2727
include("utils.jl")
2828
include("exact.jl")
2929
include("wasserstein.jl")
30-
include("entropic.jl")
30+
include("entropic/sinkhorn.jl")
31+
include("entropic/sinkhorn_stabilized.jl")
3132
include("quadratic.jl")
3233

3334
end

src/entropic.jl renamed to src/entropic/sinkhorn.jl

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -461,136 +461,6 @@ function sinkhorn_unbalanced2(
461461
return dot(γ, C)
462462
end
463463

464-
"""
465-
sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda = 0.5, k = 5, kwargs...)
466-
467-
Compute the optimal transport plan for the entropically regularized optimal transport problem
468-
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19] with ε-scaling.
469-
470-
`k` ε-scaling steps are used with scaling factor `lambda`, i.e. sequentially solve Sinkhorn using `sinkhorn_stabilized` with regularisation parameters
471-
``ε_i \\in [λ^{1-k}, \\ldots, λ^{-1}, 1] \\times ε``.
472-
473-
See also: [`sinkhorn_stabilized`](@ref), [`sinkhorn`](@ref)
474-
"""
475-
function sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda=0.5, k=5, kwargs...)
476-
α = zero(μ)
477-
β = zero(ν)
478-
for ε_i in* lambda^(1 - j) for j in k:-1:1)
479-
@debug "Epsilon-scaling Sinkhorn algorithm: ε = $ε_i"
480-
α, β = sinkhorn_stabilized(
481-
μ, ν, C, ε_i; alpha=α, beta=β, return_duals=true, kwargs...
482-
)
483-
end
484-
gamma = similar(C)
485-
getK!(gamma, C, α, β, ε, μ, ν)
486-
return gamma
487-
end
488-
489-
function getK!(K, C, α, β, ε, μ, ν)
490-
@. K = exp(-(C - α - β') / ε) * μ * ν'
491-
return K
492-
end
493-
494-
"""
495-
sinkhorn_stabilized(μ, ν, C, ε; absorb_tol = 1e3, alpha_0 = zero(μ), beta = zero(ν), maxiter = 1_000, atol = tol, rtol=nothing, return_duals = false)
496-
497-
Compute the optimal transport plan for the entropically regularized optimal transport problem
498-
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19]
499-
500-
`alpha` and `beta` are initial scalings for the stabilized Gibbs kernel. If not specified, `alpha` and `beta` are initialised to zero.
501-
502-
If `return_duals = true`, then the optimal dual variables `(u, v)` corresponding to `(μ, ν)` are returned. Otherwise, the coupling `γ` is returned.
503-
504-
[^S19]: Schmitzer, B., 2019. Stabilized sparse scaling algorithms for entropy regularized transport problems. SIAM Journal on Scientific Computing, 41(3), pp.A1443-A1481.
505-
506-
See also: [`sinkhorn`](@ref)
507-
"""
508-
function sinkhorn_stabilized(
509-
μ,
510-
ν,
511-
C,
512-
ε;
513-
absorb_tol=1e3,
514-
maxiter=1_000,
515-
tol=nothing,
516-
atol=tol,
517-
rtol=nothing,
518-
check_convergence=10,
519-
alpha=zero(μ),
520-
beta=zero(ν),
521-
return_duals=false,
522-
)
523-
if tol !== nothing
524-
Base.depwarn(
525-
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
526-
:sinkhorn_stabilized,
527-
)
528-
end
529-
sum(μ) sum(ν) ||
530-
throw(ArgumentError("source and target marginals must have the same mass"))
531-
532-
T = float(Base.promote_eltype(μ, ν, C))
533-
_atol = atol === nothing ? 0 : atol
534-
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol
535-
536-
norm_μ = sum(abs, μ)
537-
isconverged = false
538-
539-
K = similar(C)
540-
gamma = similar(C)
541-
542-
getK!(K, C, alpha, beta, ε, μ, ν)
543-
u = μ ./ sum(K; dims=2)
544-
v = ν ./ (K' * u)
545-
tmp_u = similar(u)
546-
for iter in 0:maxiter
547-
if (max(norm(u, Inf), norm(v, Inf)) > absorb_tol)
548-
@debug "Absorbing (u, v) into (alpha, beta)"
549-
# absorb into α, β
550-
alpha += ε * log.(u)
551-
beta += ε * log.(v)
552-
u .= 1
553-
v .= 1
554-
getK!(K, C, alpha, beta, ε, μ, ν)
555-
end
556-
if iter % check_convergence == 0
557-
# check marginal
558-
getK!(gamma, C, alpha, beta, ε, μ, ν)
559-
@. gamma *= u * v'
560-
norm_diff = sum(abs, gamma * ones(size(ν)) - μ)
561-
norm_uKv = sum(abs, gamma)
562-
@debug "Stabilized Sinkhorn algorithm (" *
563-
string(iter) *
564-
"/" *
565-
string(maxiter) *
566-
": error of source marginal = " *
567-
string(norm_diff)
568-
569-
if norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
570-
@debug "Stabilized Sinkhorn algorithm ($iter/$maxiter): converged"
571-
isconverged = true
572-
break
573-
end
574-
end
575-
mul!(tmp_u, K, v)
576-
u = μ ./ tmp_u
577-
mul!(v, K', u)
578-
v = ν ./ v
579-
end
580-
581-
if !isconverged
582-
@warn "Stabilized Sinkhorn algorithm ($maxiter/$maxiter): not converged"
583-
end
584-
585-
alpha = alpha + ε * log.(u)
586-
beta = beta + ε * log.(v)
587-
if return_duals
588-
return alpha, beta
589-
end
590-
getK!(gamma, C, alpha, beta, ε, μ, ν)
591-
return gamma
592-
end
593-
594464
"""
595465
sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
596466

0 commit comments

Comments
 (0)