Skip to content

Commit a5cdd5b

Browse files
authored
[MRG] Gaussian Gromov wasserstein solvers (#498)
* gaussian gromov distance * gaussian gromov distance * debug mapping function * cleanup debug and pep8 * add test for diferent source saizes * add transport classes ith GW * upate linera mapping exmaple * pep8 * gaussian gromov with sskew signe alignment * add sign_eigs to DA class * change in release file * debug code coverage * documeation fix
1 parent f98698e commit a5cdd5b

File tree

10 files changed

+626
-20
lines changed

10 files changed

+626
-20
lines changed

.github/workflows/build_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
pip install pytest pytest-cov
4141
- name: Run tests
4242
run: |
43-
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov-report=xml
43+
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov=./ --cov-report=xml
4444
- name: Upload coverage reports to Codecov with GitHub Action
4545
uses: codecov/codecov-action@v3
4646

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
324324
[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).
325325

326326
[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.
327+
328+
[57] Delon, J., Desolneux, A., & Salmona, A. (2022). [Gromov–Wasserstein
329+
distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4),
330+
1178-1198.
331+

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
This new release contains several new features and bug fixes.
77

8-
New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default).
8+
New features include a new submodule `ot.gnn` that contains two new Graph neural network layers (compatible with [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/)) for template-based pooling of graphs with an example on [graph classification](https://pythonot.github.io/master/auto_examples/gromov/plot_gnn_TFGW.html). Related to this, we also now provide FGW and semi relaxed FGW solvers for which the resulting loss is differentiable w.r.t. the parameter `alpha`. Other contributions on the (F)GW front include a new solver for the Proximal Point algorithm [that can be used to solve entropic GW problems](https://pythonot.github.io/master/auto_examples/gromov/plot_fgw_solvers.html) (using the parameter `solver="PPA"`), new solvers for entropic FGW barycenters, novels Sinkhorn-based solvers for entropic semi-relaxed (F)GW, the possibility to provide a warm-start to the solvers, and optional marginal weights of the samples (uniform weights ar used by default). Finally we added in the submodule `ot.gaussian` and `ot.da` new loss and mapping estimators for the Gaussian Gromov-Wasserstein that can be used as a fast alternative to GW and estimates linear mappings between unregistered spaces that can potentially have different size (See the update [linear mapping example](https://pythonot.github.io/master/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) for an illustration).
99

1010
We also provide a new solver for the [Entropic Wasserstein Component Analysis](https://pythonot.github.io/master/auto_examples/others/plot_EWCA.html) that is a generalization of the celebrated PCA taking into account the local neighborhood of the samples. We also now have a new solver in `ot.smooth` for the [sparsity-constrained OT (last plot)](https://pythonot.github.io/master/auto_examples/plot_OT_1D_smooth.html) that can be used to find regularized OT plans with sparsity constraints. Finally we have a first multi-marginal solver for regular 1D distributions with a Monge loss (see [here](https://pythonot.github.io/master/auto_examples/others/plot_dmmot.html)).
1111

@@ -15,6 +15,7 @@ Many other bugs and issues have been fixed and we want to thank all the contribu
1515

1616

1717
#### New features
18+
- Gaussian Gromov Wasserstein loss and mapping (PR #498)
1819
- Template-based Fused Gromov Wasserstein GNN layer in `ot.gnn` (PR #488)
1920
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
2021
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)

examples/domain-adaptation/plot_otda_linear_mapping.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# License: MIT License
1414

1515
# sphinx_gallery_thumbnail_number = 2
16+
17+
#%%
1618
import os
1719
from pathlib import Path
1820

@@ -55,27 +57,43 @@
5557
plt.figure(1, (5, 5))
5658
plt.plot(xs[:, 0], xs[:, 1], '+')
5759
plt.plot(xt[:, 0], xt[:, 1], 'o')
58-
60+
plt.legend(('Source', 'Target'))
61+
plt.title('Source and target distributions')
62+
plt.show()
5963

6064
##############################################################################
6165
# Estimate linear mapping and transport
6266
# -------------------------------------
6367

68+
69+
# Gaussian (linear) Monge mapping estimation
6470
Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)
6571

6672
xst = xs.dot(Ae) + be
6773

74+
# Gaussian (linear) GW mapping estimation
75+
Agw, bgw = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(xs, xt)
76+
77+
xstgw = xs.dot(Agw) + bgw
6878

6979
##############################################################################
7080
# Plot transported samples
7181
# ------------------------
7282

73-
plt.figure(1, (5, 5))
83+
plt.figure(2, (10, 5))
7484
plt.clf()
85+
plt.subplot(1, 2, 1)
7586
plt.plot(xs[:, 0], xs[:, 1], '+')
7687
plt.plot(xt[:, 0], xt[:, 1], 'o')
7788
plt.plot(xst[:, 0], xst[:, 1], '+')
78-
89+
plt.legend(('Source', 'Target', 'Transp. Monge'), loc=0)
90+
plt.title('Transported samples with Monge')
91+
plt.subplot(1, 2, 2)
92+
plt.plot(xs[:, 0], xs[:, 1], '+')
93+
plt.plot(xt[:, 0], xt[:, 1], 'o')
94+
plt.plot(xstgw[:, 0], xstgw[:, 1], '+')
95+
plt.legend(('Source', 'Target', 'Transp. GW'), loc=0)
96+
plt.title('Transported samples with Gaussian GW')
7997
plt.show()
8098

8199
##############################################################################
@@ -112,8 +130,8 @@ def minmax(img):
112130
# Estimate mapping and adapt
113131
# ----------------------------
114132

133+
# Monge mapping
115134
mapping = ot.da.LinearTransport()
116-
117135
mapping.fit(Xs=X1, Xt=X2)
118136

119137

@@ -123,31 +141,53 @@ def minmax(img):
123141
I1t = minmax(mat2im(xst, I1.shape))
124142
I2t = minmax(mat2im(xts, I2.shape))
125143

144+
# gaussian GW mapping
145+
146+
mapping = ot.da.LinearGWTransport()
147+
mapping.fit(Xs=X1, Xt=X2)
148+
149+
150+
xstgw = mapping.transform(Xs=X1)
151+
xtsgw = mapping.inverse_transform(Xt=X2)
152+
153+
I1tgw = minmax(mat2im(xstgw, I1.shape))
154+
I2tgw = minmax(mat2im(xtsgw, I2.shape))
155+
126156
# %%
127157

128158

129159
##############################################################################
130160
# Plot transformed images
131161
# -----------------------
132162

133-
plt.figure(2, figsize=(10, 7))
163+
plt.figure(3, figsize=(14, 7))
134164

135-
plt.subplot(2, 2, 1)
165+
plt.subplot(2, 3, 1)
136166
plt.imshow(I1)
137167
plt.axis('off')
138168
plt.title('Im. 1')
139169

140-
plt.subplot(2, 2, 2)
170+
plt.subplot(2, 3, 4)
141171
plt.imshow(I2)
142172
plt.axis('off')
143173
plt.title('Im. 2')
144174

145-
plt.subplot(2, 2, 3)
175+
plt.subplot(2, 3, 2)
146176
plt.imshow(I1t)
147177
plt.axis('off')
148-
plt.title('Mapping Im. 1')
178+
plt.title('Monge mapping Im. 1')
149179

150-
plt.subplot(2, 2, 4)
180+
plt.subplot(2, 3, 5)
151181
plt.imshow(I2t)
152182
plt.axis('off')
153-
plt.title('Inverse mapping Im. 2')
183+
plt.title('Inverse Monge mapping Im. 2')
184+
185+
plt.subplot(2, 3, 3)
186+
plt.imshow(I1tgw)
187+
plt.axis('off')
188+
plt.title('Gaussian GW mapping Im. 1')
189+
190+
plt.subplot(2, 3, 6)
191+
plt.imshow(I2tgw)
192+
plt.axis('off')
193+
plt.title('Inverse Gaussian GW mapping Im. 2')

ot/backend.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ def minimum(self, a, b):
338338
"""
339339
raise NotImplementedError()
340340

341+
def sign(self, a):
342+
r""" Returns an element-wise indication of the sign of a number.
343+
344+
This function follows the api from :any:`numpy.sign`
345+
346+
See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html
347+
"""
348+
raise NotImplementedError()
349+
341350
def dot(self, a, b):
342351
r"""
343352
Returns the dot product of two tensors.
@@ -858,6 +867,16 @@ def sqrtm(self, a):
858867
"""
859868
raise NotImplementedError()
860869

870+
def eigh(self, a):
871+
r"""
872+
Computes the eigenvalues and eigenvectors of a symmetric tensor.
873+
874+
This function follows the api from :any:`scipy.linalg.eigh`.
875+
876+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html
877+
"""
878+
raise NotImplementedError()
879+
861880
def kl_div(self, p, q, eps=1e-16):
862881
r"""
863882
Computes the Kullback-Leibler divergence.
@@ -1047,6 +1066,9 @@ def maximum(self, a, b):
10471066
def minimum(self, a, b):
10481067
return np.minimum(a, b)
10491068

1069+
def sign(self, a):
1070+
return np.sign(a)
1071+
10501072
def dot(self, a, b):
10511073
return np.dot(a, b)
10521074

@@ -1253,6 +1275,9 @@ def sqrtm(self, a):
12531275
L, V = np.linalg.eigh(a)
12541276
return (V * np.sqrt(L)[None, :]) @ V.T
12551277

1278+
def eigh(self, a):
1279+
return np.linalg.eigh(a)
1280+
12561281
def kl_div(self, p, q, eps=1e-16):
12571282
return np.sum(p * np.log(p / q + eps))
12581283

@@ -1415,6 +1440,9 @@ def maximum(self, a, b):
14151440
def minimum(self, a, b):
14161441
return jnp.minimum(a, b)
14171442

1443+
def sign(self, a):
1444+
return jnp.sign(a)
1445+
14181446
def dot(self, a, b):
14191447
return jnp.dot(a, b)
14201448

@@ -1631,6 +1659,9 @@ def sqrtm(self, a):
16311659
L, V = jnp.linalg.eigh(a)
16321660
return (V * jnp.sqrt(L)[None, :]) @ V.T
16331661

1662+
def eigh(self, a):
1663+
return jnp.linalg.eigh(a)
1664+
16341665
def kl_div(self, p, q, eps=1e-16):
16351666
return jnp.sum(p * jnp.log(p / q + eps))
16361667

@@ -1829,6 +1860,9 @@ def minimum(self, a, b):
18291860
else:
18301861
return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
18311862

1863+
def sign(self, a):
1864+
return torch.sign(a)
1865+
18321866
def dot(self, a, b):
18331867
return torch.matmul(a, b)
18341868

@@ -2106,6 +2140,9 @@ def sqrtm(self, a):
21062140
L, V = torch.linalg.eigh(a)
21072141
return (V * torch.sqrt(L)[None, :]) @ V.T
21082142

2143+
def eigh(self, a):
2144+
return torch.linalg.eigh(a)
2145+
21092146
def kl_div(self, p, q, eps=1e-16):
21102147
return torch.sum(p * torch.log(p / q + eps))
21112148

@@ -2248,6 +2285,9 @@ def maximum(self, a, b):
22482285
def minimum(self, a, b):
22492286
return cp.minimum(a, b)
22502287

2288+
def sign(self, a):
2289+
return cp.sign(a)
2290+
22512291
def abs(self, a):
22522292
return cp.abs(a)
22532293

@@ -2495,6 +2535,9 @@ def sqrtm(self, a):
24952535
L, V = cp.linalg.eigh(a)
24962536
return (V * cp.sqrt(L)[None, :]) @ V.T
24972537

2538+
def eigh(self, a):
2539+
return cp.linalg.eigh(a)
2540+
24982541
def kl_div(self, p, q, eps=1e-16):
24992542
return cp.sum(p * cp.log(p / q + eps))
25002543

@@ -2642,6 +2685,9 @@ def maximum(self, a, b):
26422685
def minimum(self, a, b):
26432686
return tnp.minimum(a, b)
26442687

2688+
def sign(self, a):
2689+
return tnp.sign(a)
2690+
26452691
def dot(self, a, b):
26462692
if len(b.shape) == 1:
26472693
if len(a.shape) == 1:
@@ -2902,6 +2948,9 @@ def sqrtm(self, a):
29022948
L, V = tf.linalg.eigh(a)
29032949
return (V * tf.sqrt(L)[None, :]) @ V.T
29042950

2951+
def eigh(self, a):
2952+
return tf.linalg.eigh(a)
2953+
29052954
def kl_div(self, p, q, eps=1e-16):
29062955
return tnp.sum(p * tnp.log(p / q + eps))
29072956

0 commit comments

Comments
 (0)