summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorclecoz <camille.lecoz@laposte.net>2022-06-21 17:36:22 +0200
committerGitHub <noreply@github.com>2022-06-21 17:36:22 +0200
commit7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (patch)
tree453c481aaea14c3b5c915fd2cab53cec5d996043
parente547fe30c59be72ae93c9f017786477b2652776f (diff)
[MRG] raise error if mass mismatch in emd2 (#386)
* Two lines added in the function emd2 to ensure that the distributions have the same mass (same as it already was in the function emd). * The same mass test has been moved inside the function f(b) to be compatible with emd2 with multiple b. * Test added. The function test_emd_dimension_and_mass_mismatch (in test/test_ot.py) has been modified to check for mass mismatch with emd2. * Add PR in releases.md * Merge and add PR in releases.md * Add name in contributors.md * Correction contribution in contributors.md * Move test on mass outside of functions f(b) * Update doc of emd and emd2 Co-authored-by: Camille Le Coz <clecoz@camelot.ipsl.polytechnique.fr> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--CONTRIBUTORS.md1
-rw-r--r--RELEASES.md1
-rw-r--r--ot/lp/__init__.py9
-rw-r--r--test/test_ot.py3
4 files changed, 14 insertions, 0 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 0909b14..c535c09 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -38,6 +38,7 @@ The contributors to this library are:
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
+* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
## Acknowledgments
diff --git a/RELEASES.md b/RELEASES.md
index b384617..78a7d9e 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -17,6 +17,7 @@
- Fixed an issue where pointers would overflow in the EMD solver, returning an
incomplete transport plan above a certain size (slightly above 46k, its square being
roughly 2^31) (PR #381)
+- Error raised when mass mismatch in emd2 (PR #386)
## 0.8.2
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 572781d..17411d0 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
+ b = b * a.sum(0) / b.sum(0,keepdims=True)
+
asel = a != 0
numThreads = check_number_threads(numThreads)
diff --git a/test/test_ot.py b/test/test_ot.py
index ba3ef6a..9a4e175 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ # test emd and emd2 for mass mismatch
+ a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+ np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
def test_emd_backends(nx):