diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-07-20 11:41:13 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-07-20 11:41:13 +0200 |
commit | e94892f972357283e70c7534f84662dfaa21cc3e (patch) | |
tree | 5f5a69625f2c61201aaf9c31ecfeee95ad7b02ab /src | |
parent | 107f8e6668509f5fd36e179f9a538b460d3941a9 (diff) |
update test enable_autodiff and _offdiag
Diffstat (limited to 'src')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 16 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 2 |
2 files changed, 7 insertions, 11 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 495142c4..142385b1 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -178,13 +178,13 @@ def _handle_essential_parts(X, Y, order): def _offdiag(X, enable_autodiff): ''' :param X: (n x 2) numpy array encoding a persistence diagram. + :param enable_autodiff: boolean, to handle the case where X is a eagerpy tensor. :returns: The off-diagonal part of a diagram `X` (points with finite coordinates). ''' if enable_autodiff: - import eagerpy as ep - - return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]) - + # Assumes the diagrams only have finite coordinates. Thus, return X directly. + # TODO improve this to get rid of essential parts if there are any. + return X else: return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] @@ -218,11 +218,6 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab If matching is set to True, also returns the optimal matching between X and Y. If cost is +inf, any matching is optimal and thus it returns `None` instead. ''' - # Zeroth step: check compatibility of arguments - if keep_essential_parts and enable_autodiff: - import warnings - warnings.warn("enable_autodiff does not handle essential parts yet. keep_essential_parts set to False.") - keep_essential_parts = False # First step: handle empty diagrams n = len(X) @@ -267,7 +262,8 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab essential_cost = 0 essential_matching = None - # Extract off-diaognal points of the diagrams. + # Extract off-diaognal points of the diagrams. Note that if enable_autodiff is True, nothing is done here (X,Y are + # assumed to be tensors with only finite coordinates). X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff) n = len(X) m = len(Y) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 285b95c9..6701c7ba 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -31,7 +31,7 @@ def test_proj_on_diag(): def test_offdiag(): diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf], [np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]]) - assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]]) + assert np.array_equal(_offdiag(diag, enable_autodiff=False), [[0, 1], [3, 5]]) def test_handle_essential_parts(): |