@@ -22,15 +22,10 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein
2222
2323const MOI = MathOptInterface
2424
25+ include (" utils.jl" )
2526include (" exact.jl" )
2627include (" wasserstein.jl" )
2728
28- dot_matwise (x:: AbstractMatrix , y:: AbstractMatrix ) = dot (x, y)
29- function dot_matwise (x:: AbstractArray , y:: AbstractMatrix )
30- xmat = reshape (x, size (x, 1 ) * size (x, 2 ), :)
31- return reshape (reshape (y, 1 , :) * xmat, size (x)[3 : end ])
32- end
33-
3429"""
3530 sinkhorn_gibbs(
3631 μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -58,11 +53,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
5853The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
5954the computation is stopped.
6055
61- Note that for a common kernel `K`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
62- as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
63- The output are then matrices `u` and `v` such that `u[:, i]` and `v[:, i]` are the dual variables for `μ[:, i]` and `ν[:, i]`.
64-
65- In addition, the case where one of `μ` or `ν` is a single histogram and the other a matrix of histograms is supported.
56+ Batch computations for multiple histograms with a common Gibbs kernel `K` can be performed
57+ by passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required
58+ that the number of source and target marginals is equal or that a single source or single
59+ target marginal is provided (either as matrix or as vector). The optimal transport plans are
60+ returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
61+ `i`th pair of source and target marginals.
6662"""
6763function sinkhorn_gibbs (
6864 μ,
@@ -87,43 +83,66 @@ function sinkhorn_gibbs(
8783 :sinkhorn_gibbs ,
8884 )
8985 end
90- if (size (μ, 2 ) != size (ν, 2 )) && (min (size (μ, 2 ), size (ν, 2 )) > 1 )
91- throw (
92- DimensionMismatch (
93- " Error: number of columns in μ and ν must coincide, if both are matrix valued" ,
94- ),
95- )
96- end
97- all (sum (μ; dims= 1 ) .≈ sum (ν; dims= 1 )) ||
98- throw (ArgumentError (" source and target marginals must have the same mass" ))
86+
87+ # checks
88+ size2 = checksize2 (μ, ν)
89+ checkbalanced (μ, ν)
9990
10091 # set default values of tolerances
10192 T = float (Base. promote_eltype (μ, ν, K))
10293 _atol = atol === nothing ? 0 : atol
10394 _rtol = rtol === nothing ? (_atol > zero (_atol) ? zero (T) : sqrt (eps (T))) : rtol
10495
105- # initial iteration
106- u = if isequal (size (μ, 2 ), size (ν, 2 ))
107- similar (μ)
108- else
109- repeat (similar (μ[:, 1 ]); outer= (1 , max (size (μ, 2 ), size (ν, 2 ))))
96+ # initialize iterates
97+ u = similar (μ, T, size (μ, 1 ), size2... )
98+ v = similar (ν, T, size (ν, 1 ), size2... )
99+ fill! (v, one (T))
100+
101+ # arrays for convergence check
102+ Kv = similar (u)
103+ mul! (Kv, K, v)
104+ tmp = similar (u)
105+ norm_μ = μ isa AbstractVector ? sum (abs, μ) : sum (abs, μ; dims= 1 )
106+ if u isa AbstractMatrix
107+ tmp2 = similar (u)
108+ norm_uKv = similar (u, 1 , size2... )
109+ norm_diff = similar (u, 1 , size2... )
110+ _isconverged = similar (u, Bool, 1 , size2... )
110111 end
111- u .= μ ./ vec (sum (K; dims= 2 ))
112- v = ν ./ (K' * u)
113- tmp1 = K * v
114- tmp2 = similar (u)
115112
116- norm_μ = sum (abs, μ; dims= 1 ) # for convergence check
117113 isconverged = false
118114 check_step = check_convergence === nothing ? 10 : check_convergence
119- for iter in 0 : maxiter
120- if iter % check_step == 0
121- # check source marginal
122- # do not overwrite `tmp1` but reuse it for computing `u` if not converged
123- @. tmp2 = u * tmp1
124- norm_uKv = sum (abs, tmp2; dims= 1 )
125- @. tmp2 = μ - tmp2
126- norm_diff = sum (abs, tmp2; dims= 1 )
115+ to_check_step = check_step
116+ for iter in 1 : maxiter
117+ # reduce counter
118+ to_check_step -= 1
119+
120+ # compute next iterate
121+ u .= μ ./ Kv
122+ mul! (v, K' , u)
123+ v .= ν ./ v
124+ mul! (Kv, K, v)
125+
126+ # check source marginal
127+ # always check convergence after the final iteration
128+ if to_check_step <= 0 || iter == maxiter
129+ # reset counter
130+ to_check_step = check_step
131+
132+ # do not overwrite `Kv` but reuse it for computing `u` if not converged
133+ tmp .= u .* Kv
134+ if u isa AbstractMatrix
135+ tmp2 .= abs .(tmp)
136+ sum! (norm_uKv, tmp2)
137+ else
138+ norm_uKv = sum (abs, tmp)
139+ end
140+ tmp .= abs .(μ .- tmp)
141+ if u isa AbstractMatrix
142+ sum! (norm_diff, tmp)
143+ else
144+ norm_diff = sum (tmp)
145+ end
127146
128147 @debug " Sinkhorn algorithm (" *
129148 string (iter) *
@@ -133,20 +152,17 @@ function sinkhorn_gibbs(
133152 string (maximum (norm_diff))
134153
135154 # check stopping criterion
136- if all (@. norm_diff < max (_atol, _rtol * max (norm_μ, norm_uKv)))
155+ isconverged = if u isa AbstractMatrix
156+ @. _isconverged = norm_diff < max (_atol, _rtol * max (norm_μ, norm_uKv))
157+ all (_isconverged)
158+ else
159+ norm_diff < max (_atol, _rtol * max (norm_μ, norm_uKv))
160+ end
161+ if isconverged
137162 @debug " Sinkhorn algorithm ($iter /$maxiter ): converged"
138- isconverged = true
139163 break
140164 end
141165 end
142-
143- # perform next iteration
144- if iter < maxiter
145- @. u = μ / tmp1
146- mul! (v, K' , u)
147- @. v = ν / v
148- mul! (tmp1, K, v)
149- end
150166 end
151167
152168 if ! isconverged
@@ -156,13 +172,6 @@ function sinkhorn_gibbs(
156172 return u, v
157173end
158174
159- function add_singleton (x:: AbstractArray , :: Val{dim} ) where {dim}
160- shape = ntuple (ndims (x) + 1 ) do i
161- return i < dim ? size (x, i) : (i > dim ? size (x, i - 1 ) : 1 )
162- end
163- return reshape (x, shape)
164- end
165-
166175"""
167176 sinkhorn(
168177 μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
@@ -188,10 +197,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
188197The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
189198the computation is stopped.
190199
191- Note that for a common cost `C`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
192- as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
193-
194- The output in this case is an `Array` `γ` of coupling matrices such that `γ[:, :, i]` is a coupling of `μ[:, i]` and `ν[:, i]`.
200+ Batch computations for multiple histograms with a common cost matrix `C` can be performed by
201+ passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that
202+ the number of source and target marginals is equal or that a single source or single target
203+ marginal is provided (either as matrix or as vector). The optimal transport plans are
204+ returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
205+ `i`th pair of source and target marginals.
195206
196207See also: [`sinkhorn2`](@ref)
197208"""
0 commit comments