|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +r""" |
| 3 | +=============================================================================== |
| 4 | +Computing d-dimensional Barycenters via d-MMOT |
| 5 | +=============================================================================== |
| 6 | +
|
| 7 | +When the cost is discretized (Monge), the d-MMOT solver can more quickly |
| 8 | +compute and minimize the distance between many distributions without the need |
| 9 | +for intermediate barycenter computations. This example compares the time to |
| 10 | +identify, and the quality of, solutions for the d-MMOT problem using a |
| 11 | +primal/dual algorithm and classical LP barycenter approaches. |
| 12 | +""" |
| 13 | + |
| 14 | +# Author: Ronak Mehta <[email protected]> |
| 15 | + |
| 16 | +# |
| 17 | +# License: MIT License |
| 18 | + |
| 19 | +# %% |
| 20 | +# Generating 2 distributions |
| 21 | +# ----- |
| 22 | +import numpy as np |
| 23 | +import matplotlib.pyplot as pl |
| 24 | +import ot |
| 25 | + |
| 26 | +np.random.seed(0) |
| 27 | + |
| 28 | +n = 100 |
| 29 | +d = 2 |
| 30 | +# Gaussian distributions |
| 31 | +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std |
| 32 | +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) |
| 33 | +A = np.vstack((a1, a2)).T |
| 34 | +x = np.arange(n, dtype=np.float64) |
| 35 | +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') |
| 36 | + |
| 37 | +pl.figure(1, figsize=(6.4, 3)) |
| 38 | +pl.plot(x, a1, 'b', label='Source distribution') |
| 39 | +pl.plot(x, a2, 'r', label='Target distribution') |
| 40 | +pl.legend() |
| 41 | + |
| 42 | +# %% |
| 43 | +# Minimize the distances among distributions, identify the Barycenter |
| 44 | +# ----- |
| 45 | +# The objective being minimized is different for both methods, so the objective |
| 46 | +# values cannot be compared. |
| 47 | + |
| 48 | +# L2 Iteration |
| 49 | +weights = np.ones(d) / d |
| 50 | +l2_bary = A.dot(weights) |
| 51 | + |
| 52 | +print('LP Iterations:') |
| 53 | +weights = np.ones(d) / d |
| 54 | +lp_bary, lp_log = ot.lp.barycenter( |
| 55 | + A, M, weights, solver='interior-point', verbose=False, log=True) |
| 56 | +print('Time\t: ', ot.toc('')) |
| 57 | +print('Obj\t: ', lp_log['fun']) |
| 58 | + |
| 59 | +print('') |
| 60 | +print('Discrete MMOT Algorithm:') |
| 61 | +ot.tic() |
| 62 | +barys, log = ot.lp.dmmot_monge_1dgrid_optimize( |
| 63 | + A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) |
| 64 | +dmmot_obj = log['primal objective'] |
| 65 | +print('Time\t: ', ot.toc('')) |
| 66 | +print('Obj\t: ', dmmot_obj) |
| 67 | + |
| 68 | +# %% |
| 69 | +# Compare Barycenters in both methods |
| 70 | +# ----- |
| 71 | +pl.figure(1, figsize=(6.4, 3)) |
| 72 | +for i in range(len(barys)): |
| 73 | + if i == 0: |
| 74 | + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') |
| 75 | + else: |
| 76 | + continue |
| 77 | + # pl.plot(x, barys[i], 'g-*') |
| 78 | +pl.plot(x, lp_bary, label='LP Barycenter') |
| 79 | +pl.plot(x, l2_bary, label='L2 Barycenter') |
| 80 | +pl.plot(x, a1, 'b', label='Source distribution') |
| 81 | +pl.plot(x, a2, 'r', label='Target distribution') |
| 82 | +pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver') |
| 83 | +pl.legend() |
| 84 | + |
| 85 | + |
| 86 | +# %% |
| 87 | +# More than 2 distributions |
| 88 | +# -------------------------------------------------- |
| 89 | +# Generate 7 pseudorandom gaussian distributions with 50 bins. |
| 90 | +n = 50 # nb bins |
| 91 | +d = 7 |
| 92 | +vecsize = n * d |
| 93 | + |
| 94 | +data = [] |
| 95 | +for i in range(d): |
| 96 | + m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1) |
| 97 | + a = ot.datasets.make_1D_gauss(n, m=m, s=5) |
| 98 | + data.append(a) |
| 99 | + |
| 100 | +x = np.arange(n, dtype=np.float64) |
| 101 | +M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') |
| 102 | +A = np.vstack(data).T |
| 103 | + |
| 104 | +pl.figure(1, figsize=(6.4, 3)) |
| 105 | +for i in range(len(data)): |
| 106 | + pl.plot(x, data[i]) |
| 107 | + |
| 108 | +pl.title('Distributions') |
| 109 | +pl.legend() |
| 110 | + |
| 111 | +# %% |
| 112 | +# Minimizing Distances Among Many Distributions |
| 113 | +# --------------- |
| 114 | +# The objective being minimized is different for both methods, so the objective |
| 115 | +# values cannot be compared. |
| 116 | + |
| 117 | +# Perform gradient descent optimization using the d-MMOT method. |
| 118 | +barys = ot.lp.dmmot_monge_1dgrid_optimize( |
| 119 | + A, niters=3000, lr_init=1e-4, lr_decay=0.997) |
| 120 | + |
| 121 | +# after minimization, any distribution can be used as a estimate of barycenter. |
| 122 | +bary = barys[0] |
| 123 | + |
| 124 | +# Compute 1D Wasserstein barycenter using the L2/LP method |
| 125 | +weights = ot.unif(d) |
| 126 | +l2_bary = A.dot(weights) |
| 127 | +lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', |
| 128 | + verbose=False, log=True) |
| 129 | + |
| 130 | +# %% |
| 131 | +# Compare Barycenters in both methods |
| 132 | +# --------- |
| 133 | +pl.figure(1, figsize=(6.4, 3)) |
| 134 | +pl.plot(x, bary, 'g-*', label='Discrete MMOT') |
| 135 | +pl.plot(x, l2_bary, 'k', label='L2 Barycenter') |
| 136 | +pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') |
| 137 | +pl.title('Barycenters') |
| 138 | +pl.legend() |
| 139 | + |
| 140 | +# %% |
| 141 | +# Compare with original distributions |
| 142 | +# --------- |
| 143 | +pl.figure(1, figsize=(6.4, 3)) |
| 144 | +for i in range(len(data)): |
| 145 | + pl.plot(x, data[i]) |
| 146 | +for i in range(len(barys)): |
| 147 | + if i == 0: |
| 148 | + pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') |
| 149 | + else: |
| 150 | + continue |
| 151 | + # pl.plot(x, barys[i], 'g') |
| 152 | +pl.plot(x, l2_bary, 'k^', label='L2') |
| 153 | +pl.plot(x, lp_bary, 'o', color='grey', label='LP') |
| 154 | +pl.title('Barycenters') |
| 155 | +pl.legend() |
| 156 | +pl.show() |
| 157 | + |
| 158 | +# %% |
0 commit comments