diff options
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 981bbf08..495142c4 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -70,6 +70,7 @@ def _perstot_autodiff(X, order, internal_p): ''' return _dist_to_diag(X, internal_p).norms.lp(order) + def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). @@ -174,12 +175,18 @@ def _handle_essential_parts(X, Y, order): return c, np.array(m) -def _offdiag(X): +def _offdiag(X, enable_autodiff): ''' :param X: (n x 2) numpy array encoding a persistence diagram. :returns: The off-diagonal part of a diagram `X` (points with finite coordinates). ''' - return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] + if enable_autodiff: + import eagerpy as ep + + return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]) + + else: + return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False, @@ -261,7 +268,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab essential_matching = None # Extract off-diaognal points of the diagrams. - X, Y = _offdiag(X), _offdiag(Y) + X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff) n = len(X) m = len(Y) |