Skip to content

Commit f03b05c

Browse files
Refactor sinkhorn and sinkhorn2 (#100)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent a2224ee commit f03b05c

21 files changed

+1643
-1094
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimalTransport"
22
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
33
authors = ["zsteve <[email protected]>"]
4-
version = "0.3.11"
4+
version = "0.3.12"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -10,6 +10,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1212
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
13+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1314
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1415
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1516
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -21,6 +22,7 @@ Distributions = "0.24, 0.25"
2122
IterativeSolvers = "0.8.4, 0.9"
2223
LogExpFunctions = "0.2"
2324
MathOptInterface = "0.9"
25+
NNlib = "0.6, 0.7"
2426
PDMats = "0.11"
2527
QuadGK = "2"
2628
StatsBase = "0.33.8"

docs/src/index.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,33 @@ squared2wasserstein
2121
```@docs
2222
sinkhorn
2323
sinkhorn2
24-
sinkhorn_stabilized_epsscaling
25-
sinkhorn_stabilized
2624
sinkhorn_barycenter
2725
```
2826

27+
Currently the following variants of the Sinkhorn algorithm are supported:
28+
29+
```@docs
30+
SinkhornGibbs
31+
SinkhornStabilized
32+
SinkhornEpsilonScaling
33+
```
34+
35+
The following methods are deprecated and will be removed:
36+
37+
```@docs
38+
sinkhorn_stabilized
39+
sinkhorn_stabilized_epsscaling
40+
```
41+
2942
## Unbalanced optimal transport
43+
3044
```@docs
3145
sinkhorn_unbalanced
3246
sinkhorn_unbalanced2
3347
```
3448

3549
## Quadratically regularised optimal transport
50+
3651
```@docs
3752
quadreg
3853
```

src/OptimalTransport.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ using MathOptInterface
1212
using Distributions
1313
using PDMats
1414
using QuadGK
15+
using NNlib: NNlib
1516
using StatsBase: StatsBase
1617

18+
export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
19+
1720
export sinkhorn, sinkhorn2
1821
export emd, emd2
1922
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
@@ -27,8 +30,14 @@ include("distances/bures.jl")
2730
include("utils.jl")
2831
include("exact.jl")
2932
include("wasserstein.jl")
33+
3034
include("entropic/sinkhorn.jl")
35+
include("entropic/sinkhorn_gibbs.jl")
3136
include("entropic/sinkhorn_stabilized.jl")
37+
include("entropic/sinkhorn_epsscaling.jl")
38+
include("entropic/sinkhorn_unbalanced.jl")
39+
include("entropic/sinkhorn_barycenter.jl")
40+
3241
include("quadratic.jl")
3342

3443
end

0 commit comments

Comments
 (0)