From fe3e6a3a47828841ba3cb4a0721e5d8c16ab126f Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 6 Jul 2020 18:27:52 +0200 Subject: update test including essential parts --- src/python/gudhi/wasserstein/wasserstein.py | 18 +++++-- src/python/test/test_wasserstein_distance.py | 72 +++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 283ecd9d..2a1dee7a 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -105,10 +105,10 @@ def _get_essential_parts(a): For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. ''' if len(a): - ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) + ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf) - ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) - ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) + ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) + ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type else: @@ -232,12 +232,20 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if not matching: return _perstot(Y, order, internal_p, enable_autodiff) else: - return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)]) + cost = _perstot(Y, order, internal_p, enable_autodiff) + if cost == np.inf: # We had some essential part here. + return cost, None + else: + return cost, np.array([[-1, j] for j in range(m)]) elif m == 0: if not matching: return _perstot(X, order, internal_p, enable_autodiff) else: - return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)]) + cost = _perstot(X, order, internal_p, enable_autodiff) + if cost == np.inf: + return cost, None + else: + return np.array([[i, -1] for i in range(n)]) # Second step: handle essential parts diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 90d26809..24be228b 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -5,10 +5,11 @@ Copyright (C) 2019 Inria Modification(s): + - 2020/07 Théo Lacombe: Added tests about handling essential parts in diagrams. - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein.wasserstein import _proj_on_diag +from gudhi.wasserstein.wasserstein import _proj_on_diag, _offdiag, _handle_essential_parts from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np @@ -18,12 +19,62 @@ __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" + def test_proj_on_diag(): dgm = np.array([[1., 1.], [1., 2.], [3., 5.]]) assert np.array_equal(_proj_on_diag(dgm), [[1., 1.], [1.5, 1.5], [4., 4.]]) empty = np.empty((0, 2)) assert np.array_equal(_proj_on_diag(empty), empty) + +def test_offdiag(): + diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) + assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]]) + + +def test_handle_essential_parts(): + diag1 = np.array([[0, 1], [3, 5], + [2, np.inf], [3, np.inf], + [-np.inf, 8], [-np.inf, 12], + [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + diag2 = np.array([[0, 2], [3, 5], + [2, np.inf], [4, np.inf], + [-np.inf, 8], [-np.inf, 11], + [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + diag3 = np.array([[0, 2], [3, 5], + [2, np.inf], [4, np.inf], + [-np.inf, 8], [-np.inf, 11], + [-np.inf, -np.inf], [-np.inf, -np.inf], + [np.inf, np.inf], + [-np.inf, np.inf], [-np.inf, np.inf]]) + + c, m = _handle_essential_parts(diag1, diag2, matching=True, order=1) + assert c == pytest.approx(3, 0.0001) + assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) + c, m = _handle_essential_parts(diag1, diag3, matching=True, order=1) + assert c == np.inf + assert (m is None) + + +def test_get_essential_parts(): + diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], + [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) + + res = _get_essential_parts(diag) + assert res[0] = [4, 5] + assert res[1] = [2, 3] + assert res[2] = [8, 9] + assert res[3] = [6] + assert res[4] = [7] + + def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) @@ -64,7 +115,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat assert wasserstein_distance(diag4, diag5) == np.inf assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) - + assert wasserstein_distance(diag5, emptydiag) == np.inf if test_matching: match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1] @@ -78,6 +129,13 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1] assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]]) + if test_matching and test_infinity: + diag7 = np.array([[0, 3], [4, np.inf], [5, np.inf]]) + + match = wasserstein_distance(diag5, diag6, matching=True, internal_p=2., order=2.)[1] + assert np.array_equal(match, [[0, -1], [-1,0], [-1, 1], [1, 2]]) + match = wasserstein_distance(diag5, diag7, matching=True, internal_p=2., order=2.)[1] + assert (match is None) def hera_wrap(**extra): @@ -92,7 +150,7 @@ def pot_wrap(**extra): def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) - _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False) + _basic_wasserstein(pot_wrap(enable_autodiff=True, keep_essential_parts=False), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) @@ -105,19 +163,19 @@ def test_wasserstein_distance_grad(): diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False) dist12.backward() dist30.backward() assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() diag4 = torch.tensor([[0., 10.]], requires_grad=True) diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True, keep_essential_parts=False) assert dist45 == 3. dist45.backward() assert np.array_equal(diag4.grad, [[-1., -1.]]) assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True, keep_essential_parts=False).backward() # https://github.com/jonasrauber/eagerpy/issues/6 # assert np.array_equal(diag6.grad, [[0., 0.]]) -- cgit v1.2.3