summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-30 23:45:57 +0200
committerGitHub <noreply@github.com>2020-04-30 23:45:57 +0200
commitebee08bfaa36603170aaa0799923c84896652f66 (patch)
tree0f21b7863ee06091aa47e14d6fee35675603515e /src
parent95d8363bd5f723db598d77876ba65cdfde479797 (diff)
parenta643583a4740fc40cf1e06e6cc1b4d17ca14000f (diff)
Merge pull request #285 from mglisse/wass-autodiff
Automatic differentiation for Wasserstein distance
Diffstat (limited to 'src')
-rw-r--r--src/python/CMakeLists.txt4
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py67
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py38
3 files changed, 94 insertions, 15 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 4771cef9..d712e189 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -467,7 +467,9 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
- add_gudhi_py_test(test_wasserstein_distance)
+ if(TORCH_FOUND AND EAGERPY_FOUND)
+ add_gudhi_py_test(test_wasserstein_distance)
+ endif()
add_gudhi_py_test(test_wasserstein_barycenter)
endif()
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index efc851a0..89ecab1c 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -64,17 +64,31 @@ def _build_dist_matrix(X, Y, order, internal_p):
return Cf
-def _perstot(X, order, internal_p):
+def _perstot_autodiff(X, order, internal_p):
+ '''
+ Version of _perstot that works on eagerpy tensors.
+ '''
+ return _dist_to_diag(X, internal_p).norms.lp(order)
+
+def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
:param order: exponent for Wasserstein. Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
+ :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ transparent to automatic differentiation.
+ :type enable_autodiff: bool
:returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
'''
- return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
+ if enable_autodiff:
+ import eagerpy as ep
+
+ return _perstot_autodiff(ep.astensor(X), order, internal_p).raw
+ else:
+ return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
@@ -85,6 +99,13 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
:param order: exponent for Wasserstein; Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
Default value is 2 (Euclidean norm).
+ :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
+ with `matching=True`.
+
+ .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y
+ and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
+ :type enable_autodiff: bool
:returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
respect to the internal_p-norm as ground metric.
If matching is set to True, also returns the optimal matching between X and Y.
@@ -93,23 +114,31 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
m = len(Y)
# handle empty diagrams
- if X.size == 0:
- if Y.size == 0:
+ if n == 0:
+ if m == 0:
if not matching:
+ # What if enable_autodiff?
return 0.
else:
return 0., np.array([])
else:
if not matching:
- return _perstot(Y, order, internal_p)
+ return _perstot(Y, order, internal_p, enable_autodiff)
else:
- return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)])
- elif Y.size == 0:
+ return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)])
+ elif m == 0:
if not matching:
- return _perstot(X, order, internal_p)
+ return _perstot(X, order, internal_p, enable_autodiff)
else:
- return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)])
+ return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)])
+ if enable_autodiff:
+ import eagerpy as ep
+
+ X_orig = ep.astensor(X)
+ Y_orig = ep.astensor(Y)
+ X = X_orig.numpy()
+ Y = Y_orig.numpy()
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
a[-1] = m
@@ -117,6 +146,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
b[-1] = n
if matching:
+ assert not enable_autodiff, "matching and enable_autodiff are currently incompatible"
P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
ot_cost = np.sum(np.multiply(P,M))
P[-1, -1] = 0 # Remove matching corresponding to the diagonal
@@ -126,6 +156,23 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
match[:,1][match[:,1] >= m] = -1
return ot_cost ** (1./order) , match
+ if enable_autodiff:
+ P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
+ pairs_X_Y = np.argwhere(P[:-1, :-1])
+ pairs_X_diag = np.nonzero(P[:-1, -1])
+ pairs_Y_diag = np.nonzero(P[-1, :-1])
+ dists = []
+ # empty arrays are not handled properly by the helpers, so we avoid calling them
+ if len(pairs_X_Y):
+ dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
+ if len(pairs_X_diag):
+ dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
+ if len(pairs_Y_diag):
+ dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
+ dists = [dist.reshape(1) for dist in dists]
+ return ep.concatenate(dists).norms.lp(order).raw
+ # We can also concatenate the 3 vectors to compute just one norm.
+
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 1a4acc1d..90d26809 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -80,14 +80,44 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
-def hera_wrap(delta):
+def hera_wrap(**extra):
def fun(*kargs,**kwargs):
- return hera(*kargs,**kwargs,delta=delta)
+ return hera(*kargs,**kwargs,**extra)
+ return fun
+
+def pot_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return pot(*kargs,**kwargs,**extra)
return fun
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
+ _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)
+
+def test_wasserstein_distance_grad():
+ import torch
+
+ diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True)
+ diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True)
+ assert diag1.grad is None and diag2.grad is None and diag3.grad is None
+ dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True)
+ dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True)
+ dist12.backward()
+ dist30.backward()
+ assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any()
+ diag4 = torch.tensor([[0., 10.]], requires_grad=True)
+ diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True)
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+ dist45.backward()
+ assert np.array_equal(diag4.grad, [[-1., -1.]])
+ assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]])
+ diag6 = torch.tensor([[5., 10.]], requires_grad=True)
+ pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward()
+ # https://github.com/jonasrauber/eagerpy/issues/6
+ # assert np.array_equal(diag6.grad, [[0., 0.]])