From 66f0b08a8f8d5006f8d29352c169525cc53a22e6 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Mar 2020 19:11:30 +0100 Subject: changed typo in doc (diag --> dgm), used integer for order and internal p, simplify th ecode --- src/python/doc/wasserstein_distance_user.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 4c3b53dd..f43b2217 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -36,10 +36,10 @@ Note that persistence diagrams must be submitted as (n x 2) numpy arrays and mus import gudhi.wasserstein import numpy as np - diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) - diag2 = np.array([[2.8, 4.45],[9.5, 14.1]]) + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45],[9.5, 14.1]]) - message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, order=1., internal_p=2.) + message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, order=1., internal_p=2.) print(message) The output is: @@ -60,12 +60,12 @@ An index of -1 represents the diagonal. dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) - cost, matchings = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) + cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2) message_cost = "Wasserstein distance value = %.2f" %cost print(message_cost) - dgm1_to_diagonal = matchings[np.where(matchings[:,0] == -1)][:,1] - dgm2_to_diagonal = matchings[np.where(matchings[:,1] == -1)][:,0] + dgm1_to_diagonal = matching[matching[:,0] == -1, 1] + dgm2_to_diagonal = matching[matching[:,1] == -1, 0] off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) for i,j in off_diagonal_match: -- cgit v1.2.3