Skip to content

Commit 6775a52

Browse files
ncourtyrflamaryagramfort
authored
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <[email protected]> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <[email protected]> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <[email protected]> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <[email protected]> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <[email protected]> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <[email protected]> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <[email protected]> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <[email protected]> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <[email protected]> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Alexandre Gramfort <[email protected]>
1 parent a335324 commit 6775a52

File tree

15 files changed

+1244
-445
lines changed

15 files changed

+1244
-445
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples):
3333
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
3434
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
3535
formulations).
36-
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
36+
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
3737
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
3838

3939
POT provides the following Machine Learning related solvers:
@@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t
285285

286286
[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021
287287

288-
[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
288+
[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
289+
290+
[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
291+
292+
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
293+
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
294+
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
295+
Machine Learning (pp. 4104-4113). PMLR.

docs/source/readme.rst

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples):
2424
for regularized OT [7].
2525
- Entropic regularization OT solver with `Sinkhorn Knopp
2626
Algorithm <auto_examples/plot_OT_1D.html>`__
27-
[2] , stabilized version [9] [10], greedy Sinkhorn [22] and
27+
[2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and
2828
`Screening Sinkhorn
2929
[26] <auto_examples/plot_screenkhorn_1D.html>`__.
3030
- Bregman projections for `Wasserstein
@@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples):
5454
solver <auto_examples/plot_stochastic.html>`__
5555
for Large-scale Optimal Transport (semi-dual problem [18] and dual
5656
problem [19])
57+
- `Stochastic solver of Gromov
58+
Wasserstein <auto_examples/gromov/plot_gromov.html>`__
59+
for large-scale problem with any loss functions [33]
5760
- Non regularized `free support Wasserstein
5861
barycenters <auto_examples/barycenters/plot_free_support_barycenter.html>`__
5962
[20].
@@ -137,19 +140,12 @@ following Python modules:
137140

138141
- Numpy (>=1.16)
139142
- Scipy (>=1.0)
140-
- Cython (>=0.23) (build only, not necessary when installing wheels
141-
from pip or conda)
143+
- Cython (>=0.23) (build only, not necessary when installing from pip
144+
or conda)
142145

143146
Pip installation
144147
^^^^^^^^^^^^^^^^
145148

146-
Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
147-
be installed prior to installing POT. This can be done easily with
148-
149-
.. code:: console
150-
151-
pip install numpy cython
152-
153149
You can install the toolbox through PyPI with:
154150

155151
.. code:: console
@@ -183,7 +179,8 @@ without errors:
183179
184180
import ot
185181
186-
Note that for easier access the module is name ot instead of pot.
182+
Note that for easier access the module is named ``ot`` instead of
183+
``pot``.
187184

188185
Dependencies
189186
~~~~~~~~~~~~
@@ -222,7 +219,7 @@ Short examples
222219

223220
.. code:: python
224221
225-
# a and b are 1D histograms (sum to 1 and positive)
222+
# a,b are 1D histograms (sum to 1 and positive)
226223
# M is the ground cost matrix
227224
Wd = ot.emd2(a, b, M) # exact linear program
228225
Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
@@ -232,7 +229,7 @@ Short examples
232229

233230
.. code:: python
234231
235-
# a and b are 1D histograms (sum to 1 and positive)
232+
# a,b are 1D histograms (sum to 1 and positive)
236233
# M is the ground cost matrix
237234
T = ot.emd(a, b, M) # exact linear program
238235
T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
@@ -287,6 +284,10 @@ The contributors to this library are
287284
- `Ievgen Redko <https://ievred.github.io/>`__ (Laplacian DA, JCPOT)
288285
- `Adrien Corenflos <https://adriencorenflos.github.io/>`__ (Sliced
289286
Wasserstein Distance)
287+
- `Tanguy Kerdoncuff <https://hv0nnus.github.io/>`__ (Sampled Gromov
288+
Wasserstein)
289+
- `Minhui Huang <https://mhhuang95.github.io>`__ (Projection Robust
290+
Wasserstein Distance)
290291

291292
This toolbox benefit a lot from open source research and we would like
292293
to thank the following persons for providing some code (in various
@@ -476,6 +477,30 @@ of
476477
measures <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`__,
477478
Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
478479

480+
[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate
481+
Descent Method for Computing the Projection Robust Wasserstein
482+
Distance <http://proceedings.mlr.press/v139/huang21e.html>`__,
483+
Proceedings of the 38th International Conference on Machine Learning
484+
(ICML).
485+
486+
[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov
487+
Wasserstein <https://hal.archives-ouvertes.fr/hal-03232509/document>`__,
488+
Machine Learning Journal (MJL), 2021
489+
490+
[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A.,
491+
& Peyré, G. (2019, April). `Interpolating between optimal transport and
492+
MMD using Sinkhorn
493+
divergences <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`__.
494+
In The 22nd International Conference on Artificial Intelligence and
495+
Statistics (pp. 2681-2690). PMLR.
496+
497+
[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N.,
498+
Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein
499+
distance and its use for
500+
gans <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`__.
501+
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
502+
Recognition (pp. 10648-10656).
503+
479504
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
480505
:target: https://badge.fury.io/py/POT
481506
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
r"""
2+
=================================
3+
Sliced Wasserstein barycenter and gradient flow with PyTorch
4+
=================================
5+
6+
In this exemple we use the pytorch backend to optimize the sliced Wasserstein
7+
loss between two empirical distributions [31].
8+
9+
In the first example one we perform a
10+
gradient flow on the support of a distribution that minimize the sliced
11+
Wassersein distance as poposed in [36].
12+
13+
In the second exemple we optimize with a gradient descent the sliced
14+
Wasserstein barycenter between two distributions as in [31].
15+
16+
[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
17+
measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
18+
19+
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
20+
(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling
21+
via optimal transport and diffusions. In International Conference on
22+
Machine Learning (pp. 4104-4113). PMLR.
23+
24+
25+
"""
26+
# Author: Rémi Flamary <[email protected]>
27+
#
28+
# License: MIT License
29+
30+
31+
# %%
32+
# Loading the data
33+
34+
35+
import numpy as np
36+
import matplotlib.pylab as pl
37+
import torch
38+
import ot
39+
import matplotlib.animation as animation
40+
41+
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
42+
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
43+
44+
sz = I2.shape[0]
45+
XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
46+
47+
x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
48+
x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0
49+
x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
50+
51+
pl.figure(1, (8, 4))
52+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
53+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
54+
55+
# %%
56+
# Sliced Wasserstein gradient flow with Pytorch
57+
# ---------------------------------------------
58+
59+
60+
device = "cuda" if torch.cuda.is_available() else "cpu"
61+
62+
# use pyTorch for our data
63+
x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True)
64+
x2_torch = torch.tensor(x2).to(device=device)
65+
66+
67+
lr = 1e3
68+
nb_iter_max = 100
69+
70+
x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
71+
72+
loss_iter = []
73+
74+
# generator for random permutations
75+
gen = torch.Generator()
76+
gen.manual_seed(42)
77+
78+
for i in range(nb_iter_max):
79+
80+
loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
81+
82+
loss_iter.append(loss.clone().detach().cpu().numpy())
83+
loss.backward()
84+
85+
# performs a step of projected gradient descent
86+
with torch.no_grad():
87+
grad = x1_torch.grad
88+
x1_torch -= grad * lr / (1 + i / 5e1) # step
89+
x1_torch.grad.zero_()
90+
x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()
91+
92+
xb = x1_torch.clone().detach().cpu().numpy()
93+
94+
pl.figure(2, (8, 4))
95+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
96+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
97+
pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$')
98+
pl.title('Sliced Wasserstein gradient flow')
99+
pl.legend()
100+
ax = pl.axis()
101+
102+
# %%
103+
# Animate trajectories of the gradient flow along iteration
104+
# -------------------------------------------------------
105+
106+
pl.figure(3, (8, 4))
107+
108+
109+
def _update_plot(i):
110+
pl.clf()
111+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
112+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
113+
pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
114+
pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i))
115+
pl.axis(ax)
116+
return 1
117+
118+
119+
ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
120+
121+
# %%
122+
# Compute the Sliced Wasserstein Barycenter
123+
#
124+
x1_torch = torch.tensor(x1).to(device=device)
125+
x3_torch = torch.tensor(x3).to(device=device)
126+
xbinit = np.random.randn(500, 2) * 10 + 16
127+
xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
128+
129+
lr = 1e3
130+
nb_iter_max = 100
131+
132+
x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
133+
134+
loss_iter = []
135+
136+
# generator for random permutations
137+
gen = torch.Generator()
138+
gen.manual_seed(42)
139+
140+
alpha = 0.5
141+
142+
for i in range(nb_iter_max):
143+
144+
loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \
145+
+ (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen)
146+
147+
loss_iter.append(loss.clone().detach().cpu().numpy())
148+
loss.backward()
149+
150+
# performs a step of projected gradient descent
151+
with torch.no_grad():
152+
grad = xbary_torch.grad
153+
xbary_torch -= grad * lr # / (1 + i / 5e1) # step
154+
xbary_torch.grad.zero_()
155+
x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()
156+
157+
xb = xbary_torch.clone().detach().cpu().numpy()
158+
159+
pl.figure(4, (8, 4))
160+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$')
161+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
162+
pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter')
163+
pl.title('Sliced Wasserstein barycenter')
164+
pl.legend()
165+
ax = pl.axis()
166+
167+
168+
# %%
169+
# Animate trajectories of the barycenter along gradient descent
170+
# -------------------------------------------------------
171+
172+
pl.figure(5, (8, 4))
173+
174+
175+
def _update_plot(i):
176+
pl.clf()
177+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
178+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
179+
pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
180+
pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i))
181+
pl.axis(ax)
182+
return 1
183+
184+
185+
ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)

0 commit comments

Comments
 (0)