Skip to content

Commit 0b520f1

Browse files
authored
[Fix] Change threshold for checking marginals and making check optional (#496)
* first shot release file + tryu build full doc * last chance * update release file * update realease and and correct import of mmot methods * try again last time * change thresholfl for checking marginals and add option to skip it * update release file
1 parent ec01d41 commit 0b520f1

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Many other bugs and issues have been fixed and we want to thank all the contribu
4343
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
4444
- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (PR #474)
4545
- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
46+
- Fix pression error on marginal sums and (Issue #429, PR #496)
4647

4748
#### New Contributors
4849
* @kachayev made their first contribution in PR #462

ot/lp/__init__.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
202202
return center_ot_dual(alpha, beta, a, b)
203203

204204

205-
def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
205+
def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, check_marginals=True):
206206
r"""Solves the Earth Movers distance problem and returns the OT matrix
207207
208208
@@ -259,6 +259,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
259259
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
260260
If compiled with OpenMP, chooses the number of threads to parallelize.
261261
"max" selects the highest number possible.
262+
check_marginals: bool, optional (default=True)
263+
If True, checks that the marginals mass are equal. If False, skips the
264+
check.
265+
262266
263267
Returns
264268
-------
@@ -328,9 +332,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
328332
"Dimension mismatch, check dimensions of M with a and b"
329333

330334
# ensure that same mass
331-
np.testing.assert_almost_equal(a.sum(0),
332-
b.sum(0), err_msg='a and b vector must have the same sum',
333-
decimal=6)
335+
if check_marginals:
336+
np.testing.assert_almost_equal(a.sum(0),
337+
b.sum(0), err_msg='a and b vector must have the same sum',
338+
decimal=6)
334339
b = b * a.sum() / b.sum()
335340

336341
asel = a != 0
@@ -368,7 +373,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
368373

369374
def emd2(a, b, M, processes=1,
370375
numItermax=100000, log=False, return_matrix=False,
371-
center_dual=True, numThreads=1):
376+
center_dual=True, numThreads=1, check_marginals=True):
372377
r"""Solves the Earth Movers distance problem and returns the loss
373378
374379
.. math::
@@ -425,7 +430,11 @@ def emd2(a, b, M, processes=1,
425430
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
426431
If compiled with OpenMP, chooses the number of threads to parallelize.
427432
"max" selects the highest number possible.
428-
433+
check_marginals: bool, optional (default=True)
434+
If True, checks that the marginals mass are equal. If False, skips the
435+
check.
436+
437+
429438
Returns
430439
-------
431440
W: float, array-like
@@ -492,8 +501,10 @@ def emd2(a, b, M, processes=1,
492501
"Dimension mismatch, check dimensions of M with a and b"
493502

494503
# ensure that same mass
495-
np.testing.assert_almost_equal(a.sum(0),
496-
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
504+
if check_marginals:
505+
np.testing.assert_almost_equal(a.sum(0),
506+
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum',
507+
decimal=6)
497508
b = b * a.sum(0) / b.sum(0,keepdims=True)
498509

499510
asel = a != 0

ot/lp/solver_1d.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
134134

135135

136136
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
137-
log=False):
137+
log=False, check_marginals=True):
138138
r"""Solves the Earth Movers distance problem between 1d measures and returns
139139
the OT matrix
140140
@@ -181,6 +181,9 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
181181
log: boolean, optional (default=False)
182182
If True, returns a dictionary containing the cost.
183183
Otherwise returns only the optimal transportation matrix.
184+
check_marginals: bool, optional (default=True)
185+
If True, checks that the marginals mass are equal. If False, skips the
186+
check.
184187
185188
Returns
186189
-------
@@ -235,11 +238,13 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
235238
b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
236239

237240
# ensure that same mass
238-
np.testing.assert_almost_equal(
239-
nx.to_numpy(nx.sum(a, axis=0)),
240-
nx.to_numpy(nx.sum(b, axis=0)),
241-
err_msg='a and b vector must have the same sum'
242-
)
241+
if check_marginals:
242+
np.testing.assert_almost_equal(
243+
nx.to_numpy(nx.sum(a, axis=0)),
244+
nx.to_numpy(nx.sum(b, axis=0)),
245+
err_msg='a and b vector must have the same sum',
246+
decimal=6
247+
)
243248
b = b * nx.sum(a) / nx.sum(b)
244249

245250
x_a_1d = nx.reshape(x_a, (-1,))

0 commit comments

Comments
 (0)