From bb0792ed7bfe9d718be3e8039e8fb89af6d160e5 Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 12 Apr 2021 19:48:57 +0200 Subject: added warning when cost is infty and matching is None --- src/python/gudhi/wasserstein/wasserstein.py | 44 ++++++++++++++++++----------- 1 file changed, 28 insertions(+), 16 deletions(-) (limited to 'src/python/gudhi/wasserstein') diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 7cb9d5d9..8ccbe12e 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -9,6 +9,7 @@ import numpy as np import scipy.spatial.distance as sc +import warnings try: import ot @@ -188,6 +189,20 @@ def _finite_part(X): return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))] +def _warn_infty(matching): + ''' + Handle essential parts with different cardinalities. Warn the user about cost being infinite and (if + `matching=True`) about the returned matching being `None`. + ''' + if matching: + warnings.warn('Cardinality of essential parts differs. Distance (cost) is +infty, and the returned matching is None.') + return np.inf, None + else: + warnings.warn('Cardinality of essential parts diffes. Distance (cost) is +infty.') + return np.inf + + + def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False, keep_essential_parts=True): ''' @@ -230,28 +245,27 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab else: return 0., np.array([]) else: - if not matching: - return _perstot(Y, order, internal_p, enable_autodiff) + cost = _perstot(Y, order, internal_p, enable_autodiff) + if cost == np.inf: + return _warn_infty(matching) else: - cost = _perstot(Y, order, internal_p, enable_autodiff) - if cost == np.inf: # We had some essential part in Y. - return cost, None + if not matching: + return cost else: return cost, np.array([[-1, j] for j in range(m)]) elif m == 0: - if not matching: - return _perstot(X, order, internal_p, enable_autodiff) + cost = _perstot(X, order, internal_p, enable_autodiff) + if cost == np.inf: + return _warn_infty(matching) else: - cost = _perstot(X, order, internal_p, enable_autodiff) - if cost == np.inf: - return cost, None + if not matching: + return cost else: return cost, np.array([[i, -1] for i in range(n)]) # Check essential part and enable autodiff together if enable_autodiff and keep_essential_parts: - import warnings # should it be done at the top of the file? warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together. keep_essential_parts is set to False: only points with finite coordiantes are considered in the following. @@ -262,11 +276,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if keep_essential_parts: essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order) if (essential_cost == np.inf): - if matching: - return np.inf, None - else: - return np.inf # avoid computing transport cost between the finite parts if essential parts - # cardinalities do not match (saves time) + return _warn_infty(matching) # Tells the user that cost is infty and matching (if True) is None. + # avoid computing transport cost between the finite parts if essential parts + # cardinalities do not match (saves time) else: essential_cost = 0 essential_matching = None -- cgit v1.2.3