summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py16
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
2 files changed, 7 insertions, 11 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 495142c4..142385b1 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -178,13 +178,13 @@ def _handle_essential_parts(X, Y, order):
def _offdiag(X, enable_autodiff):
'''
:param X: (n x 2) numpy array encoding a persistence diagram.
+ :param enable_autodiff: boolean, to handle the case where X is a eagerpy tensor.
:returns: The off-diagonal part of a diagram `X` (points with finite coordinates).
'''
if enable_autodiff:
- import eagerpy as ep
-
- return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))])
-
+ # Assumes the diagrams only have finite coordinates. Thus, return X directly.
+ # TODO improve this to get rid of essential parts if there are any.
+ return X
else:
return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
@@ -218,11 +218,6 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
If matching is set to True, also returns the optimal matching between X and Y.
If cost is +inf, any matching is optimal and thus it returns `None` instead.
'''
- # Zeroth step: check compatibility of arguments
- if keep_essential_parts and enable_autodiff:
- import warnings
- warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.")
- keep_essential_parts = False
# First step: handle empty diagrams
n = len(X)
@@ -267,7 +262,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
essential_cost = 0
essential_matching = None
- # Extract off-diaognal points of the diagrams.
+ # Extract off-diaognal points of the diagrams. Note that if enable_autodiff is True, nothing is done here (X,Y are
+ # assumed to be tensors with only finite coordinates).
X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff)
n = len(X)
m = len(Y)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 285b95c9..6701c7ba 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -31,7 +31,7 @@ def test_proj_on_diag():
def test_offdiag():
diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
- assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]])
+ assert np.array_equal(_offdiag(diag, enable_autodiff=False), [[0, 1], [3, 5]])
def test_handle_essential_parts():