summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/barycenter.py49
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py67
2 files changed, 75 insertions, 41 deletions
diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py
index de7aea81..d67bcde7 100644
--- a/src/python/gudhi/wasserstein/barycenter.py
+++ b/src/python/gudhi/wasserstein/barycenter.py
@@ -18,8 +18,7 @@ from gudhi.wasserstein import wasserstein_distance
def _mean(x, m):
'''
:param x: a list of 2D-points, off diagonal, x_0... x_{k-1}
- :param m: total amount of points taken into account,
- that is we have (m-k) copies of diagonal
+ :param m: total amount of points taken into account, that is we have (m-k) copies of diagonal
:returns: the weighted mean of x with (m-k) copies of the diagonal
'''
k = len(x)
@@ -33,37 +32,26 @@ def _mean(x, m):
def lagrangian_barycenter(pdiagset, init=None, verbose=False):
'''
- :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)`
- (`n` can variate), encoding a set of
- persistence diagrams with only finite coordinates.
+ :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)` (`n` can variate), encoding a set of persistence
+ diagrams with only finite coordinates.
:param init: The initial value for barycenter estimate.
- If ``None``, init is made on a random diagram from the dataset.
- Otherwise, it can be an ``int``
- (then initialization is made on ``pdiagset[init]``)
- or a `(n x 2)` ``numpy.array`` enconding
- a persistence diagram with `n` points.
+ If ``None``, init is made on a random diagram from the dataset.
+ Otherwise, it can be an ``int`` (then initialization is made on ``pdiagset[init]``)
+ or a `(n x 2)` ``numpy.array`` enconding a persistence diagram with `n` points.
:type init: ``int``, or (n x 2) ``np.array``
- :param verbose: if ``True``, returns additional information about the
- barycenter.
+ :param verbose: if ``True``, returns additional information about the barycenter.
:type verbose: boolean
- :returns: If not verbose (default), a ``numpy.array`` encoding
- the barycenter estimate of pdiagset
- (local minimum of the energy function).
- If ``pdiagset`` is empty, returns ``None``.
- If verbose, returns a couple ``(Y, log)``
- where ``Y`` is the barycenter estimate,
- and ``log`` is a ``dict`` that contains additional informations:
-
- - `"groupings"`, a list of list of pairs ``(i,j)``.
- Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates
- that ``pdiagset[k][i]`` is matched to ``Y[j]``
- if ``i = -1`` or ``j = -1``, it means they
- represent the diagonal.
-
- - `"energy"`, ``float`` representing the Frechet energy value obtained.
- It is the mean of squared distances of observations to the output.
-
- - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm.
+ :returns: If not verbose (default), a ``numpy.array`` encoding the barycenter estimate of pdiagset
+ (local minimum of the energy function).
+ If ``pdiagset`` is empty, returns ``None``.
+ If verbose, returns a couple ``(Y, log)`` where ``Y`` is the barycenter estimate,
+ and ``log`` is a ``dict`` that contains additional informations:
+
+ - `"groupings"`, a list of list of pairs ``(i,j)``. Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates that `pdiagset[k][i]`` is matched to ``Y[j]`` if ``i = -1`` or ``j = -1``, it means they represent the diagonal.
+
+ - `"energy"`, ``float`` representing the Frechet energy value obtained. It is the mean of squared distances of observations to the output.
+
+ - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm.
'''
X = pdiagset # to shorten notations, not a copy
m = len(X) # number of diagrams we are averaging
@@ -156,4 +144,3 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
return Y, log
else:
return Y
-
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index efc851a0..89ecab1c 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -64,17 +64,31 @@ def _build_dist_matrix(X, Y, order, internal_p):
return Cf
-def _perstot(X, order, internal_p):
+def _perstot_autodiff(X, order, internal_p):
+ '''
+ Version of _perstot that works on eagerpy tensors.
+ '''
+ return _dist_to_diag(X, internal_p).norms.lp(order)
+
+def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
:param order: exponent for Wasserstein. Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
+ :param enable_autodiff: If X is torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ transparent to automatic differentiation.
+ :type enable_autodiff: bool
:returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
'''
- return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
+ if enable_autodiff:
+ import eagerpy as ep
+
+ return _perstot_autodiff(ep.astensor(X), order, internal_p).raw
+ else:
+ return np.linalg.norm(_dist_to_diag(X, internal_p), ord=order)
-def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
+def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2., enable_autodiff=False):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
@@ -85,6 +99,13 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
:param order: exponent for Wasserstein; Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
Default value is 2 (Euclidean norm).
+ :param enable_autodiff: If X and Y are torch.tensor, tensorflow.Tensor or jax.numpy.ndarray, make the computation
+ transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
+ with `matching=True`.
+
+ .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y
+ and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
+ :type enable_autodiff: bool
:returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
respect to the internal_p-norm as ground metric.
If matching is set to True, also returns the optimal matching between X and Y.
@@ -93,23 +114,31 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
m = len(Y)
# handle empty diagrams
- if X.size == 0:
- if Y.size == 0:
+ if n == 0:
+ if m == 0:
if not matching:
+ # What if enable_autodiff?
return 0.
else:
return 0., np.array([])
else:
if not matching:
- return _perstot(Y, order, internal_p)
+ return _perstot(Y, order, internal_p, enable_autodiff)
else:
- return _perstot(Y, order, internal_p), np.array([[-1, j] for j in range(m)])
- elif Y.size == 0:
+ return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)])
+ elif m == 0:
if not matching:
- return _perstot(X, order, internal_p)
+ return _perstot(X, order, internal_p, enable_autodiff)
else:
- return _perstot(X, order, internal_p), np.array([[i, -1] for i in range(n)])
+ return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)])
+ if enable_autodiff:
+ import eagerpy as ep
+
+ X_orig = ep.astensor(X)
+ Y_orig = ep.astensor(Y)
+ X = X_orig.numpy()
+ Y = Y_orig.numpy()
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
a[-1] = m
@@ -117,6 +146,7 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
b[-1] = n
if matching:
+ assert not enable_autodiff, "matching and enable_autodiff are currently incompatible"
P = ot.emd(a=a,b=b,M=M, numItermax=2000000)
ot_cost = np.sum(np.multiply(P,M))
P[-1, -1] = 0 # Remove matching corresponding to the diagonal
@@ -126,6 +156,23 @@ def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
match[:,1][match[:,1] >= m] = -1
return ot_cost ** (1./order) , match
+ if enable_autodiff:
+ P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
+ pairs_X_Y = np.argwhere(P[:-1, :-1])
+ pairs_X_diag = np.nonzero(P[:-1, -1])
+ pairs_Y_diag = np.nonzero(P[-1, :-1])
+ dists = []
+ # empty arrays are not handled properly by the helpers, so we avoid calling them
+ if len(pairs_X_Y):
+ dists.append((Y_orig[pairs_X_Y[:, 1]] - X_orig[pairs_X_Y[:, 0]]).norms.lp(internal_p, axis=-1).norms.lp(order))
+ if len(pairs_X_diag):
+ dists.append(_perstot_autodiff(X_orig[pairs_X_diag], order, internal_p))
+ if len(pairs_Y_diag):
+ dists.append(_perstot_autodiff(Y_orig[pairs_Y_diag], order, internal_p))
+ dists = [dist.reshape(1) for dist in dists]
+ return ep.concatenate(dists).norms.lp(order).raw
+ # We can also concatenate the 3 vectors to compute just one norm.
+
# Comptuation of the otcost using the ot.emd2 library.
# Note: it is the Wasserstein distance to the power q.
# The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value?