summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 572781d..17411d0 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
+ b = b * a.sum(0) / b.sum(0,keepdims=True)
+
asel = a != 0
numThreads = check_number_threads(numThreads)