@@ -202,7 +202,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
202
202
return center_ot_dual (alpha , beta , a , b )
203
203
204
204
205
- def emd (a , b , M , numItermax = 100000 , log = False , center_dual = True , numThreads = 1 ):
205
+ def emd (a , b , M , numItermax = 100000 , log = False , center_dual = True , numThreads = 1 , check_marginals = True ):
206
206
r"""Solves the Earth Movers distance problem and returns the OT matrix
207
207
208
208
@@ -259,6 +259,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
259
259
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
260
260
If compiled with OpenMP, chooses the number of threads to parallelize.
261
261
"max" selects the highest number possible.
262
+ check_marginals: bool, optional (default=True)
263
+ If True, checks that the marginals mass are equal. If False, skips the
264
+ check.
265
+
262
266
263
267
Returns
264
268
-------
@@ -328,9 +332,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
328
332
"Dimension mismatch, check dimensions of M with a and b"
329
333
330
334
# ensure that same mass
331
- np .testing .assert_almost_equal (a .sum (0 ),
332
- b .sum (0 ), err_msg = 'a and b vector must have the same sum' ,
333
- decimal = 6 )
335
+ if check_marginals :
336
+ np .testing .assert_almost_equal (a .sum (0 ),
337
+ b .sum (0 ), err_msg = 'a and b vector must have the same sum' ,
338
+ decimal = 6 )
334
339
b = b * a .sum () / b .sum ()
335
340
336
341
asel = a != 0
@@ -368,7 +373,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
368
373
369
374
def emd2 (a , b , M , processes = 1 ,
370
375
numItermax = 100000 , log = False , return_matrix = False ,
371
- center_dual = True , numThreads = 1 ):
376
+ center_dual = True , numThreads = 1 , check_marginals = True ):
372
377
r"""Solves the Earth Movers distance problem and returns the loss
373
378
374
379
.. math::
@@ -425,7 +430,11 @@ def emd2(a, b, M, processes=1,
425
430
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
426
431
If compiled with OpenMP, chooses the number of threads to parallelize.
427
432
"max" selects the highest number possible.
428
-
433
+ check_marginals: bool, optional (default=True)
434
+ If True, checks that the marginals mass are equal. If False, skips the
435
+ check.
436
+
437
+
429
438
Returns
430
439
-------
431
440
W: float, array-like
@@ -492,8 +501,10 @@ def emd2(a, b, M, processes=1,
492
501
"Dimension mismatch, check dimensions of M with a and b"
493
502
494
503
# ensure that same mass
495
- np .testing .assert_almost_equal (a .sum (0 ),
496
- b .sum (0 ,keepdims = True ), err_msg = 'a and b vector must have the same sum' )
504
+ if check_marginals :
505
+ np .testing .assert_almost_equal (a .sum (0 ),
506
+ b .sum (0 ,keepdims = True ), err_msg = 'a and b vector must have the same sum' ,
507
+ decimal = 6 )
497
508
b = b * a .sum (0 ) / b .sum (0 ,keepdims = True )
498
509
499
510
asel = a != 0
0 commit comments