summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-12 15:52:36 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-12 15:52:36 +0200
commit777522b82bde16b55f15c21471bad06038849fd1 (patch)
tree9e6e0c78d893c0c7b23953bfd51dade4ce10bf08 /src/python
parent01bd9eef85b0d93eb1629f1a0c5a28a359e4e7b9 (diff)
improved essential part and enable autodiff management
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py75
1 files changed, 41 insertions, 34 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index d64d433e..2911f826 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -95,7 +95,7 @@ def _perstot(X, order, internal_p, enable_autodiff):
def _get_essential_parts(a):
'''
:param a: (n x 2) numpy.array (point of a diagram)
- :retuns: five lists of indices (between 0 and len(a)) accounting for the five types of points with infinite
+ :returns: five lists of indices (between 0 and len(a)) accounting for the five types of points with infinite
coordinates that can occur in a diagram, namely:
type0 : (-inf, finite)
type1 : (finite, +inf)
@@ -104,13 +104,20 @@ def _get_essential_parts(a):
type4 : (+inf, +inf)
.. note::
For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x.
+ Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)).
'''
if len(a):
- ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x)
- ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf)
- ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf)
- ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf)
- ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf)
+ first_coord_finite = np.isfinite(a[:,0])
+ second_coord_finite = np.isfinite(a[:,1])
+ first_coord_infinite_positive = (a[:,0] == np.inf)
+ second_coord_infinite_positive = (a[:,1] == np.inf)
+ first_coord_infinite_negative = (a[:,0] == -np.inf)
+ second_coord_infinite_negative = (a[:,1] == -np.inf)
+ ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x)
+ ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf)
+ ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf)
+ ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf)
+ ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf)
return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type
else:
return [], [], [], [], []
@@ -136,7 +143,7 @@ def _cost_and_match_essential_parts(X, Y, idX, idY, order, axis):
sortidX = idX[np.argsort(u)]
sortidY = idY[np.argsort(v)]
- # We return [i,j] sorted per value, and then [i, -1] (or [-1, j]) to account for essential points matched to the diagonal
+ # We return [i,j] sorted per value
match = list(zip(sortidX, sortidY))
return cost, match
@@ -149,9 +156,6 @@ def _handle_essential_parts(X, Y, order):
:order: Wasserstein order for cost computation.
:returns: cost and matching due to essential parts. If cost is +inf, matching will be set to None.
'''
- c = 0
- m = []
-
ess_parts_X = _get_essential_parts(X)
ess_parts_Y = _get_essential_parts(Y)
@@ -165,8 +169,8 @@ def _handle_essential_parts(X, Y, order):
c1, m1 = _cost_and_match_essential_parts(X, Y, ess_parts_X[0], ess_parts_Y[0], axis=1, order=order)
c2, m2 = _cost_and_match_essential_parts(X, Y, ess_parts_X[1], ess_parts_Y[1], axis=0, order=order)
- c += c1 + c2
- m += m1 + m2
+ c = c1 + c2
+ m = m1 + m2
# Handle type >= 2 (both coordinates are infinite, so we essentially just align points)
for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]):
@@ -175,24 +179,18 @@ def _handle_essential_parts(X, Y, order):
return c, np.array(m)
-def _finite_part(X, enable_autodiff):
+def _finite_part(X):
'''
:param X: (n x 2) numpy array encoding a persistence diagram.
- :param enable_autodiff: boolean, to handle the case where X is a eagerpy tensor.
:returns: The finite part of a diagram `X` (points with finite coordinates).
'''
- if enable_autodiff:
- # Assumes the diagrams only have finite coordinates. Thus, return X directly.
- # TODO improve this to get rid of essential parts if there are any.
- return X
- else:
- return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
+ return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False,
keep_essential_parts=True):
'''
- :param X: (n x 2) numpy.array encoding the first diagram. Can now contain essential parts (points with infinite
+ :param X: (n x 2) numpy.array encoding the first diagram. Can contain essential parts (points with infinite
coordinates).
:param Y: (m x 2) numpy.array encoding the second diagram.
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
@@ -200,17 +198,17 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
the j-th point in Y, with the convention (-1) represents the diagonal.
Note that if the cost is +inf (essential parts have different number of points,
then the optimal matching will be set to `None`.
- :param order: exponent for Wasserstein; Default value is 1.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ :param order: exponent for Wasserstein. Default value is 1.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2).
Default value is `np.inf`.
:param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation
transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible
with `matching=True` and with `keep_essential_parts=True`.
- .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y
+ .. note:: This considers the function defined on the coordinates of the off-diagonal finite points of X and Y
and lets the various frameworks compute its gradient. It never pulls new points from the diagonal.
:type enable_autodiff: bool
- :param keep_essential_parts: If False, only considers the off-diagonal points in the diagrams.
+ :param keep_essential_parts: If False, only considers the finite points in the diagrams.
Otherwise, computes the distance between the essential parts separately.
:type keep_essential_parts: bool
:returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
@@ -235,7 +233,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return _perstot(Y, order, internal_p, enable_autodiff)
else:
cost = _perstot(Y, order, internal_p, enable_autodiff)
- if cost == np.inf: # We had some essential part here.
+ if cost == np.inf: # We had some essential part in Y.
return cost, None
else:
return cost, np.array([[-1, j] for j in range(m)])
@@ -250,24 +248,28 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
return cost, np.array([[i, -1] for i in range(n)])
- # Second step: handle essential parts
+ # Check essential part and enable autodiff together
+ if enable_autodiff and keep_essential_parts:
+ import warnings # should it be done at the top of the file?
+ warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together.
+ keep_essential_parts is set to False: only points with finite coordiantes are considered
+ in the following.
+ ''')
+ keep_essential_parts = False
+
+ # Second step: handle essential parts if needed.
if keep_essential_parts:
essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order)
if (essential_cost == np.inf):
if matching:
return np.inf, None
else:
- return np.inf # avoid computing off-diagonal transport cost if essential parts do not match (saves time)
+ return np.inf # avoid computing transport cost between the finite parts if essential parts
+ # cardinalities do not match (saves time)
else:
essential_cost = 0
essential_matching = None
- # Extract finite points of the diagrams. Note that if enable_autodiff is True, nothing is done here (X,Y are
- # assumed to be tensors with only finite coordinates).
- X, Y = _finite_part(X, enable_autodiff), _finite_part(Y, enable_autodiff)
- n = len(X)
- m = len(Y)
-
# Now the standard pipeline for finite parts
if enable_autodiff:
import eagerpy as ep
@@ -277,6 +279,11 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
X = X_orig.numpy()
Y = Y_orig.numpy()
+ # Extract finite points of the diagrams.
+ X, Y = _finite_part(X), _finite_part(Y)
+ n = len(X)
+ m = len(Y)
+
M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p)
a = np.ones(n+1) # weight vector of the input diagram. Uniform here.
a[-1] = m