Skip to content

Commit 1f6d2df

Browse files
eloitanguyrflamarycedricvincentcuaz
authored
[MRG] OT barycenters for generic transport costs (#715)
* ot.lp reorganise to avoid def in __init__ * pr number + enabled pre-commit * added barycenter.py imports * fixed wrong import in ot.gmm * ruff fix attempt * removed ot bar contribs -> only o.lp reorganisation in this PR * add check_number_threads to ot/lp/__init__.py __all__ * update releases * made barycenter_solvers and network_simplex hidden + deprecated ot.lp.cvx * fix ref to lp.cvx in test * lp.cvx now imports barycenter and gives a warnings.warning * cvx import barycenter * added fixed-point barycenter function to ot.lp._barycenter_solvers_ * ot bar demo * ot bar doc * doc fixes + ot bar coverage * python 3.13 in test workflow + added ggmot barycenter (WIP) * fixed github action file * ot bar doc + test coverage * examples: ot bar with projections onto circles + gmm ot bar * releases + readme + docs update * ref fix * implementation comments * (WIP) added true barycenter fixed-point algorithm with updated tests and examples * test and fixes * no jax or tf support for free_support_generic_costs due to array assignment * updated gmm bar colours * applying PR comments (still some doc rendering issues to fix) * doc fixes * changed stop threshold method in groundBary and changed the example to have different support sizes * fix ground bary stop threshold formula * uncommited ground bar fix * merge master into dev * email fix * added gmm barycenter example to gaussian-gmm minigallery * more bar iterations for auto ground bary test --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Cédric Vincent-Cuaz <[email protected]>
1 parent d121bf8 commit 1f6d2df

18 files changed

+1662
-18
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The contributors to this library are:
4747
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW,
4848
semi-relaxed FGW, quantized FGW, partial FGW)
4949
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein
50-
Barycenters, GMMOT)
50+
Barycenters, GMMOT, Barycenters for General Transport Costs)
5151
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
5252
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
5353
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ POT provides the following generic OT solvers:
7070
* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
7171
[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
7272
* Fused unbalanced Gromov-Wasserstein [70].
73+
* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77]
74+
* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77]
7375

7476
POT provides the following Machine Learning related solvers:
7577

