diff options
-rw-r--r-- | CONTRIBUTORS.md | 1 | ||||
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | ot/lp/__init__.py | 9 | ||||
-rw-r--r-- | test/test_ot.py | 3 |
4 files changed, 14 insertions, 0 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0909b14..c535c09 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -38,6 +38,7 @@ The contributors to this library are: * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) +* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index b384617..78a7d9e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ - Fixed an issue where pointers would overflow in the EMD solver, returning an incomplete transport plan above a certain size (slightly above 46k, its square being roughly 2^31) (PR #381) +- Error raised when mass mismatch in emd2 (PR #386) ## 0.8.2 diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 572781d..17411d0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] <references-emd>`. Parameters @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1, If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] <references-emd2>`. Parameters @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1, assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum') + b = b * a.sum(0) / b.sum(0,keepdims=True) + asel = a != 0 numThreads = check_number_threads(numThreads) diff --git a/test/test_ot.py b/test/test_ot.py index ba3ef6a..9a4e175 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + # test emd and emd2 for mass mismatch + a = ot.utils.unif(n_samples) b = a.copy() a[0] = 100 np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, b, M) def test_emd_backends(nx): |