From 8e4f3d151818b78a29d11cdc6ca171947bfd6dd9 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 3 Mar 2020 15:33:17 +0100 Subject: update wasserstein distance with pot so that it can return optimal matching now! --- src/python/doc/wasserstein_distance_user.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 94b454e2..d3daa318 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -47,3 +47,27 @@ The output is: .. testoutput:: Wasserstein distance value = 1.45 + +We can also have access to the optimal matching by letting `matching=True`. +It is encoded as a list of indices (i,j), meaning that the i-th point in X +is mapped to the j-th point in Y. +An index of -1 represents the diagonal. + +.. testcode:: + + import gudhi.wasserstein + import numpy as np + + diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) + cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) + + message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching) + print(message) + +The output is: + +.. testoutput:: + + Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)] + -- cgit v1.2.3 From 4aea5deab6ce4cbb491f4c9c2b7e9f023efbbe01 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 10 Mar 2020 17:41:38 +0100 Subject: changed output of matching as a (n x 2) array, adapted tests and doc --- src/python/doc/wasserstein_distance_user.rst | 2 +- src/python/gudhi/wasserstein.py | 2 +- src/python/test/test_wasserstein_distance.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index d3daa318..9519caa6 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -69,5 +69,5 @@ The output is: .. testoutput:: - Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)] + Wasserstein distance value = 2.15, optimal matching: [[0, 0], [1, 2], [2, -1], [-1, 1]] diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py index 9efa946e..9e4dc7d5 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein.py @@ -88,7 +88,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): if not matching: return 0. else: - return 0., [] + return 0., np.array([]) else: if not matching: return _perstot(Y, order, internal_p) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index d0f0323c..ca9a4a61 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -61,15 +61,15 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching: match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] - assert match == [] + assert np.array_equal(match, np.array([])) match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert match == [] + assert np.array_equal(match, np.array([])) match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1] - assert match == [(-1, 0), (-1, 1)] + assert np.array_equal(match , np.array([[-1, 0], [-1, 1]])) match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1] - assert match == [(0, -1), (1, -1)] + assert np.array_equal(match , np.array([[0, -1], [1, -1]])) match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] - assert match == [(0, 0), (1, 1), (2, -1)] + assert np.array_equal(match, np.array_equal([[0, 0], [1, 1], [2, -1]])) -- cgit v1.2.3 From 5c55e976606b4dd020bd4e21c93ae22143ef5348 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Mar 2020 18:01:16 +0100 Subject: changed doc of matchings for a more explicit (and hopefully sphinx-valid) version --- src/python/doc/wasserstein_distance_user.rst | 29 ++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 9519caa6..4c3b53dd 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -58,16 +58,29 @@ An index of -1 represents the diagonal. import gudhi.wasserstein import numpy as np - diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) - diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) - cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) - - message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching) - print(message) + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) + cost, matchings = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) + + message_cost = "Wasserstein distance value = %.2f" %cost + print(message_cost) + dgm1_to_diagonal = matchings[np.where(matchings[:,0] == -1)][:,1] + dgm2_to_diagonal = matchings[np.where(matchings[:,1] == -1)][:,0] + off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) + + for i,j in off_diagonal_match: + print("point %s in dgm1 is matched to point %s in dgm2" %(i,j)) + for i in dgm1_to_diagonal: + print("point %s in dgm1 is matched to the diagonal" %i) + for j in dgm2_to_diagonal: + print("point %s in dgm2 is matched to the diagonal" %j) The output is: .. testoutput:: - Wasserstein distance value = 2.15, optimal matching: [[0, 0], [1, 2], [2, -1], [-1, 1]] - + Wasserstein distance value = 2.15 + point 0 in dgm1 is matched to point 0 in dgm2 + point 1 in dgm1 is matched to point 2 in dgm2 + point 2 in dgm1 is matched to the diagonal + point 1 in dgm2 is matched to the diagonal -- cgit v1.2.3 From 66f0b08a8f8d5006f8d29352c169525cc53a22e6 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Mar 2020 19:11:30 +0100 Subject: changed typo in doc (diag --> dgm), used integer for order and internal p, simplify th ecode --- src/python/doc/wasserstein_distance_user.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 4c3b53dd..f43b2217 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -36,10 +36,10 @@ Note that persistence diagrams must be submitted as (n x 2) numpy arrays and mus import gudhi.wasserstein import numpy as np - diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) - diag2 = np.array([[2.8, 4.45],[9.5, 14.1]]) + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45],[9.5, 14.1]]) - message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, order=1., internal_p=2.) + message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, order=1., internal_p=2.) print(message) The output is: @@ -60,12 +60,12 @@ An index of -1 represents the diagonal. dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) - cost, matchings = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) + cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2) message_cost = "Wasserstein distance value = %.2f" %cost print(message_cost) - dgm1_to_diagonal = matchings[np.where(matchings[:,0] == -1)][:,1] - dgm2_to_diagonal = matchings[np.where(matchings[:,1] == -1)][:,0] + dgm1_to_diagonal = matching[matching[:,0] == -1, 1] + dgm2_to_diagonal = matching[matching[:,1] == -1, 0] off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) for i,j in off_diagonal_match: -- cgit v1.2.3 From a253c0c4f54a9a148740ed9c20457df0ea43c842 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Mar 2020 19:36:07 +0100 Subject: correction typo in usr.rst --- src/python/doc/wasserstein_distance_user.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index f43b2217..25e51d68 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -64,8 +64,8 @@ An index of -1 represents the diagonal. message_cost = "Wasserstein distance value = %.2f" %cost print(message_cost) - dgm1_to_diagonal = matching[matching[:,0] == -1, 1] - dgm2_to_diagonal = matching[matching[:,1] == -1, 0] + dgm1_to_diagonal = matchings[matchings[:,0] == -1, 1] + dgm2_to_diagonal = matchings[matchings[:,1] == -1, 0] off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) for i,j in off_diagonal_match: -- cgit v1.2.3 From 60d11e3f06e08b66e49997f389c4dc01b00b793f Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Mar 2020 21:17:38 +0100 Subject: correction of typo in usr.rst --- src/python/doc/wasserstein_distance_user.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/python/doc/wasserstein_distance_user.rst') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 25e51d68..a9b21fa5 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -64,8 +64,8 @@ An index of -1 represents the diagonal. message_cost = "Wasserstein distance value = %.2f" %cost print(message_cost) - dgm1_to_diagonal = matchings[matchings[:,0] == -1, 1] - dgm2_to_diagonal = matchings[matchings[:,1] == -1, 0] + dgm1_to_diagonal = matchings[matchings[:,1] == -1, 0] + dgm2_to_diagonal = matchings[matchings[:,0] == -1, 1] off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) for i,j in off_diagonal_match: -- cgit v1.2.3