diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-07-07 11:52:35 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-07-07 11:52:35 +0200 |
commit | e0eba14109e02676825f8c24563872a5b49c6120 (patch) | |
tree | 760b8287044752b0f77fcc5a43feb295a9cbae93 | |
parent | fe3e6a3a47828841ba3cb4a0721e5d8c16ab126f (diff) |
correction typo in test wdist
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 2 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 14 |
2 files changed, 8 insertions, 8 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 2a1dee7a..009c1bf7 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -245,7 +245,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if cost == np.inf: return cost, None else: - return np.array([[i, -1] for i in range(n)]) + return cost, np.array([[i, -1] for i in range(n)]) # Second step: handle essential parts diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 24be228b..e50091e9 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -55,10 +55,10 @@ def test_handle_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - c, m = _handle_essential_parts(diag1, diag2, matching=True, order=1) + c, m = _handle_essential_parts(diag1, diag2, order=1) assert c == pytest.approx(3, 0.0001) assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) - c, m = _handle_essential_parts(diag1, diag3, matching=True, order=1) + c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf assert (m is None) @@ -68,11 +68,11 @@ def test_get_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) res = _get_essential_parts(diag) - assert res[0] = [4, 5] - assert res[1] = [2, 3] - assert res[2] = [8, 9] - assert res[3] = [6] - assert res[4] = [7] + assert res[0] == [4, 5] + assert res[1] == [2, 3] + assert res[2] == [8, 9] + assert res[3] == [6] + assert res[4] == [7] def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): |