diff options
author | tlacombe <lacombe1993@gmail.com> | 2021-04-12 19:48:57 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2021-04-12 19:48:57 +0200 |
commit | bb0792ed7bfe9d718be3e8039e8fb89af6d160e5 (patch) | |
tree | 087c453546b5259a1621e13c0ad590eede97d996 /src | |
parent | cdab3c9e32923f83d25d2cdf207f3cddbb3f94f6 (diff) |
added warning when cost is infty and matching is None
Diffstat (limited to 'src')
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 4 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 44 |
2 files changed, 30 insertions, 18 deletions
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index b3d17495..091c9fd9 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -100,7 +100,7 @@ any matching has a cost +inf and thus can be considered to be optimal. In such a print("matchings:", matchings) -The output is: +The output is: .. testoutput:: @@ -197,4 +197,4 @@ Tutorial This `notebook <https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-Barycenters-of-persistence-diagrams.ipynb>`_ -presents the concept of barycenter, or Fréchet mean, of a family of persistence diagrams.
\ No newline at end of file +presents the concept of barycenter, or Fréchet mean, of a family of persistence diagrams. diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 7cb9d5d9..8ccbe12e 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -9,6 +9,7 @@ import numpy as np import scipy.spatial.distance as sc +import warnings try: import ot @@ -188,6 +189,20 @@ def _finite_part(X): return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] +def _warn_infty(matching): + ''' + Handle essential parts with different cardinalities. Warn the user about cost being infinite and (if + `matching=True`) about the returned matching being `None`. + ''' + if matching: + warnings.warn('Cardinality of essential parts differs. Distance (cost) is +infty, and the returned matching is None.') + return np.inf, None + else: + warnings.warn('Cardinality of essential parts diffes. Distance (cost) is +infty.') + return np.inf + + + def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False, keep_essential_parts=True): ''' @@ -230,28 +245,27 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab else: return 0., np.array([]) else: - if not matching: - return _perstot(Y, order, internal_p, enable_autodiff) + cost = _perstot(Y, order, internal_p, enable_autodiff) + if cost == np.inf: + return _warn_infty(matching) else: - cost = _perstot(Y, order, internal_p, enable_autodiff) - if cost == np.inf: # We had some essential part in Y. - return cost, None + if not matching: + return cost else: return cost, np.array([[-1, j] for j in range(m)]) elif m == 0: - if not matching: - return _perstot(X, order, internal_p, enable_autodiff) + cost = _perstot(X, order, internal_p, enable_autodiff) + if cost == np.inf: + return _warn_infty(matching) else: - cost = _perstot(X, order, internal_p, enable_autodiff) - if cost == np.inf: - return cost, None + if not matching: + return cost else: return cost, np.array([[i, -1] for i in range(n)]) # Check essential part and enable autodiff together if enable_autodiff and keep_essential_parts: - import warnings # should it be done at the top of the file? warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together. keep_essential_parts is set to False: only points with finite coordiantes are considered in the following. @@ -262,11 +276,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if keep_essential_parts: essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order) if (essential_cost == np.inf): - if matching: - return np.inf, None - else: - return np.inf # avoid computing transport cost between the finite parts if essential parts - # cardinalities do not match (saves time) + return _warn_infty(matching) # Tells the user that cost is infty and matching (if True) is None. + # avoid computing transport cost between the finite parts if essential parts + # cardinalities do not match (saves time) else: essential_cost = 0 essential_matching = None |