From fe3e6a3a47828841ba3cb4a0721e5d8c16ab126f Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 6 Jul 2020 18:27:52 +0200 Subject: update test including essential parts --- src/python/test/test_wasserstein_distance.py | 72 +++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 90d26809..24be228b 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -5,10 +5,11 @@ Copyright (C) 2019 Inria Modification(s): + - 2020/07 Théo Lacombe: Added tests about handling essential parts in diagrams. - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag +from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -18,12 +19,62 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" + def test_proj_on_diag(): dgm = np.array([[1., 1.], [1., 2.], [3., 5.]]) assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]]) empty = np.empty((0, 2)) assert np.array_equal(_proj_on_diag(empty), empty) + +def test_offdiag(): + diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) + assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]]) + + +def test_handle_essential_parts(): + diag1 = np.array([[0, 1], [3, 5], + [2, np.inf], [3, np.inf], + [-np.inf, 8], [-np.inf, 12], + [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + diag2 = np.array([[0, 2], [3, 5], + [2, np.inf], [4, np.inf], + [-np.inf, 8], [-np.inf, 11], + [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + diag3 = np.array([[0, 2], [3, 5], + [2, np.inf], [4, np.inf], + [-np.inf, 8], [-np.inf, 11], + [-np.inf, -np.inf], [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + c, m = _handle_essential_parts(diag1, diag2, matching=True, order=1) + assert c == pytest.approx(3, 0.0001) + assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) + c, m = _handle_essential_parts(diag1, diag3, matching=True, order=1) + assert c == np.inf + assert (m is None) + + +def test_get_essential_parts(): + diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) + + res = _get_essential_parts(diag) + assert res[0] = [4, 5] + assert res[1] = [2, 3] + assert res[2] = [8, 9] + assert res[3] = [6] + assert res[4] = [7] + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -64,7 +115,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert wasserstein_distance(diag4, diag5) == np.inf assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) - + assert wasserstein_distance(diag5, emptydiag) == np.inf if test_matching: match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] @@ -78,6 +129,13 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + if test_matching and test_infinity: + diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) + + match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] + assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) + match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] + assert (match is None) def hera_wrap(**extra): @@ -92,7 +150,7 @@ def pot_wrap(**extra): def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) - _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False) + _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) @@ -105,19 +163,19 @@ def test_wasserstein_distance_grad(): diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) dist12.backward() dist30.backward() assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() diag4 = torch.tensor([[0., 10.]], requires_grad=True) diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True, keep_essential_parts=False) assert dist45 == 3. dist45.backward() assert np.array_equal(diag4.grad, [[-1., -1.]]) assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False).backward() # https://github.com/jonasrauber/eagerpy/issues/6 # assert np.array_equal(diag6.grad, [[0., 0.]]) -- cgit v1.2.3 From e0eba14109e02676825f8c24563872a5b49c6120 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 7 Jul 2020 11:52:35 +0200 Subject: correction typo in test wdist --- src/python/gudhi/wasserstein/wasserstein.py | 2 +- src/python/test/test_wasserstein_distance.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 2a1dee7a..009c1bf7 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -245,7 +245,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if cost == np.inf: return cost, None else: - return np.array([[i, -1] for i in range(n)]) + return cost, np.array([[i, -1] for i in range(n)]) # Second step: handle essential parts diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 24be228b..e50091e9 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -55,10 +55,10 @@ def test_handle_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - c, m = _handle_essential_parts(diag1, diag2, matching=True, order=1) + c, m = _handle_essential_parts(diag1, diag2, order=1) assert c == pytest.approx(3, 0.0001) assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) - c, m = _handle_essential_parts(diag1, diag3, matching=True, order=1) + c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf assert (m is None) @@ -68,11 +68,11 @@ def test_get_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) res = _get_essential_parts(diag) - assert res[0] = [4, 5] - assert res[1] = [2, 3] - assert res[2] = [8, 9] - assert res[3] = [6] - assert res[4] = [7] + assert res[0] == [4, 5] + assert res[1] == [2, 3] + assert res[2] == [8, 9] + assert res[3] == [6] + assert res[4] == [7] def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): -- cgit v1.2.3 From 42a399c273fde7c76ec23d2993957fcbb492ee79 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 7 Jul 2020 12:37:51 +0200 Subject: correction mistake in tests --- src/python/gudhi/wasserstein/wasserstein.py | 4 ++-- src/python/test/test_wasserstein_distance.py | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 009c1bf7..981bbf08 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -214,7 +214,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab # Zeroth step: check compatibility of arguments if keep_essential_parts and enable_autodiff: import warnings - warnings.warn("enable_autodiff does not handle essential parts yet. These will be ignored in the following computations") + warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.") keep_essential_parts = False # First step: handle empty diagrams @@ -256,11 +256,11 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return np.inf, None else: return np.inf # avoid computing off-diagonal transport cost if essential parts do not match (saves time) - else: essential_cost = 0 essential_matching = None + # Extract off-diaognal points of the diagrams. X, Y = _offdiag(X), _offdiag(Y) n = len(X) m = len(Y) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index e50091e9..285b95c9 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -9,12 +9,13 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts +from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts, _get_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np import pytest + __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" @@ -56,8 +57,10 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) - assert c == pytest.approx(3, 0.0001) - assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) + assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) + # Similarly, the matching only corresponds to essential parts. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf assert (m is None) @@ -68,11 +71,11 @@ def test_get_essential_parts(): [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) res = _get_essential_parts(diag) - assert res[0] == [4, 5] - assert res[1] == [2, 3] - assert res[2] == [8, 9] - assert res[3] == [6] - assert res[4] == [7] + assert np.array_equal(res[0], [4, 5]) + assert np.array_equal(res[1], [2, 3]) + assert np.array_equal(res[2], [8, 9]) + assert np.array_equal(res[3], [6] ) + assert np.array_equal(res[4], [7] ) def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): -- cgit v1.2.3 From e94892f972357283e70c7534f84662dfaa21cc3e Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 20 Jul 2020 11:41:13 +0200 Subject: update test enable_autodiff and _offdiag --- src/python/gudhi/wasserstein/wasserstein.py | 16 ++++++---------- src/python/test/test_wasserstein_distance.py | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 495142c4..142385b1 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -178,13 +178,13 @@ def _handle_essential_parts(X, Y, order): def _offdiag(X, enable_autodiff): ''' :param X: (n x 2) numpy array encoding a persistence diagram. + :param enable_autodiff: boolean, to handle the case where X is a eagerpy tensor. :returns: The off-diagonal part of a diagram `X` (points with finite coordinates). ''' if enable_autodiff: - import eagerpy as ep - - return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]) - + # Assumes the diagrams only have finite coordinates. Thus, return X directly. + # TODO improve this to get rid of essential parts if there are any. + return X else: return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] @@ -218,11 +218,6 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab If matching is set to True, also returns the optimal matching between X and Y. If cost is +inf, any matching is optimal and thus it returns `None` instead. ''' - # Zeroth step: check compatibility of arguments - if keep_essential_parts and enable_autodiff: - import warnings - warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.") - keep_essential_parts = False # First step: handle empty diagrams n = len(X) @@ -267,7 +262,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab essential_cost = 0 essential_matching = None - # Extract off-diaognal points of the diagrams. + # Extract off-diaognal points of the diagrams. Note that if enable_autodiff is True, nothing is done here (X,Y are + # assumed to be tensors with only finite coordinates). X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff) n = len(X) m = len(Y) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 285b95c9..6701c7ba 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -31,7 +31,7 @@ def test_proj_on_diag(): def test_offdiag(): diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]]) + assert np.array_equal(_offdiag(diag, enable_autodiff=False), [[0, 1], [3, 5]]) def test_handle_essential_parts(): -- cgit v1.2.3 From 01bd9eef85b0d93eb1629f1a0c5a28a359e4e7b9 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 12 Apr 2021 10:47:18 +0200 Subject: change name _offdiag to _finite_part in test file --- src/python/test/test_wasserstein_distance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6701c7ba..12bf71df 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -9,7 +9,7 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts, _get_essential_parts +from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -28,10 +28,10 @@ def test_proj_on_diag(): assert np.array_equal(_proj_on_diag(empty), empty) -def test_offdiag(): +def test_finite_part(): diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - assert np.array_equal(_offdiag(diag, enable_autodiff=False), [[0, 1], [3, 5]]) + assert np.array_equal(_finite_part(diag, enable_autodiff=False), [[0, 1], [3, 5]]) def test_handle_essential_parts(): -- cgit v1.2.3 From 2a11e3651c2d66df8371a9aa1d23dff69ffbc31c Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 12 Apr 2021 15:54:26 +0200 Subject: removed test_wasserstein_distance_grad to be consistent with master --- src/python/test/test_wasserstein_distance.py | 23 ----------------------- 1 file changed, 23 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 12bf71df..14d5c2ca 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -159,26 +159,3 @@ def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) -def test_wasserstein_distance_grad(): - import torch - - diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) - diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) - dist12.backward() - dist30.backward() - assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() - diag4 = torch.tensor([[0., 10.]], requires_grad=True) - diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True, keep_essential_parts=False) - assert dist45 == 3. - dist45.backward() - assert np.array_equal(diag4.grad, [[-1., -1.]]) - assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) - diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False).backward() - # https://github.com/jonasrauber/eagerpy/issues/6 - # assert np.array_equal(diag6.grad, [[0., 0.]]) -- cgit v1.2.3 From cdab3c9e32923f83d25d2cdf207f3cddbb3f94f6 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 12 Apr 2021 17:02:34 +0200 Subject: handle essential parts test --- src/python/gudhi/wasserstein/wasserstein.py | 1 + src/python/test/test_wasserstein_distance.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 2911f826..7cb9d5d9 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -113,6 +113,7 @@ def _get_essential_parts(a): second_coord_infinite_positive = (a[:,1] == np.inf) first_coord_infinite_negative = (a[:,0] == -np.inf) second_coord_infinite_negative = (a[:,1] == -np.inf) + ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x) ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf) ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 14d5c2ca..df7acc91 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -31,7 +31,7 @@ def test_proj_on_diag(): def test_finite_part(): diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - assert np.array_equal(_finite_part(diag, enable_autodiff=False), [[0, 1], [3, 5]]) + assert np.array_equal(_finite_part(diag), [[0, 1], [3, 5]]) def test_handle_essential_parts(): -- cgit v1.2.3 From 604b2cde0c7951c81d1c510f3038e2c65c19e6fe Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 20 Apr 2021 19:06:56 +0200 Subject: update doc and tests --- src/python/doc/wasserstein_distance_user.rst | 1 + src/python/test/test_wasserstein_distance.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'src/python/test') diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 091c9fd9..76eb1469 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -92,6 +92,7 @@ any matching has a cost +inf and thus can be considered to be optimal. In such a for j in dgm2_to_diagonal: print("point %s in dgm2 is matched to the diagonal" %j) + # An example where essential part cardinalities differ dgm3 = np.array([[1, 2], [0, np.inf]]) dgm4 = np.array([[1, 2], [0, np.inf], [1, np.inf]]) cost, matchings = gudhi.wasserstein.wasserstein_distance(dgm3, dgm4, matching=True, order=1, internal_p=2) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index df7acc91..121ba065 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -67,16 +67,25 @@ def test_handle_essential_parts(): def test_get_essential_parts(): - diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + diag1 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - res = _get_essential_parts(diag) + diag2 = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf]]) + + res = _get_essential_parts(diag1) + res2 = _get_essential_parts(diag2) assert np.array_equal(res[0], [4, 5]) assert np.array_equal(res[1], [2, 3]) assert np.array_equal(res[2], [8, 9]) assert np.array_equal(res[3], [6] ) assert np.array_equal(res[4], [7] ) + assert np.array_equal(res2[0], [] ) + assert np.array_equal(res2[1], [2, 3]) + assert np.array_equal(res2[2], [] ) + assert np.array_equal(res2[3], [] ) + assert np.array_equal(res2[4], [] ) + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) @@ -152,7 +161,7 @@ def pot_wrap(**extra): return fun def test_wasserstein_distance_pot(): - _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) + _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): -- cgit v1.2.3 From c1ab7c43d4797da93aa74ba823dd1a6b28fb2cfd Mon Sep 17 00:00:00 2001 From: tlacombe Date: Tue, 27 Apr 2021 12:16:22 +0200 Subject: now consider (inf,inf) as belonging to the diagonal ; more tests --- src/python/gudhi/wasserstein/wasserstein.py | 18 ++++++++++---- src/python/test/test_wasserstein_distance.py | 36 +++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 8 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 3abecfe6..5095e672 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -106,6 +106,8 @@ def _get_essential_parts(a): .. note:: For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)). + + Finally, we consider that points with coordinates (-inf,-inf) and (+inf, +inf) belong to the diagonal. ''' if len(a): first_coord_finite = np.isfinite(a[:,0]) @@ -118,6 +120,7 @@ def _get_essential_parts(a): ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x) ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf) ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf) + ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf) ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type @@ -162,7 +165,7 @@ def _handle_essential_parts(X, Y, order): ess_parts_Y = _get_essential_parts(Y) # Treats the case of infinite cost (cardinalities of essential parts differ). - for u, v in zip(ess_parts_X, ess_parts_Y): + for u, v in list(zip(ess_parts_X, ess_parts_Y))[:3]: # ignore types 4 and 5 as they belong to the diagonal if len(u) != len(v): return np.inf, None @@ -174,9 +177,14 @@ def _handle_essential_parts(X, Y, order): c = c1 + c2 m = m1 + m2 - # Handle type >= 2 (both coordinates are infinite, so we essentially just align points) - for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]): - m += list(zip(u, v)) # cost is 0 + # Handle type3 (coordinates (-inf,+inf), so we just align points) + m += list(zip(ess_parts_X[2], ess_parts_Y[2])) + + # Handle type 4 and 5, considered as belonging to the diagonal so matched to (-1) with cost 0. + for z in ess_parts_X[3:]: + m += [(u, -1) for u in z] # points in X are matched to -1 + for z in ess_parts_Y[3:]: + m += [(-1, v) for v in z] # -1 is match to points in Y return c, np.array(m) @@ -334,7 +342,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return ep.concatenate(dists).norms.lp(order).raw # We can also concatenate the 3 vectors to compute just one norm. - # Comptuation of the otcost using the ot.emd2 library. + # Comptuation of the ot cost using the ot.emd2 library. # Note: it is the Wasserstein distance to the power q. # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? ot_cost = ot.emd2(a, b, M, numItermax=2000000) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 121ba065..3a004d77 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -10,6 +10,7 @@ """ from gudhi.wasserstein.wasserstein import _proj_on_diag, _finite_part, _handle_essential_parts, _get_essential_parts +from gudhi.wasserstein.wasserstein import _warn_infty from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -50,16 +51,17 @@ def test_handle_essential_parts(): [-np.inf, np.inf], [-np.inf, np.inf]]) diag3 = np.array([[0, 2], [3, 5], - [2, np.inf], [4, np.inf], + [2, np.inf], [4, np.inf], [6, np.inf], [-np.inf, 8], [-np.inf, 11], - [-np.inf, -np.inf], [-np.inf, -np.inf], + [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) c, m = _handle_essential_parts(diag1, diag2, order=1) assert c == pytest.approx(2, 0.0001) # Note: here c is only the cost due to essential part (thus 2, not 3) # Similarly, the matching only corresponds to essential parts. - assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, 6], [7, 7]]) + # Note that (-inf,-inf) and (+inf,+inf) coordinates are matched to the diagonal. + assert np.array_equal(m, [[4, 4], [5, 5], [2, 2], [3, 3], [8, 8], [9, 9], [6, -1], [7, -1], [-1, 6], [-1, 7]]) c, m = _handle_essential_parts(diag1, diag3, order=1) assert c == np.inf @@ -87,6 +89,13 @@ def test_get_essential_parts(): assert np.array_equal(res2[4], [] ) +def test_warn_infty(): + assert _warn_infty(matching=False)==np.inf + c, m = _warn_infty(matching=True) + assert (c == np.inf) + assert (m is None) + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -143,11 +152,29 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat if test_matching and test_infinity: diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) + diag8 = np.array([[0,1], [0, np.inf], [-np.inf, -np.inf], [np.inf, np.inf]]) + diag9 = np.array([[-np.inf, -np.inf], [np.inf, np.inf]]) + diag10 = np.array([[0,1], [-np.inf, -np.inf], [np.inf, np.inf]]) match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] assert (match is None) + cost, match = wasserstein_distance(diag7, emptydiag, matching=True, internal_p=2., order=2.3) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(emptydiag, diag7, matching=True, internal_p=2.42, order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag8, diag9, matching=True, internal_p=2., order=2.) + assert (cost == np.inf) + assert (match is None) + cost, match = wasserstein_distance(diag9, diag10, matching=True, internal_p=1., order=1.) + assert (cost == 1) + assert (match == [[0, -1],[1, -1],[-1, 0], [-1, 1], [-1, 2]]) # type 4 and 5 are match to the diag anyway. + cost, match = wasserstein_distance(diag9, emptydiag, matching=True, internal_p=2., order=2.) + assert (cost == 0.) + assert (match == [[0, -1], [1, -1]]) def hera_wrap(**extra): @@ -155,15 +182,18 @@ def hera_wrap(**extra): return hera(*kargs,**kwargs,**extra) return fun + def pot_wrap(**extra): def fun(*kargs,**kwargs): return pot(*kargs,**kwargs,**extra) return fun + def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) # pot with its standard args _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) + def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) -- cgit v1.2.3 From 8b3c55502718e4c184d828151ee6f75fd2cfc9eb Mon Sep 17 00:00:00 2001 From: Hind-M Date: Tue, 25 May 2021 18:01:54 +0200 Subject: Add a separator argument that goes with the rips_complex_diagram_persistence_from_distance_matrix_file_example input file Specify explicitly the separator when using a specific input file --- src/common/test/test_distance_matrix_reader.cpp | 2 +- src/python/CMakeLists.txt | 2 +- ..._complex_diagram_persistence_from_distance_matrix_file_example.py | 5 +++-- src/python/test/test_reader_utils.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) (limited to 'src/python/test') diff --git a/src/common/test/test_distance_matrix_reader.cpp b/src/common/test/test_distance_matrix_reader.cpp index 73be8104..92e899b8 100644 --- a/src/common/test/test_distance_matrix_reader.cpp +++ b/src/common/test/test_distance_matrix_reader.cpp @@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE( full_square_distance_matrix ) { Distance_matrix from_full_square; // Read full_square_distance_matrix.csv file where the separator is the default one ';' - from_full_square = Gudhi::read_lower_triangular_matrix_from_csv_file("full_square_distance_matrix.csv"); + from_full_square = Gudhi::read_lower_triangular_matrix_from_csv_file("full_square_distance_matrix.csv", ';'); for (auto& i : from_full_square) { for (auto j : i) { std::clog << j << " "; diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index a1440cbc..bc9a3b7b 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -457,7 +457,7 @@ if(PYTHONINTERP_FOUND) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} COMMAND ${CMAKE_COMMAND} -E env "${GUDHI_PYTHON_PATH_ENV}" ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py" - --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3) + --no-diagram -f ${CMAKE_SOURCE_DIR}/data/distance_matrix/lower_triangular_distance_matrix.csv -s , -e 12.0 -d 3) add_test(NAME rips_complex_diagram_persistence_from_off_file_example_py_test WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} diff --git a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py index 9320d904..8a9cc857 100755 --- a/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py +++ b/src/python/example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py @@ -21,11 +21,12 @@ parser = argparse.ArgumentParser( description="RipsComplex creation from " "a distance matrix read in a csv file.", epilog="Example: " "example/rips_complex_diagram_persistence_from_distance_matrix_file_example.py " - "-f ../data/distance_matrix/lower_triangular_distance_matrix.csv -e 12.0 -d 3" + "-f ../data/distance_matrix/lower_triangular_distance_matrix.csv -s , -e 12.0 -d 3" "- Constructs a Rips complex with the " "distance matrix from the given csv file.", ) parser.add_argument("-f", "--file", type=str, required=True) +parser.add_argument("-s", "--separator", type=str, required=True) parser.add_argument("-e", "--max_edge_length", type=float, default=0.5) parser.add_argument("-d", "--max_dimension", type=int, default=1) parser.add_argument("-b", "--band", type=float, default=0.0) @@ -44,7 +45,7 @@ print("RipsComplex creation from distance matrix read in a csv file") message = "RipsComplex with max_edge_length=" + repr(args.max_edge_length) print(message) -distance_matrix = gudhi.read_lower_triangular_matrix_from_csv_file(csv_file=args.file, separator=',') +distance_matrix = gudhi.read_lower_triangular_matrix_from_csv_file(csv_file=args.file, separator=args.separator) rips_complex = gudhi.RipsComplex( distance_matrix=distance_matrix, max_edge_length=args.max_edge_length ) diff --git a/src/python/test/test_reader_utils.py b/src/python/test/test_reader_utils.py index 90da6651..e96e0569 100755 --- a/src/python/test/test_reader_utils.py +++ b/src/python/test/test_reader_utils.py @@ -30,7 +30,7 @@ def test_full_square_distance_matrix_csv_file(): test_file.write("0;1;2;3;\n1;0;4;5;\n2;4;0;6;\n3;5;6;0;") test_file.close() matrix = gudhi.read_lower_triangular_matrix_from_csv_file( - csv_file="full_square_distance_matrix.csv" + csv_file="full_square_distance_matrix.csv", separator=";" ) assert matrix == [[], [1.0], [2.0, 4.0], [3.0, 5.0, 6.0]] -- cgit v1.2.3