diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-12 15:52:36 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-12 15:52:36 +0200 |
commit | 777522b82bde16b55f15c21471bad06038849fd1 (patch) | |
tree | 9e6e0c78d893c0c7b23953bfd51dade4ce10bf08 /src/python/gudhi | |
parent | 01bd9eef85b0d93eb1629f1a0c5a28a359e4e7b9 (diff) |
improved essential part and enable autodiff management
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 75 |
1 files changed, 41 insertions, 34 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index d64d433e..2911f826 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -95,7 +95,7 @@ def _perstot(X, order, internal_p, enable_autodiff): def _get_essential_parts(a): ''' :param a: (n x 2) numpy.array (point of a diagram) - :retuns: five lists of indices (between 0 and len(a)) accounting for the five types of points with infinite + :returns: five lists of indices (between 0 and len(a)) accounting for the five types of points with infinite coordinates that can occur in a diagram, namely: type0 : (-inf, finite) type1 : (finite, +inf) @@ -104,13 +104,20 @@ def _get_essential_parts(a): type4 : (+inf, +inf) .. note:: For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. + Note also that points with (+inf, -inf) are not handled (points (x,y) in dgm satisfy by assumption (y >= x)). ''' if len(a): - ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) - ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf) - ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) - ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) - ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf) + first_coord_finite = np.isfinite(a[:,0]) + second_coord_finite = np.isfinite(a[:,1]) + first_coord_infinite_positive = (a[:,0] == np.inf) + second_coord_infinite_positive = (a[:,1] == np.inf) + first_coord_infinite_negative = (a[:,0] == -np.inf) + second_coord_infinite_negative = (a[:,1] == -np.inf) + ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x) + ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf) + ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf) + ess_fourth_type = np.where(first_coord_infinite_negative & second_coord_infinite_negative)[0] # coord (-inf, -inf) + ess_fifth_type = np.where(first_coord_infinite_positive & second_coord_infinite_positive)[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type else: return [], [], [], [], [] @@ -136,7 +143,7 @@ def _cost_and_match_essential_parts(X, Y, idX, idY, order, axis): sortidX = idX[np.argsort(u)] sortidY = idY[np.argsort(v)] - # We return [i,j] sorted per value, and then [i, -1] (or [-1, j]) to account for essential points matched to the diagonal + # We return [i,j] sorted per value match = list(zip(sortidX, sortidY)) return cost, match @@ -149,9 +156,6 @@ def _handle_essential_parts(X, Y, order): :order: Wasserstein order for cost computation. :returns: cost and matching due to essential parts. If cost is +inf, matching will be set to None. ''' - c = 0 - m = [] - ess_parts_X = _get_essential_parts(X) ess_parts_Y = _get_essential_parts(Y) @@ -165,8 +169,8 @@ def _handle_essential_parts(X, Y, order): c1, m1 = _cost_and_match_essential_parts(X, Y, ess_parts_X[0], ess_parts_Y[0], axis=1, order=order) c2, m2 = _cost_and_match_essential_parts(X, Y, ess_parts_X[1], ess_parts_Y[1], axis=0, order=order) - c += c1 + c2 - m += m1 + m2 + c = c1 + c2 + m = m1 + m2 # Handle type >= 2 (both coordinates are infinite, so we essentially just align points) for u, v in zip(ess_parts_X[2:], ess_parts_Y[2:]): @@ -175,24 +179,18 @@ def _handle_essential_parts(X, Y, order): return c, np.array(m) -def _finite_part(X, enable_autodiff): +def _finite_part(X): ''' :param X: (n x 2) numpy array encoding a persistence diagram. - :param enable_autodiff: boolean, to handle the case where X is a eagerpy tensor. :returns: The finite part of a diagram `X` (points with finite coordinates). ''' - if enable_autodiff: - # Assumes the diagrams only have finite coordinates. Thus, return X directly. - # TODO improve this to get rid of essential parts if there are any. - return X - else: - return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] + return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False, keep_essential_parts=True): ''' - :param X: (n x 2) numpy.array encoding the first diagram. Can now contain essential parts (points with infinite + :param X: (n x 2) numpy.array encoding the first diagram. Can contain essential parts (points with infinite coordinates). :param Y: (m x 2) numpy.array encoding the second diagram. :param matching: if True, computes and returns the optimal matching between X and Y, encoded as @@ -200,17 +198,17 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab the j-th point in Y, with the convention (-1) represents the diagonal. Note that if the cost is +inf (essential parts have different number of points, then the optimal matching will be set to `None`. - :param order: exponent for Wasserstein; Default value is 1. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); + :param order: exponent for Wasserstein. Default value is 1. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2). Default value is `np.inf`. :param enable_autodiff: If X and Y are torch.tensor or tensorflow.Tensor, make the computation transparent to automatic differentiation. This requires the package EagerPy and is currently incompatible with `matching=True` and with `keep_essential_parts=True`. - .. note:: This considers the function defined on the coordinates of the off-diagonal points of X and Y + .. note:: This considers the function defined on the coordinates of the off-diagonal finite points of X and Y and lets the various frameworks compute its gradient. It never pulls new points from the diagonal. :type enable_autodiff: bool - :param keep_essential_parts: If False, only considers the off-diagonal points in the diagrams. + :param keep_essential_parts: If False, only considers the finite points in the diagrams. Otherwise, computes the distance between the essential parts separately. :type keep_essential_parts: bool :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with @@ -235,7 +233,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return _perstot(Y, order, internal_p, enable_autodiff) else: cost = _perstot(Y, order, internal_p, enable_autodiff) - if cost == np.inf: # We had some essential part here. + if cost == np.inf: # We had some essential part in Y. return cost, None else: return cost, np.array([[-1, j] for j in range(m)]) @@ -250,24 +248,28 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab return cost, np.array([[i, -1] for i in range(n)]) - # Second step: handle essential parts + # Check essential part and enable autodiff together + if enable_autodiff and keep_essential_parts: + import warnings # should it be done at the top of the file? + warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together. + keep_essential_parts is set to False: only points with finite coordiantes are considered + in the following. + ''') + keep_essential_parts = False + + # Second step: handle essential parts if needed. if keep_essential_parts: essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order) if (essential_cost == np.inf): if matching: return np.inf, None else: - return np.inf # avoid computing off-diagonal transport cost if essential parts do not match (saves time) + return np.inf # avoid computing transport cost between the finite parts if essential parts + # cardinalities do not match (saves time) else: essential_cost = 0 essential_matching = None - # Extract finite points of the diagrams. Note that if enable_autodiff is True, nothing is done here (X,Y are - # assumed to be tensors with only finite coordinates). - X, Y = _finite_part(X, enable_autodiff), _finite_part(Y, enable_autodiff) - n = len(X) - m = len(Y) - # Now the standard pipeline for finite parts if enable_autodiff: import eagerpy as ep @@ -277,6 +279,11 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab X = X_orig.numpy() Y = Y_orig.numpy() + # Extract finite points of the diagrams. + X, Y = _finite_part(X), _finite_part(Y) + n = len(X) + m = len(Y) + M = _build_dist_matrix(X, Y, order=order, internal_p=internal_p) a = np.ones(n+1) # weight vector of the input diagram. Uniform here. a[-1] = m |