diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-03-03 15:33:17 +0100 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-03-03 15:33:17 +0100 |
commit | 8e4f3d151818b78a29d11cdc6ca171947bfd6dd9 (patch) | |
tree | f7f2d562332cfa22b90628e0e95dd22739322f9e /src/python/doc | |
parent | d2943b9e7311c8a3d8a4fb379c39b15497481b9c (diff) |
update wasserstein distance with pot so that it can return optimal matching now!
Diffstat (limited to 'src/python/doc')
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 94b454e2..d3daa318 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -47,3 +47,27 @@ The output is: .. testoutput:: Wasserstein distance value = 1.45 + +We can also have access to the optimal matching by letting `matching=True`. +It is encoded as a list of indices (i,j), meaning that the i-th point in X +is mapped to the j-th point in Y. +An index of -1 represents the diagonal. + +.. testcode:: + + import gudhi.wasserstein + import numpy as np + + diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + diag2 = np.array([[2.8, 4.45], [5, 6], [9.5, 14.1]]) + cost, matching = gudhi.wasserstein.wasserstein_distance(diag1, diag2, matching=True, order=1., internal_p=2.) + + message = "Wasserstein distance value = %.2f, optimal matching: %s" %(cost, matching) + print(message) + +The output is: + +.. testoutput:: + + Wasserstein distance value = 2.15, optimal matching: [(0, 0), (1, 2), (2, -1), (-1, 1)] + |