summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein/wasserstein.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-07-07 18:15:17 +0200
committertlacombe <lacombe1993@gmail.com>2020-07-07 18:15:17 +0200
commit107f8e6668509f5fd36e179f9a538b460d3941a9 (patch)
tree34eff3e267d6355d5200e3f56c7e04032f928b8c /src/python/gudhi/wasserstein/wasserstein.py
parent42a399c273fde7c76ec23d2993957fcbb492ee79 (diff)
added enable autodiff management in _offdiag utils function
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 981bbf08..495142c4 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -70,6 +70,7 @@ def _perstot_autodiff(X, order, internal_p):
'''
return _dist_to_diag(X, internal_p).norms.lp(order)
+
def _perstot(X, order, internal_p, enable_autodiff):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
@@ -174,12 +175,18 @@ def _handle_essential_parts(X, Y, order):
return c, np.array(m)
-def _offdiag(X):
+def _offdiag(X, enable_autodiff):
'''
:param X: (n x 2) numpy array encoding a persistence diagram.
:returns: The off-diagonal part of a diagram `X` (points with finite coordinates).
'''
- return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
+ if enable_autodiff:
+ import eagerpy as ep
+
+ return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))])
+
+ else:
+ return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False,
@@ -261,7 +268,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
essential_matching = None
# Extract off-diaognal points of the diagrams.
- X, Y = _offdiag(X), _offdiag(Y)
+ X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff)
n = len(X)
m = len(Y)