summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-10 17:41:38 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-10 17:41:38 +0100
commit4aea5deab6ce4cbb491f4c9c2b7e9f023efbbe01 (patch)
treece370e35fee402bf0f1bf334e40eed978c56306f
parent967ceab26b09ad74e0cff0d84429a766af267f6b (diff)
changed output of matching as a (n x 2) array, adapted tests and doc
-rw-r--r--src/python/doc/wasserstein_distance_user.rst2
-rw-r--r--src/python/gudhi/wasserstein.py2
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py10
3 files changed, 7 insertions, 7 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index d3daa318..9519caa6 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -69,5 +69,5 @@ The output is:
.. testoutput::
- Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)]
+ Wasserstein distance value = 2.15, optimal matching: [[0, 0], [1, 2], [2, -1], [-1, 1]]
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index 9efa946e..9e4dc7d5 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -88,7 +88,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
if not matching:
return 0.
else:
- return 0., []
+ return 0., np.array([])
else:
if not matching:
return _perstot(Y, order, internal_p)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index d0f0323c..ca9a4a61 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -61,15 +61,15 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
if test_matching:
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
- assert match == []
+ assert np.array_equal(match, np.array([]))
match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert match == []
+ assert np.array_equal(match, np.array([]))
match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
- assert match == [(-1, 0), (-1, 1)]
+ assert np.array_equal(match , np.array([[-1, 0], [-1, 1]]))
match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
- assert match == [(0, -1), (1, -1)]
+ assert np.array_equal(match , np.array([[0, -1], [1, -1]]))
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
- assert match == [(0, 0), (1, 1), (2, -1)]
+ assert np.array_equal(match, np.array_equal([[0, 0], [1, 1], [2, -1]]))