summaryrefslogtreecommitdiff
path: root/src/python/doc/wasserstein_distance_user.rst
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-16 18:01:16 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-16 18:01:16 +0100
commit5c55e976606b4dd020bd4e21c93ae22143ef5348 (patch)
treeda6753884d25f82293a072da93295fbe14ca8da6 /src/python/doc/wasserstein_distance_user.rst
parentcc2d51d32c5e546c10046adae04ad14d38930566 (diff)
changed doc of matchings for a more explicit (and hopefully sphinx-valid) version
Diffstat (limited to 'src/python/doc/wasserstein_distance_user.rst')
-rw-r--r--src/python/doc/wasserstein_distance_user.rst29
1 files changed, 21 insertions, 8 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 9519caa6..4c3b53dd 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -58,16 +58,29 @@ An index of -1 represents the diagonal.
import gudhi.wasserstein
import numpy as np
- diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
- diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
- cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.)
-
- message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching)
- print(message)
+ dgm1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ dgm2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ cost, matchings = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.)
+
+ message_cost = "Wasserstein distance value = %.2f" %cost
+ print(message_cost)
+ dgm1_to_diagonal = matchings[np.where(matchings[:,0] == -1)][:,1]
+ dgm2_to_diagonal = matchings[np.where(matchings[:,1] == -1)][:,0]
+ off_diagonal_match = np.delete(matchings, np.where(matchings == -1)[0], axis=0)
+
+ for i,j in off_diagonal_match:
+ print("point %s in dgm1 is matched to point %s in dgm2" %(i,j))
+ for i in dgm1_to_diagonal:
+ print("point %s in dgm1 is matched to the diagonal" %i)
+ for j in dgm2_to_diagonal:
+ print("point %s in dgm2 is matched to the diagonal" %j)
The output is:
.. testoutput::
- Wasserstein distance value = 2.15, optimal matching: [[0, 0], [1, 2], [2, -1], [-1, 1]]
-
+ Wasserstein distance value = 2.15
+ point 0 in dgm1 is matched to point 0 in dgm2
+ point 1 in dgm1 is matched to point 2 in dgm2
+ point 2 in dgm1 is matched to the diagonal
+ point 1 in dgm2 is matched to the diagonal