From f93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sat, 18 Apr 2020 23:52:12 +0200 Subject: enable_autodiff for POT wasserstein_distance --- src/python/gudhi/wasserstein/wasserstein.py | 64 +++++++++++++++++++++++----- src/python/test/test_wasserstein_distance.py | 14 ++++-- 2 files changed, 63 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 5df66cf9..9660b99b 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -53,17 +53,30 @@ def _build_dist_matrix(X, Y, order, internal_p): return Cf -def _perstot(X, order, internal_p): +def _perstot_autodiff(X, order, internal_p): + ''' + Version of _perstot that works on eagerpy tensors. + ''' + return _dist_to_diag(X, internal_p).norms.lp(order) + +def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). :param order: exponent for Wasserstein. Default value is 2. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation + transparent to automatic differentiation. + :type enable_autodiff: bool :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram). ''' - return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) + if enable_autodiff: + import eagerpy as ep + return _perstot_autodiff(ep.astensor(X), order, internal_p).raw + else: + return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) -def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): +def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). @@ -74,6 +87,9 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): :param order: exponent for Wasserstein; Default value is 2. :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). + :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation + transparent to automatic differentiation. + :type enable_autodiff: bool :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric. If matching is set to True, also returns the optimal matching between X and Y. @@ -82,23 +98,30 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): m = len(Y) # handle empty diagrams - if X.size == 0: - if Y.size == 0: + if n == 0: + if m == 0: if not matching: + # What if enable_autodiff? return 0. else: return 0., np.array([]) else: if not matching: - return _perstot(Y, order, internal_p) + return _perstot(Y, order, internal_p, enable_autodiff) else: - return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)]) - elif Y.size == 0: + return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)]) + elif m == 0: if not matching: - return _perstot(X, order, internal_p) + return _perstot(X, order, internal_p, enable_autodiff) else: - return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)]) - + return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)]) + + if enable_autodiff: + import eagerpy as ep + X_orig = ep.astensor(X) + Y_orig = ep.astensor(Y) + X = X_orig.numpy() + Y = Y_orig.numpy() M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p) a = np.ones(n+1) # weight vector of the input diagram. Uniform here. a[-1] = m @@ -106,6 +129,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): b[-1] = n if matching: + assert not enable_autodiff, "matching and enable_autodiff are currently incompatible" P = ot.emd(a=a,b=b,M=M, numItermax=2000000) ot_cost = np.sum(np.multiply(P,M)) P[-1, -1] = 0 # Remove matching corresponding to the diagonal @@ -115,6 +139,24 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.): match[:,1][match[:,1] >= m] = -1 return ot_cost ** (1./order) , match + if enable_autodiff: + P = ot.emd(a=a,b=b,M=M, numItermax=2000000) + pairs = np.argwhere(P[:-1, :-1]) + diag2 = np.nonzero(P[-1, :-1]) + diag1 = np.nonzero(P[:-1, -1]) + dists = [] + # empty arrays are not handled properly by the helpers, so we avoid calling them + if len(pairs): + dists.append((Y_orig[pairs[:, 1]] - X_orig[pairs[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) + if len(diag1): + dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p)) + if len(diag2): + dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) + dists = [ dist.reshape(1) for dist in dists ] + return ep.concatenate(dists).norms.lp(order) + # Should just compute the L^order norm manually? + # We can also concatenate the 3 vectors to compute just one norm. + # Comptuation of the otcost using the ot.emd2 library. # Note: it is the Wasserstein distance to the power q. # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 7e0d0f5f..5bec5bd3 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -73,14 +73,20 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat -def hera_wrap(delta): +def hera_wrap(**extra): def fun(*kargs,**kwargs): - return hera(*kargs,**kwargs,delta=delta) + return hera(*kargs,**kwargs,**extra) + return fun + +def pot_wrap(**extra): + def fun(*kargs,**kwargs): + return pot(*kargs,**kwargs,**extra) return fun def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) + _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False) - _basic_wasserstein(hera_wrap(.1), .1, test_matching=False) + _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) + _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) -- cgit v1.2.3 From b2a9ba18ce33778abdd9f5032af4bfff04e8bbd2 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sun, 19 Apr 2020 09:06:08 +0200 Subject: Unwrap the result --- src/python/gudhi/wasserstein/wasserstein.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 9660b99b..f0c82962 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -71,6 +71,7 @@ def _perstot(X, order, internal_p, enable_autodiff): ''' if enable_autodiff: import eagerpy as ep + return _perstot_autodiff(ep.astensor(X), order, internal_p).raw else: return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order) @@ -118,6 +119,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a if enable_autodiff: import eagerpy as ep + X_orig = ep.astensor(X) Y_orig = ep.astensor(Y) X = X_orig.numpy() @@ -140,10 +142,10 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a return ot_cost ** (1./order) , match if enable_autodiff: - P = ot.emd(a=a,b=b,M=M, numItermax=2000000) + P = ot.emd(a=a, b=b, M=M, numItermax=2000000) pairs = np.argwhere(P[:-1, :-1]) - diag2 = np.nonzero(P[-1, :-1]) diag1 = np.nonzero(P[:-1, -1]) + diag2 = np.nonzero(P[-1, :-1]) dists = [] # empty arrays are not handled properly by the helpers, so we avoid calling them if len(pairs): @@ -152,8 +154,8 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p)) if len(diag2): dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) - dists = [ dist.reshape(1) for dist in dists ] - return ep.concatenate(dists).norms.lp(order) + dists = [dist.reshape(1) for dist in dists] + return ep.concatenate(dists).norms.lp(order).raw # Should just compute the L^order norm manually? # We can also concatenate the 3 vectors to compute just one norm. -- cgit v1.2.3 From 1086b8cad7c1ea2a02742dfc44aef036a674f5d3 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sun, 19 Apr 2020 12:17:42 +0200 Subject: Test gradient --- src/python/test/test_wasserstein_distance.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'src') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 5bec5bd3..c6d6b346 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -90,3 +90,16 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) + +def test_wasserstein_distance_grad(): + import torch + + diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) + diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + assert diag1.grad is None and diag2.grad is None and diag3.grad is None + dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist1.backward() + dist2.backward() + assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() -- cgit v1.2.3 From bac284bf7f65c40f03ec8e47316d4f0fd0059c91 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 20 Apr 2020 19:12:35 +0200 Subject: Check that dependencies are present before testing --- src/python/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 10dcd161..5ab63e5d 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -401,7 +401,9 @@ if(PYTHONINTERP_FOUND) # Wasserstein if(OT_FOUND AND PYBIND11_FOUND) - add_gudhi_py_test(test_wasserstein_distance) + if(TORCH_FOUND AND EAGERPY_FOUND) + add_gudhi_py_test(test_wasserstein_distance) + endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() -- cgit v1.2.3 From 4ad650bc3184f57e1dda91f6b0a6358830f0562f Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 20 Apr 2020 19:42:34 +0200 Subject: Drop one comment --- src/python/gudhi/wasserstein/wasserstein.py | 1 - 1 file changed, 1 deletion(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 5b61d176..42c8dc2d 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -167,7 +167,6 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw - # Should just compute the L^order norm manually? # We can also concatenate the 3 vectors to compute just one norm. # Comptuation of the otcost using the ot.emd2 library. -- cgit v1.2.3 From da2a7a68f8f57495080af37cf981f64228d165a2 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 14:06:02 +0200 Subject: Rename local variables --- src/python/gudhi/wasserstein/wasserstein.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 42c8dc2d..3d1caeb3 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -154,17 +154,17 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a if enable_autodiff: P = ot.emd(a=a, b=b, M=M, numItermax=2000000) - pairs = np.argwhere(P[:-1, :-1]) - diag1 = np.nonzero(P[:-1, -1]) - diag2 = np.nonzero(P[-1, :-1]) + pairs_X_Y = np.argwhere(P[:-1, :-1]) + pairs_X_diag = np.nonzero(P[:-1, -1]) + pairs_Y_diag = np.nonzero(P[-1, :-1]) dists = [] # empty arrays are not handled properly by the helpers, so we avoid calling them - if len(pairs): - dists.append((Y_orig[pairs[:, 1]] - X_orig[pairs[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) - if len(diag1): - dists.append(_perstot_autodiff(X_orig[diag1], order, internal_p)) - if len(diag2): - dists.append(_perstot_autodiff(Y_orig[diag2], order, internal_p)) + if len(pairs_X_Y): + dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order)) + if len(pairs_X_diag): + dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p)) + if len(pairs_Y_diag): + dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p)) dists = [dist.reshape(1) for dist in dists] return ep.concatenate(dists).norms.lp(order).raw # We can also concatenate the 3 vectors to compute just one norm. -- cgit v1.2.3 From 51f7b5bb15f351d08af4c26bd1ffdfe979199976 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 16:29:26 +0200 Subject: Test value of computed gradient --- src/python/test/test_wasserstein_distance.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6bfcb2ee..90d26809 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -105,8 +105,19 @@ def test_wasserstein_distance_grad(): diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist1 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist2 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) - dist1.backward() - dist2.backward() + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12.backward() + dist30.backward() assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() + diag4 = torch.tensor([[0., 10.]], requires_grad=True) + diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + dist45.backward() + assert np.array_equal(diag4.grad, [[-1., -1.]]) + assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) + diag6 = torch.tensor([[5., 10.]], requires_grad=True) + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + # https://github.com/jonasrauber/eagerpy/issues/6 + # assert np.array_equal(diag6.grad, [[0., 0.]]) -- cgit v1.2.3 From ba17759cf922d246a0a74ac5cf99f67d48a7d8c3 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 16:52:27 +0200 Subject: Clarify the doc of enable_autodiff --- src/python/gudhi/wasserstein/wasserstein.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 3d1caeb3..0d164eda 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -100,7 +100,10 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation - transparent to automatic differentiation. + transparent to automatic differentiation. This requires the package EagerPy. + + .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y + and lets the various frameworks compute its gradient. It never pulls new points from the diagonal. :type enable_autodiff: bool :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric. -- cgit v1.2.3 From a643583a4740fc40cf1e06e6cc1b4d17ca14000f Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Wed, 22 Apr 2020 17:39:52 +0200 Subject: Document incompatibility of matching=True and enable_autodiff --- src/python/gudhi/wasserstein/wasserstein.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 0d164eda..89ecab1c 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -100,7 +100,8 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_a :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation - transparent to automatic differentiation. This requires the package EagerPy. + transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible + with `matching=True`. .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y and lets the various frameworks compute its gradient. It never pulls new points from the diagonal. -- cgit v1.2.3