Skip to content

Commit 5ead79b

Browse files
xzyu02rflamaryronakrm
authored
[MRG] Efficient Discrete Multi Marginal Optimal Transport (#454)
* add demd.py to ot, add plot_demd_*.py to examples, updated init.py in ot, build failed need to fix * update REAMDME.md with citation to iclr23 paper and example link * chaneg directory of examples, build successful * fix small latex bug * update all.rst, examples and demd have passed pep8 and pyflake * add more detailed comments for examples * TODO: test module for demd, wrong demd index after build * add test module * add contributors * pass pyflake checks, pass pep8 * added the PR to the RELEASES.md file * temporal changes with logs * init changes * merge examples, demd -> lp.dmmot * bug fix in plot_dmmot, some commenting/documenting edits * dmmot example cleanup, some comments/plotting edits * add dist_monge method * all dmmot methods takes (n, d) shape A as input (follows POT style) * passed pep8 and pyflake checks * resolve test fail issue * fix pep8 error * resolve issues from last review, pyflake and pep8 checked * add lr decay * add more examples, ground cost options, test for uniqueness * remove additional experiment setting, not needed in this PR * fixed line 14 1 blank line * fix gradient computation link * Update ot/lp/dmmot.py Store input variable instead of copying it --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Ronak <[email protected]>
1 parent a879690 commit 5ead79b

File tree

7 files changed

+590
-1
lines changed

7 files changed

+590
-1
lines changed

CONTRIBUTORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ The contributors to this library are:
4242
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
4343
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
4444
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
45+
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
46+
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
4547
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
4648

4749
## Acknowledgments

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples):
4343
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
4444
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
4545
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
46+
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
4647
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
4748

4849
POT provides the following Machine Learning related solvers:
@@ -319,3 +320,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
319320
[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35.
320321

321322
[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804).
323+
324+
[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).
325+
326+
[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
1919
- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486)
2020

21+
- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454)
22+
2123
#### Closed issues
2224

