@@ -89,123 +89,35 @@ function build_convergence_cache(
8989 )
9090end
9191
92- # Sinkhorn algorithm
92+ # Sinkhorn algorithm steps (see solve!)
93+ function init_step! (solver:: SinkhornSolver )
94+ return A_batched_mul_B! (solver. cache. Kv, solver. cache. K, solver. cache. v)
95+ end
9396
94- function solve! (solver:: SinkhornSolver )
95- # unpack solver
97+ function step! (solver:: SinkhornSolver , iter:: Int )
9698 μ = solver. source
9799 ν = solver. target
98- atol = solver. atol
99- rtol = solver. rtol
100- maxiter = solver. maxiter
101- check_convergence = solver. check_convergence
102100 cache = solver. cache
103- convergence_cache = solver. convergence_cache
104-
105- # unpack cache
106101 u = cache. u
107102 v = cache. v
108- K = cache. K
109103 Kv = cache. Kv
104+ K = cache. K
110105
111- A_batched_mul_B! (Kv, K, v)
112-
113- isconverged = false
114- to_check_step = check_convergence
115- for iter in 1 : maxiter
116- # computations before the Sinkhorn iteration (e.g., absorption step)
117- prestep! (solver, iter)
118-
119- # perform Sinkhorn iteration
120- u .= μ ./ Kv
121- At_batched_mul_B! (v, K, u)
122- v .= ν ./ v
123- A_batched_mul_B! (Kv, K, v)
124-
125- # check source marginal
126- # always check convergence after the final iteration
127- to_check_step -= 1
128- if to_check_step == 0 || iter == maxiter
129- # reset counter
130- to_check_step = check_convergence
131-
132- isconverged, abserror = OptimalTransport. check_convergence (
133- μ, u, Kv, convergence_cache, atol, rtol
134- )
135- @debug string (solver. alg) *
136- " (" *
137- string (iter) *
138- " /" *
139- string (maxiter) *
140- " : absolute error of source marginal = " *
141- string (maximum (abserror))
142-
143- if isconverged
144- @debug " $(solver. alg) ($iter /$maxiter ): converged"
145- break
146- end
147- end
148- end
149-
150- if ! isconverged
151- @warn " $(solver. alg) ($maxiter /$maxiter ): not converged"
152- end
153-
154- return nothing
155- end
156-
157- # for single inputs
158- function check_convergence (
159- μ:: AbstractVector ,
160- u:: AbstractVector ,
161- Kv:: AbstractVector ,
162- cache:: SinkhornConvergenceCache ,
163- atol:: Real ,
164- rtol:: Real ,
165- )
166- # unpack
167- tmp = cache. tmp
168- norm_μ = cache. norm_source
169-
170- # do not overwrite `Kv` but reuse it for computing `u` if not converged
171- tmp .= u .* Kv
172- norm_uKv = sum (abs, tmp)
173- tmp .= abs .(μ .- tmp)
174- norm_diff = sum (tmp)
175-
176- isconverged = norm_diff < max (atol, rtol * max (norm_μ, norm_uKv))
177-
178- return isconverged, norm_diff
106+ u .= μ ./ Kv
107+ At_batched_mul_B! (v, K, u)
108+ v .= ν ./ v
109+ return A_batched_mul_B! (Kv, K, v)
179110end
180111
181- # for batches
182- function check_convergence (
183- μ:: AbstractVecOrMat ,
184- u:: AbstractMatrix ,
185- Kv:: AbstractMatrix ,
186- cache:: SinkhornBatchConvergenceCache ,
187- atol:: Real ,
188- rtol:: Real ,
189- )
190- # unpack
191- tmp = cache. tmp
192- tmp2 = cache. tmp2
193- norm_μ = cache. norm_source
194- norm_uKv = cache. norm_uKv
195- norm_diff = cache. norm_diff
196- isconverged = cache. isconverged
197-
198- # do not overwrite `Kv` but reuse it for computing `u` if not converged
199- tmp .= u .* Kv
200- tmp2 .= abs .(tmp)
201- sum! (norm_uKv, tmp2)
202- tmp .= abs .(μ .- tmp)
203- sum! (norm_diff, tmp)
204-
205- # check stopping criterion
206- @. isconverged = norm_diff < max (atol, rtol * max (norm_μ, norm_uKv))
207-
208- return all (isconverged), norm_diff
112+ function check_convergence (solver:: SinkhornSolver )
113+ return OptimalTransport. check_convergence (
114+ solver. source,
115+ solver. cache. u,
116+ solver. cache. Kv,
117+ solver. convergence_cache,
118+ solver. atol,
119+ solver. rtol,
120+ )
209121end
210122
211123# API
0 commit comments