summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 981bbf08..495142c4 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -70,6 +70,7 @@ def _perstot_autodiff(X, order, internal_p):
'''
return _dist_to_diag(X, internal_p).norms.lp(order)
+
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
@@ -174,12 +175,18 @@ def _handle_essential_parts(X, Y, order):
return c, np.array(m)
-def _offdiag(X):
+def _offdiag(X, enable_autodiff):
'''
:param X: (n x 2) numpy array encoding a persistence diagram.
:returns: The off-diagonal part of a diagram `X` (points with finite coordinates).
'''
- return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
+ if enable_autodiff:
+ import eagerpy as ep
+
+ return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))])
+
+ else:
+ return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False,
@@ -261,7 +268,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
essential_matching = None
# Extract off-diaognal points of the diagrams.
- X, Y = _offdiag(X), _offdiag(Y)
+ X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff)
n = len(X)
m = len(Y)