2325
- Fix change in scipy API for `cdist` (PR #487)

examples/others/plot_dmmot.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
===============================================================================
4+
Computing d-dimensional Barycenters via d-MMOT
5+
===============================================================================
6+
7+
When the cost is discretized (Monge), the d-MMOT solver can more quickly
8+
compute and minimize the distance between many distributions without the need
9+
for intermediate barycenter computations. This example compares the time to
10+
identify, and the quality of, solutions for the d-MMOT problem using a
11+
primal/dual algorithm and classical LP barycenter approaches.
12+
"""
13+
14+
# Author: Ronak Mehta <[email protected]>
15+
# Xizheng Yu <[email protected]>
16+
#
17+
# License: MIT License
18+
19+
# %%
20+
# Generating 2 distributions
21+
# -----
22+
import numpy as np
23+
import matplotlib.pyplot as pl
24+
import ot
25+
26+
np.random.seed(0)
27+
28+
n = 100
29+
d = 2
30+
# Gaussian distributions
31+
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std
32+
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
33+
A = np.vstack((a1, a2)).T
34+
x = np.arange(n, dtype=np.float64)
35+
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
36+
37+
pl.figure(1, figsize=(6.4, 3))
38+
pl.plot(x, a1, 'b', label='Source distribution')
39+
pl.plot(x, a2, 'r', label='Target distribution')
40+
pl.legend()
41+
42+
# %%
43+
# Minimize the distances among distributions, identify the Barycenter
44+
# -----
45+
# The objective being minimized is different for both methods, so the objective
46+
# values cannot be compared.
47+
48+
# L2 Iteration
49+
weights = np.ones(d) / d
50+
l2_bary = A.dot(weights)
51+
52+
print('LP Iterations:')
53+
weights = np.ones(d) / d
54+
lp_bary, lp_log = ot.lp.barycenter(
55+
A, M, weights, solver='interior-point', verbose=False, log=True)
56+
print('Time\t: ', ot.toc(''))
57+
print('Obj\t: ', lp_log['fun'])
58+
59+
print('')
60+
print('Discrete MMOT Algorithm:')
61+
ot.tic()
62+
barys, log = ot.lp.dmmot_monge_1dgrid_optimize(
63+
A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True)
64+
dmmot_obj = log['primal objective']
65+
print('Time\t: ', ot.toc(''))
66+
print('Obj\t: ', dmmot_obj)
67+
68+
# %%
69+
# Compare Barycenters in both methods
70+
# -----
71+
pl.figure(1, figsize=(6.4, 3))
72+
for i in range(len(barys)):
73+
if i == 0:
74+
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
75+
else:
76+
continue
77+
# pl.plot(x, barys[i], 'g-*')
78+
pl.plot(x, lp_bary, label='LP Barycenter')
79+
pl.plot(x, l2_bary, label='L2 Barycenter')
80+
pl.plot(x, a1, 'b', label='Source distribution')
81+
pl.plot(x, a2, 'r', label='Target distribution')
82+
pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver')
83+
pl.legend()
84+
85+
86+
# %%
87+
# More than 2 distributions
88+
# --------------------------------------------------
89+
# Generate 7 pseudorandom gaussian distributions with 50 bins.
90+
n = 50 # nb bins
91+
d = 7
92+
vecsize = n * d
93+
94+
data = []
95+
for i in range(d):
96+
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
97+
a = ot.datasets.make_1D_gauss(n, m=m, s=5)
98+
data.append(a)
99+
100+
x = np.arange(n, dtype=np.float64)
101+
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
102+
A = np.vstack(data).T
103+
104+
pl.figure(1, figsize=(6.4, 3))
105+
for i in range(len(data)):
106+
pl.plot(x, data[i])
107+
108+
pl.title('Distributions')
109+
pl.legend()
110+
111+
# %%
112+
# Minimizing Distances Among Many Distributions
113+
# ---------------
114+
# The objective being minimized is different for both methods, so the objective
115+
# values cannot be compared.
116+
117+
# Perform gradient descent optimization using the d-MMOT method.
118+
barys = ot.lp.dmmot_monge_1dgrid_optimize(
119+
A, niters=3000, lr_init=1e-4, lr_decay=0.997)
120+
121+
# after minimization, any distribution can be used as a estimate of barycenter.
122+
bary = barys[0]
123+
124+
# Compute 1D Wasserstein barycenter using the L2/LP method
125+
weights = ot.unif(d)
126+
l2_bary = A.dot(weights)
127+
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
128+
verbose=False, log=True)
129+
130+
# %%
131+
# Compare Barycenters in both methods
132+
# ---------
133+
pl.figure(1, figsize=(6.4, 3))
134+
pl.plot(x, bary, 'g-*', label='Discrete MMOT')
135+
pl.plot(x, l2_bary, 'k', label='L2 Barycenter')
136+
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
137+
pl.title('Barycenters')
138+
pl.legend()
139+
140+
# %%
141+
# Compare with original distributions
142+
# ---------
143+
pl.figure(1, figsize=(6.4, 3))
144+
for i in range(len(data)):
145+
pl.plot(x, data[i])
146+
for i in range(len(barys)):
147+
if i == 0:
148+
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
149+
else:
150+
continue
151+
# pl.plot(x, barys[i], 'g')
152+
pl.plot(x, l2_bary, 'k^', label='L2')
153+
pl.plot(x, lp_bary, 'o', color='grey', label='LP')
154+
pl.title('Barycenters')
155+
pl.legend()
156+
pl.show()
157+
158+
# %%

ot/lp/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from . import cvx
1919
from .cvx import barycenter
20+
from .dmmot import *
2021

2122
# import compiled emd
2223
from .emd_wrap import emd_c, check_result, emd_1d_sorted
@@ -30,7 +31,8 @@
3031

3132
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
3233
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
33-
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
34+
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle',
35+
'discrete_mmot', 'discrete_mmot_converge']
3436

3537

3638
def check_number_threads(numThreads):

0 commit comments

Comments
 (0)