@@ -436,3 +438,5 @@ Artificial Intelligence.
436438
[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.
437439

438440
[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations.
441+
442+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
88
- Automatic PR labeling and release file update check (PR #704)
99
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
10+
- Implement fixed-point solver for OT barycenters with generic cost functions
11+
(generalizes `ot.lp.free_support_barycenter`), with example. (PR #715)
12+
- Implement fixed-point solver for barycenters between GMMs (PR #715), with example.
1013
- Fix warning raise when import the library (PR #716)
1114
- Implement projected gradient descent solvers for entropic partial FGW (PR #702)
1215
- Fix documentation in the module `ot.gaussian` (PR #718)
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
=====================================
4+
OT Barycenter with Generic Costs Demo
5+
=====================================
6+
7+
This example illustrates the computation of an Optimal Transport Barycenter for
8+
a ground cost that is not a power of a norm. We take the example of ground costs
9+
:math:`c_k(x, y) = \lambda_k\|P_k(x)-y\|_2^2`, where :math:`P_k` is the
10+
(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A
11+
barycenter is defined ([77]) as a minimiser of the energy :math:`V(\mu) = \sum_k
12+
\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter
13+
measure, the measures :math:`\nu_k` are the target measures and
14+
:math:`\mathcal{T}_{c_k}` is the OT cost for ground cost :math:`c_k`. This is an
15+
example of the fixed-point barycenter solver introduced in [77] which
16+
generalises [20] and [43].
17+
18+
The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
19+
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
20+
:math:`x` with Pytorch.
21+
22+
We compare two algorithms from [77]: the first ([77], Algorithm 2,
23+
'true_fixed_point' in POT) has convergence guarantees but the iterations may
24+
increase in support size and thus require more computational resources. The
25+
second ([77], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified
26+
heuristic that imposes a fixed support size for the barycenter and fixed
27+
weights.
28+
29+
We initialise both algorithms with a support size of 136, computing a barycenter
30+
between measures with uniform weights and 50 points.
31+
32+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
33+
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
34+
(2024)
35+
36+
[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
37+
Barycenters. InternationalConference in Machine Learning
38+
39+
[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in
40+
Wasserstein space. Journal of Mathematical Analysis and Applications 441.2
41+
(2016): 744-762.
42+
43+
"""
44+
45+
# Author: Eloi Tanguy <[email protected]>
46+
#
47+
# License: MIT License
48+
49+
# sphinx_gallery_thumbnail_number = 1
50+
51+
# %%
52+
# Generate data
53+
import torch
54+
from torch.optim import Adam
55+
from ot.utils import dist
56+
import numpy as np
57+
from ot.lp import free_support_barycenter_generic_costs
58+
import matplotlib.pyplot as plt
59+
from time import time
60+
61+
62+
torch.manual_seed(42)
63+
64+
n = 136 # number of points of the barycentre
65+
d = 2 # dimensions of the original measure
66+
K = 4 # number of measures to barycentre
67+
m_list = [49, 50, 51, 51] # number of points of the measures
68+
b_list = [torch.ones(m) / m for m in m_list] # weights of the 4 measures
69+
weights = torch.ones(K) / K # weights for the barycentre
70+
stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo
71+
72+
73+
# map R^2 -> R^2 projection onto circle
74+
def proj_circle(X, origin, radius):
75+
diffs = X - origin[None, :]
76+
norms = torch.norm(diffs, dim=1)
77+
return origin[None, :] + radius * diffs / norms[:, None]
78+
79+
80+
# circles on which to project
81+
origin1 = torch.tensor([-1.0, -1.0])
82+
origin2 = torch.tensor([-1.0, 2.0])
83+
origin3 = torch.tensor([2.0, 2.0])
84+
origin4 = torch.tensor([2.0, -1.0])
85+
r = np.sqrt(2)
86+
P_list = [
87+
lambda X: proj_circle(X, origin1, r),
88+
lambda X: proj_circle(X, origin2, r),
89+
lambda X: proj_circle(X, origin3, r),
90+
lambda X: proj_circle(X, origin4, r),
91+
]
92+
93+
# measures to barycentre are projections of different random circles
94+
# onto the K circles
95+
Y_list = []
96+
for k in range(K):
97+
t = torch.rand(m_list[k]) * 2 * np.pi
98+
X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1)
99+
X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :]
100+
Y_list.append(P_list[k](X_temp))
101+
102+
103+
# %%
104+
# Define costs and ground barycenter function
105+
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
106+
# (n, n_k) matrix of costs
107+
def c1(x, y):
108+
return dist(P_list[0](x), y)
109+
110+
111+
def c2(x, y):
112+
return dist(P_list[1](x), y)
113+
114+
115+
def c3(x, y):
116+
return dist(P_list[2](x), y)
117+
118+
119+
def c4(x, y):
120+
return dist(P_list[3](x), y)
121+
122+
123+
cost_list = [c1, c2, c3, c4]
124+
125+
126+
# batched total ground cost function for candidate points x (n, d)
127+
# for computation of the ground barycenter B with gradient descent
128+
def C(x, y):
129+
"""
130+
Computes the barycenter cost for candidate points x (n, d) and
131+
measure supports y: List(n, d_k).
132+
"""
133+
n = x.shape[0]
134+
K = len(y)
135+
out = torch.zeros(n)
136+
for k in range(K):
137+
out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1)
138+
return out
139+
140+
141+
# ground barycenter function
142+
def B(y, its=150, lr=1, stop_threshold=stop_threshold):
143+
"""
144+
Computes the ground barycenter for measure supports y: List(n, d_k).
145+
Output: (n, d) array
146+
"""
147+
x = torch.randn(y[0].shape[0], d)
148+
x.requires_grad_(True)
149+
opt = Adam([x], lr=lr)
150+
for _ in range(its):
151+
x_prev = x.data.clone()
152+
opt.zero_grad()
153+
loss = torch.sum(C(x, y))
154+
loss.backward()
155+
opt.step()
156+
diff = torch.sum((x.data - x_prev) ** 2)
157+
if diff < stop_threshold:
158+
break
159+
return x
160+
161+
162+
# %%
163+
# Compute the barycenter measure with the true fixed-point algorithm
164+
fixed_point_its = 5
165+
torch.manual_seed(42)
166+
X_init = torch.rand(n, d)
167+
t0 = time()
168+
X_bar, a_bar, log_dict = free_support_barycenter_generic_costs(
169+
Y_list,
170+
b_list,
171+
X_init,
172+
cost_list,
173+
B,
174+
numItermax=fixed_point_its,
175+
stopThr=stop_threshold,
176+
method="true_fixed_point",
177+
log=True,
178+
clean_measure=True,
179+
)
180+
dt_true_fixed_point = time() - t0
181+
182+
# %%
183+
# Compute the barycenter measure with the barycentric (default) algorithm
184+
fixed_point_its = 5
185+
torch.manual_seed(42)
186+
X_init = torch.rand(n, d)
187+
t0 = time()
188+
X_bar2, log_dict2 = free_support_barycenter_generic_costs(
189+
Y_list,
190+
b_list,
191+
X_init,
192+
cost_list,
193+
B,
194+
numItermax=fixed_point_its,
195+
stopThr=stop_threshold,
196+
log=True,
197+
)
198+
dt_barycentric = time() - t0
199+
200+
# %%
201+
# Plot Barycenters (Iteration 3)
202+
alpha = 0.4
203+
s = 80
204+
labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
205+
206+
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
207+
208+
# Plot for the true fixed-point algorithm
209+
for Y, label in zip(Y_list, labels):
210+
axes[0].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s)
211+
axes[0].scatter(
212+
*(X_bar.detach().numpy()).T,
213+
label="Barycenter",
214+
c="black",
215+
alpha=alpha * a_bar.numpy() / np.max(a_bar.numpy()),
216+
s=s,
217+
)
218+
axes[0].set_title(
219+
"True Fixed-Point Algorithm\n"
220+
f"Support size: {a_bar.shape[0]}\n"
221+
f"Barycenter cost: {log_dict['V_list'][-1].item():.6f}\n"
222+
f"Computation time {dt_true_fixed_point:.4f}s"
223+
)
224+
axes[0].axis("equal")
225+
axes[0].axis("off")
226+
axes[0].legend()
227+
228+
# Plot for the heuristic algorithm
229+
for Y, label in zip(Y_list, labels):
230+
axes[1].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s)
231+
axes[1].scatter(
232+
*(X_bar2.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s
233+
)
234+
axes[1].set_title(
235+
"Heuristic Barycentric Algorithm\n"
236+
f"Support size: {X_bar2.shape[0]}\n"
237+
f"Barycenter cost: {log_dict2['V_list'][-1].item():.6f}\n"
238+
f"Computation time {dt_barycentric:.4f}s"
239+
)
240+
axes[1].axis("equal")
241+
axes[1].axis("off")
242+
axes[1].legend()
243+
244+
plt.tight_layout()
245+
246+
# %%
247+
# Plot energy convergence and support sizes
248+
size = 3
249+
n_plots = 4
250+
fig, axes = plt.subplots(1, n_plots, figsize=(size * n_plots, size))
251+
V_list = [V.item() for V in log_dict["V_list"]]
252+
V_list2 = [V.item() for V in log_dict2["V_list"]]
253+
diff = np.array(V_list2) - np.array(V_list)
254+
255+
# Plot for True Fixed-Point Algorithm
256+
axes[0].plot(V_list, lw=5, alpha=0.6)
257+
axes[0].scatter(range(len(V_list)), V_list, color="blue", alpha=0.8, s=100)
258+
axes[0].set_title("True Fixed-Point Algorithm")
259+
axes[0].set_xlabel("Iteration")
260+
axes[0].set_ylabel("Barycenter Energy")
261+
axes[0].set_yscale("log")
262+
axes[0].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
263+
264+
# Plot for Heuristic Barycentric Algorithm
265+
axes[1].plot(V_list2, lw=5, alpha=0.6)
266+
axes[1].scatter(range(len(V_list2)), V_list2, color="blue", alpha=0.8, s=100)
267+
axes[1].set_title("Heuristic Barycentric Algorithm")
268+
axes[1].set_xlabel("Iteration")
269+
axes[1].set_ylabel("Barycenter Energy")
270+
axes[1].set_yscale("log")
271+
axes[1].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
272+
273+
# Plot difference between the two
274+
axes[2].plot(diff, lw=5, alpha=0.6)
275+
axes[2].scatter(range(len(diff)), diff, color="blue", alpha=0.8, s=100)
276+
axes[2].set_title("Heuristic Fixed-Point Energy - True")
277+
axes[2].set_xlabel("Iteration")
278+
axes[2].set_ylabel("$V_{\\mathrm{heuristic}} - V_{\\mathrm{true}}$")
279+
axes[2].set_yscale("log")
280+
axes[2].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
281+
282+
# plot support sizes
283+
support_sizes = [Xi.shape[0] for Xi in log_dict["X_list"]]
284+
support_sizes2 = [Xi.shape[0] for Xi in log_dict2["X_list"]]
285+
286+
axes[3].plot(support_sizes, color="C0", lw=5, alpha=0.6, label="True FP")
287+
axes[3].scatter(
288+
range(len(support_sizes)), support_sizes, color="blue", alpha=0.8, s=100
289+
)
290+
axes[3].plot(support_sizes2, color="red", lw=5, alpha=0.6, label="Heur. FP")
291+
axes[3].scatter(
292+
range(len(support_sizes2)), support_sizes2, color="red", alpha=0.8, s=100
293+
)
294+
axes[3].legend(loc="best")
295+
axes[3].set_xlabel("Iteration")
296+
axes[3].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
297+
axes[3].set_title("Support Sizes")
298+
299+
plt.tight_layout()
300+
plt.show()
301+
302+
# %%

examples/barycenters/plot_generalized_free_support_barycenter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
1818
"""
1919

20-
# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
20+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
2121
#
2222
# License: MIT License
2323

0 commit comments

Comments
 (0)