diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-07-07 18:15:17 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-07-07 18:15:17 +0200 |
commit | 107f8e6668509f5fd36e179f9a538b460d3941a9 (patch) | |
tree | 34eff3e267d6355d5200e3f56c7e04032f928b8c /src/python/gudhi/wasserstein/wasserstein.py | |
parent | 42a399c273fde7c76ec23d2993957fcbb492ee79 (diff) |
added enable autodiff management in _offdiag utils function
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 981bbf08..495142c4 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -70,6 +70,7 @@ def _perstot_autodiff(X, order, internal_p): ''' return _dist_to_diag(X, internal_p).norms.lp(order) + def _perstot(X, order, internal_p, enable_autodiff): ''' :param X: (n x 2) numpy.array (points of a given diagram). @@ -174,12 +175,18 @@ def _handle_essential_parts(X, Y, order): return c, np.array(m) -def _offdiag(X): +def _offdiag(X, enable_autodiff): ''' :param X: (n x 2) numpy array encoding a persistence diagram. :returns: The off-diagonal part of a diagram `X` (points with finite coordinates). ''' - return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] + if enable_autodiff: + import eagerpy as ep + + return ep.astensor(X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]) + + else: + return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False, @@ -261,7 +268,7 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab essential_matching = None # Extract off-diaognal points of the diagrams. - X, Y = _offdiag(X), _offdiag(Y) + X, Y = _offdiag(X, enable_autodiff), _offdiag(Y, enable_autodiff) n = len(X) m = len(Y) |