summaryrefslogtreecommitdiff
path: root/src/python/doc
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
commit8e4f3d151818b78a29d11cdc6ca171947bfd6dd9 (patch)
treef7f2d562332cfa22b90628e0e95dd22739322f9e /src/python/doc
parentd2943b9e7311c8a3d8a4fb379c39b15497481b9c (diff)
update wasserstein distance with pot so that it can return optimal matching now!
Diffstat (limited to 'src/python/doc')
-rw-r--r--src/python/doc/wasserstein_distance_user.rst24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 94b454e2..d3daa318 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -47,3 +47,27 @@ The output is:
.. testoutput::
Wasserstein distance value = 1.45
+
+We can also have access to the optimal matching by letting `matching=True`.
+It is encoded as a list of indices (i,j), meaning that the i-th point in X
+is mapped to the j-th point in Y.
+An index of -1 represents the diagonal.
+
+.. testcode::
+
+ import gudhi.wasserstein
+ import numpy as np
+
+ diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]])
+ cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.)
+
+ message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching)
+ print(message)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)]
+