summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py4
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py19
2 files changed, 13 insertions, 10 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 009c1bf7..981bbf08 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -214,7 +214,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
# Zeroth step: check compatibility of arguments
if keep_essential_parts and enable_autodiff:
import warnings
- warnings.warn("enable_autodiff does not handle essential parts yet. These will be ignored in the following computations")
+ warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.")
keep_essential_parts = False
# First step: handle empty diagrams
@@ -256,11 +256,11 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return np.inf, None
else:
return np.inf # avoid computing off-diagonal transport cost if essential parts do not match (saves time)
-
else:
essential_cost = 0
essential_matching = None
+ # Extract off-diaognal points of the diagrams.
X, Y = _offdiag(X), _offdiag(Y)
n = len(X)
m = len(Y)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index e50091e9..285b95c9 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -9,12 +9,13 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts
+from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts, _get_essential_parts
from gudhi.wasserstein import wasserstein_distance as pot
from gudhi.hera import wasserstein_distance as hera
import numpy as np
import pytest
+
__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
@@ -56,8 +57,10 @@ def test_handle_essential_parts():
[-np.inf, np.inf], [-np.inf, np.inf]])
c, m = _handle_essential_parts(diag1, diag2, order=1)
- assert c == pytest.approx(3, 0.0001)
- assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]])
+ assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3)
+ # Similarly, the matching only corresponds to essential parts.
+ assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]])
+
c, m = _handle_essential_parts(diag1, diag3, order=1)
assert c == np.inf
assert (m is None)
@@ -68,11 +71,11 @@ def test_get_essential_parts():
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
res = _get_essential_parts(diag)
- assert res[0] == [4, 5]
- assert res[1] == [2, 3]
- assert res[2] == [8, 9]
- assert res[3] == [6]
- assert res[4] == [7]
+ assert np.array_equal(res[0], [4, 5])
+ assert np.array_equal(res[1], [2, 3])
+ assert np.array_equal(res[2], [8, 9])
+ assert np.array_equal(res[3], [6] )
+ assert np.array_equal(res[4], [7] )
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):