|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +r""" |
| 3 | +===================================== |
| 4 | +OT Barycenter with Generic Costs Demo |
| 5 | +===================================== |
| 6 | +
|
| 7 | +This example illustrates the computation of an Optimal Transport Barycenter for |
| 8 | +a ground cost that is not a power of a norm. We take the example of ground costs |
| 9 | +:math:`c_k(x, y) = \lambda_k\|P_k(x)-y\|_2^2`, where :math:`P_k` is the |
| 10 | +(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A |
| 11 | +barycenter is defined ([77]) as a minimiser of the energy :math:`V(\mu) = \sum_k |
| 12 | +\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter |
| 13 | +measure, the measures :math:`\nu_k` are the target measures and |
| 14 | +:math:`\mathcal{T}_{c_k}` is the OT cost for ground cost :math:`c_k`. This is an |
| 15 | +example of the fixed-point barycenter solver introduced in [77] which |
| 16 | +generalises [20] and [43]. |
| 17 | +
|
| 18 | +The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in |
| 19 | +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over |
| 20 | +:math:`x` with Pytorch. |
| 21 | +
|
| 22 | +We compare two algorithms from [77]: the first ([77], Algorithm 2, |
| 23 | +'true_fixed_point' in POT) has convergence guarantees but the iterations may |
| 24 | +increase in support size and thus require more computational resources. The |
| 25 | +second ([77], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified |
| 26 | +heuristic that imposes a fixed support size for the barycenter and fixed |
| 27 | +weights. |
| 28 | +
|
| 29 | +We initialise both algorithms with a support size of 136, computing a barycenter |
| 30 | +between measures with uniform weights and 50 points. |
| 31 | +
|
| 32 | +[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing |
| 33 | +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 |
| 34 | +(2024) |
| 35 | +
|
| 36 | +[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein |
| 37 | +Barycenters. InternationalConference in Machine Learning |
| 38 | +
|
| 39 | +[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in |
| 40 | +Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 |
| 41 | +(2016): 744-762. |
| 42 | +
|
| 43 | +""" |
| 44 | + |
| 45 | +# Author: Eloi Tanguy <[email protected]> |
| 46 | +# |
| 47 | +# License: MIT License |
| 48 | + |
| 49 | +# sphinx_gallery_thumbnail_number = 1 |
| 50 | + |
| 51 | +# %% |
| 52 | +# Generate data |
| 53 | +import torch |
| 54 | +from torch.optim import Adam |
| 55 | +from ot.utils import dist |
| 56 | +import numpy as np |
| 57 | +from ot.lp import free_support_barycenter_generic_costs |
| 58 | +import matplotlib.pyplot as plt |
| 59 | +from time import time |
| 60 | + |
| 61 | + |
| 62 | +torch.manual_seed(42) |
| 63 | + |
| 64 | +n = 136 # number of points of the barycentre |
| 65 | +d = 2 # dimensions of the original measure |
| 66 | +K = 4 # number of measures to barycentre |
| 67 | +m_list = [49, 50, 51, 51] # number of points of the measures |
| 68 | +b_list = [torch.ones(m) / m for m in m_list] # weights of the 4 measures |
| 69 | +weights = torch.ones(K) / K # weights for the barycentre |
| 70 | +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo |
| 71 | + |
| 72 | + |
| 73 | +# map R^2 -> R^2 projection onto circle |
| 74 | +def proj_circle(X, origin, radius): |
| 75 | + diffs = X - origin[None, :] |
| 76 | + norms = torch.norm(diffs, dim=1) |
| 77 | + return origin[None, :] + radius * diffs / norms[:, None] |
| 78 | + |
| 79 | + |
| 80 | +# circles on which to project |
| 81 | +origin1 = torch.tensor([-1.0, -1.0]) |
| 82 | +origin2 = torch.tensor([-1.0, 2.0]) |
| 83 | +origin3 = torch.tensor([2.0, 2.0]) |
| 84 | +origin4 = torch.tensor([2.0, -1.0]) |
| 85 | +r = np.sqrt(2) |
| 86 | +P_list = [ |
| 87 | + lambda X: proj_circle(X, origin1, r), |
| 88 | + lambda X: proj_circle(X, origin2, r), |
| 89 | + lambda X: proj_circle(X, origin3, r), |
| 90 | + lambda X: proj_circle(X, origin4, r), |
| 91 | +] |
| 92 | + |
| 93 | +# measures to barycentre are projections of different random circles |
| 94 | +# onto the K circles |
| 95 | +Y_list = [] |
| 96 | +for k in range(K): |
| 97 | + t = torch.rand(m_list[k]) * 2 * np.pi |
| 98 | + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) |
| 99 | + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] |
| 100 | + Y_list.append(P_list[k](X_temp)) |
| 101 | + |
| 102 | + |
| 103 | +# %% |
| 104 | +# Define costs and ground barycenter function |
| 105 | +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a |
| 106 | +# (n, n_k) matrix of costs |
| 107 | +def c1(x, y): |
| 108 | + return dist(P_list[0](x), y) |
| 109 | + |
| 110 | + |
| 111 | +def c2(x, y): |
| 112 | + return dist(P_list[1](x), y) |
| 113 | + |
| 114 | + |
| 115 | +def c3(x, y): |
| 116 | + return dist(P_list[2](x), y) |
| 117 | + |
| 118 | + |
| 119 | +def c4(x, y): |
| 120 | + return dist(P_list[3](x), y) |
| 121 | + |
| 122 | + |
| 123 | +cost_list = [c1, c2, c3, c4] |
| 124 | + |
| 125 | + |
| 126 | +# batched total ground cost function for candidate points x (n, d) |
| 127 | +# for computation of the ground barycenter B with gradient descent |
| 128 | +def C(x, y): |
| 129 | + """ |
| 130 | + Computes the barycenter cost for candidate points x (n, d) and |
| 131 | + measure supports y: List(n, d_k). |
| 132 | + """ |
| 133 | + n = x.shape[0] |
| 134 | + K = len(y) |
| 135 | + out = torch.zeros(n) |
| 136 | + for k in range(K): |
| 137 | + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) |
| 138 | + return out |
| 139 | + |
| 140 | + |
| 141 | +# ground barycenter function |
| 142 | +def B(y, its=150, lr=1, stop_threshold=stop_threshold): |
| 143 | + """ |
| 144 | + Computes the ground barycenter for measure supports y: List(n, d_k). |
| 145 | + Output: (n, d) array |
| 146 | + """ |
| 147 | + x = torch.randn(y[0].shape[0], d) |
| 148 | + x.requires_grad_(True) |
| 149 | + opt = Adam([x], lr=lr) |
| 150 | + for _ in range(its): |
| 151 | + x_prev = x.data.clone() |
| 152 | + opt.zero_grad() |
| 153 | + loss = torch.sum(C(x, y)) |
| 154 | + loss.backward() |
| 155 | + opt.step() |
| 156 | + diff = torch.sum((x.data - x_prev) ** 2) |
| 157 | + if diff < stop_threshold: |
| 158 | + break |
| 159 | + return x |
| 160 | + |
| 161 | + |
| 162 | +# %% |
| 163 | +# Compute the barycenter measure with the true fixed-point algorithm |
| 164 | +fixed_point_its = 5 |
| 165 | +torch.manual_seed(42) |
| 166 | +X_init = torch.rand(n, d) |
| 167 | +t0 = time() |
| 168 | +X_bar, a_bar, log_dict = free_support_barycenter_generic_costs( |
| 169 | + Y_list, |
| 170 | + b_list, |
| 171 | + X_init, |
| 172 | + cost_list, |
| 173 | + B, |
| 174 | + numItermax=fixed_point_its, |
| 175 | + stopThr=stop_threshold, |
| 176 | + method="true_fixed_point", |
| 177 | + log=True, |
| 178 | + clean_measure=True, |
| 179 | +) |
| 180 | +dt_true_fixed_point = time() - t0 |
| 181 | + |
| 182 | +# %% |
| 183 | +# Compute the barycenter measure with the barycentric (default) algorithm |
| 184 | +fixed_point_its = 5 |
| 185 | +torch.manual_seed(42) |
| 186 | +X_init = torch.rand(n, d) |
| 187 | +t0 = time() |
| 188 | +X_bar2, log_dict2 = free_support_barycenter_generic_costs( |
| 189 | + Y_list, |
| 190 | + b_list, |
| 191 | + X_init, |
| 192 | + cost_list, |
| 193 | + B, |
| 194 | + numItermax=fixed_point_its, |
| 195 | + stopThr=stop_threshold, |
| 196 | + log=True, |
| 197 | +) |
| 198 | +dt_barycentric = time() - t0 |
| 199 | + |
| 200 | +# %% |
| 201 | +# Plot Barycenters (Iteration 3) |
| 202 | +alpha = 0.4 |
| 203 | +s = 80 |
| 204 | +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] |
| 205 | + |
| 206 | +fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
| 207 | + |
| 208 | +# Plot for the true fixed-point algorithm |
| 209 | +for Y, label in zip(Y_list, labels): |
| 210 | + axes[0].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) |
| 211 | +axes[0].scatter( |
| 212 | + *(X_bar.detach().numpy()).T, |
| 213 | + label="Barycenter", |
| 214 | + c="black", |
| 215 | + alpha=alpha * a_bar.numpy() / np.max(a_bar.numpy()), |
| 216 | + s=s, |
| 217 | +) |
| 218 | +axes[0].set_title( |
| 219 | + "True Fixed-Point Algorithm\n" |
| 220 | + f"Support size: {a_bar.shape[0]}\n" |
| 221 | + f"Barycenter cost: {log_dict['V_list'][-1].item():.6f}\n" |
| 222 | + f"Computation time {dt_true_fixed_point:.4f}s" |
| 223 | +) |
| 224 | +axes[0].axis("equal") |
| 225 | +axes[0].axis("off") |
| 226 | +axes[0].legend() |
| 227 | + |
| 228 | +# Plot for the heuristic algorithm |
| 229 | +for Y, label in zip(Y_list, labels): |
| 230 | + axes[1].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) |
| 231 | +axes[1].scatter( |
| 232 | + *(X_bar2.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s |
| 233 | +) |
| 234 | +axes[1].set_title( |
| 235 | + "Heuristic Barycentric Algorithm\n" |
| 236 | + f"Support size: {X_bar2.shape[0]}\n" |
| 237 | + f"Barycenter cost: {log_dict2['V_list'][-1].item():.6f}\n" |
| 238 | + f"Computation time {dt_barycentric:.4f}s" |
| 239 | +) |
| 240 | +axes[1].axis("equal") |
| 241 | +axes[1].axis("off") |
| 242 | +axes[1].legend() |
| 243 | + |
| 244 | +plt.tight_layout() |
| 245 | + |
| 246 | +# %% |
| 247 | +# Plot energy convergence and support sizes |
| 248 | +size = 3 |
| 249 | +n_plots = 4 |
| 250 | +fig, axes = plt.subplots(1, n_plots, figsize=(size * n_plots, size)) |
| 251 | +V_list = [V.item() for V in log_dict["V_list"]] |
| 252 | +V_list2 = [V.item() for V in log_dict2["V_list"]] |
| 253 | +diff = np.array(V_list2) - np.array(V_list) |
| 254 | + |
| 255 | +# Plot for True Fixed-Point Algorithm |
| 256 | +axes[0].plot(V_list, lw=5, alpha=0.6) |
| 257 | +axes[0].scatter(range(len(V_list)), V_list, color="blue", alpha=0.8, s=100) |
| 258 | +axes[0].set_title("True Fixed-Point Algorithm") |
| 259 | +axes[0].set_xlabel("Iteration") |
| 260 | +axes[0].set_ylabel("Barycenter Energy") |
| 261 | +axes[0].set_yscale("log") |
| 262 | +axes[0].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
| 263 | + |
| 264 | +# Plot for Heuristic Barycentric Algorithm |
| 265 | +axes[1].plot(V_list2, lw=5, alpha=0.6) |
| 266 | +axes[1].scatter(range(len(V_list2)), V_list2, color="blue", alpha=0.8, s=100) |
| 267 | +axes[1].set_title("Heuristic Barycentric Algorithm") |
| 268 | +axes[1].set_xlabel("Iteration") |
| 269 | +axes[1].set_ylabel("Barycenter Energy") |
| 270 | +axes[1].set_yscale("log") |
| 271 | +axes[1].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
| 272 | + |
| 273 | +# Plot difference between the two |
| 274 | +axes[2].plot(diff, lw=5, alpha=0.6) |
| 275 | +axes[2].scatter(range(len(diff)), diff, color="blue", alpha=0.8, s=100) |
| 276 | +axes[2].set_title("Heuristic Fixed-Point Energy - True") |
| 277 | +axes[2].set_xlabel("Iteration") |
| 278 | +axes[2].set_ylabel("$V_{\\mathrm{heuristic}} - V_{\\mathrm{true}}$") |
| 279 | +axes[2].set_yscale("log") |
| 280 | +axes[2].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
| 281 | + |
| 282 | +# plot support sizes |
| 283 | +support_sizes = [Xi.shape[0] for Xi in log_dict["X_list"]] |
| 284 | +support_sizes2 = [Xi.shape[0] for Xi in log_dict2["X_list"]] |
| 285 | + |
| 286 | +axes[3].plot(support_sizes, color="C0", lw=5, alpha=0.6, label="True FP") |
| 287 | +axes[3].scatter( |
| 288 | + range(len(support_sizes)), support_sizes, color="blue", alpha=0.8, s=100 |
| 289 | +) |
| 290 | +axes[3].plot(support_sizes2, color="red", lw=5, alpha=0.6, label="Heur. FP") |
| 291 | +axes[3].scatter( |
| 292 | + range(len(support_sizes2)), support_sizes2, color="red", alpha=0.8, s=100 |
| 293 | +) |
| 294 | +axes[3].legend(loc="best") |
| 295 | +axes[3].set_xlabel("Iteration") |
| 296 | +axes[3].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
| 297 | +axes[3].set_title("Support Sizes") |
| 298 | + |
| 299 | +plt.tight_layout() |
| 300 | +plt.show() |
| 301 | + |
| 302 | +# %% |
0 commit comments