From 7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d Mon Sep 17 00:00:00 2001 From: clecoz Date: Tue, 21 Jun 2022 17:36:22 +0200 Subject: [MRG] raise error if mass mismatch in emd2 (#386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Two lines added in the function emd2 to ensure that the distributions have the same mass (same as it already was in the function emd). * The same mass test has been moved inside the function f(b) to be compatible with emd2 with multiple b. * Test added. The function test_emd_dimension_and_mass_mismatch (in test/test_ot.py) has been modified to check for mass mismatch with emd2. * Add PR in releases.md * Merge and add PR in releases.md * Add name in contributors.md * Correction contribution in contributors.md * Move test on mass outside of functions f(b) * Update doc of emd and emd2 Co-authored-by: Camille Le Coz Co-authored-by: RĂ©mi Flamary --- ot/lp/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 572781d..17411d0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1, If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1, assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum') + b = b * a.sum(0) / b.sum(0,keepdims=True) + asel = a != 0 numThreads = check_number_threads(numThreads) -- cgit v1.2.3