diff options
author | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-03-17 12:18:11 -0400 |
---|---|---|
committer | MathieuCarriere <mathieu.carriere3@gmail.com> | 2020-03-17 12:18:11 -0400 |
commit | 910f29401e053f668e4d277af098295ab05e6022 (patch) | |
tree | f2c19f2d0c01ccc52fea2c09286907ed6777f7c6 /src/python/doc/wasserstein_distance_user.rst | |
parent | aa035c480ba8c1e229f7135257b6a6ae8ad57032 (diff) | |
parent | 5fdbad3fdd350edc3a5340110f4c5332c73517b3 (diff) |
Merge branch 'master' of https://github.com/GUDHI/gudhi-devel into wasserstein_representations
Diffstat (limited to 'src/python/doc/wasserstein_distance_user.rst')
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 43 |
1 files changed, 40 insertions, 3 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 94b454e2..a9b21fa5 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -36,10 +36,10 @@ Note that persistence diagrams must be submitted as (n x 2) numpy arrays and mus import gudhi.wasserstein import numpy as np - diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) - diag2 = np.array([[2.8, 4.45],[9.5, 14.1]]) + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45],[9.5, 14.1]]) - message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, order=1., internal_p=2.) + message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, order=1., internal_p=2.) print(message) The output is: @@ -47,3 +47,40 @@ The output is: .. testoutput:: Wasserstein distance value = 1.45 + +We can also have access to the optimal matching by letting `matching=True`. +It is encoded as a list of indices (i,j), meaning that the i-th point in X +is mapped to the j-th point in Y. +An index of -1 represents the diagonal. + +.. testcode:: + + import gudhi.wasserstein + import numpy as np + + dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) + cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm1, dgm2, matching=True, order=1, internal_p=2) + + message_cost = "Wasserstein distance value = %.2f" %cost + print(message_cost) + dgm1_to_diagonal = matchings[matchings[:,1] == -1, 0] + dgm2_to_diagonal = matchings[matchings[:,0] == -1, 1] + off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0) + + for i,j in off_diagonal_match: + print("point %s in dgm1 is matched to point %s in dgm2" %(i,j)) + for i in dgm1_to_diagonal: + print("point %s in dgm1 is matched to the diagonal" %i) + for j in dgm2_to_diagonal: + print("point %s in dgm2 is matched to the diagonal" %j) + +The output is: + +.. testoutput:: + + Wasserstein distance value = 2.15 + point 0 in dgm1 is matched to point 0 in dgm2 + point 1 in dgm1 is matched to point 2 in dgm2 + point 2 in dgm1 is matched to the diagonal + point 1 in dgm2 is matched to the diagonal |