From 42a399c273fde7c76ec23d2993957fcbb492ee79 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 7 Jul 2020 12:37:51 +0200 Subject: correction mistake in tests --- src/python/gudhi/wasserstein/wasserstein.py | 4 ++-- src/python/test/test_wasserstein_distance.py | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 009c1bf7..981bbf08 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -214,7 +214,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # Zeroth step: check compatibility of arguments if keep_essential_parts and enable_autodiff: import warnings - warnings.warn("enable_autodiff does not handle essential parts yet. These will be ignored in the following computations") + warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.") keep_essential_parts = False # First step: handle empty diagrams @@ -256,11 +256,11 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return np.inf, None else: return np.inf # avoid computing off-diagonal transport cost if essential parts do not match (saves time) - else: essential_cost = 0 essential_matching = None + # Extract off-diaognal points of the diagrams. X, Y = _offdiag(X), _offdiag(Y) n = len(X) m = len(Y) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index e50091e9..285b95c9 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -9,12 +9,13 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts +from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts, _get_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np import pytest + __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" @@ -56,8 +57,10 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) - assert c == pytest.approx(3, 0.0001) - assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) + assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) + # Similarly, the matching only corresponds to essential parts. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf assert (m is None) @@ -68,11 +71,11 @@ def test_get_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) res = _get_essential_parts(diag) - assert res[0] == [4, 5] - assert res[1] == [2, 3] - assert res[2] == [8, 9] - assert res[3] == [6] - assert res[4] == [7] + assert np.array_equal(res[0], [4, 5]) + assert np.array_equal(res[1], [2, 3]) + assert np.array_equal(res[2], [8, 9]) + assert np.array_equal(res[3], [6] ) + assert np.array_equal(res[4], [7] ) def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): -- cgit v1.2.3