summaryrefslogtreecommitdiff
path: root/src/python/gudhi/wasserstein
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-12 19:48:57 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-12 19:48:57 +0200
commitbb0792ed7bfe9d718be3e8039e8fb89af6d160e5 (patch)
tree087c453546b5259a1621e13c0ad590eede97d996 /src/python/gudhi/wasserstein
parentcdab3c9e32923f83d25d2cdf207f3cddbb3f94f6 (diff)
added warning when cost is infty and matching is None
Diffstat (limited to 'src/python/gudhi/wasserstein')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py44
1 files changed, 28 insertions, 16 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 7cb9d5d9..8ccbe12e 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -9,6 +9,7 @@
import numpy as np
import scipy.spatial.distance as sc
+import warnings
try:
import ot
@@ -188,6 +189,20 @@ def _finite_part(X):
return X[np.where(np.isfinite(X[:,0]) & np.isfinite(X[:,1]))]
+def _warn_infty(matching):
+ '''
+ Handle essential parts with different cardinalities. Warn the user about cost being infinite and (if
+ `matching=True`) about the returned matching being `None`.
+ '''
+ if matching:
+ warnings.warn('Cardinality of essential parts differs. Distance (cost) is +infty, and the returned matching is None.')
+ return np.inf, None
+ else:
+ warnings.warn('Cardinality of essential parts diffes. Distance (cost) is +infty.')
+ return np.inf
+
+
+
def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enable_autodiff=False,
keep_essential_parts=True):
'''
@@ -230,28 +245,27 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
else:
return 0., np.array([])
else:
- if not matching:
- return _perstot(Y, order, internal_p, enable_autodiff)
+ cost = _perstot(Y, order, internal_p, enable_autodiff)
+ if cost == np.inf:
+ return _warn_infty(matching)
else:
- cost = _perstot(Y, order, internal_p, enable_autodiff)
- if cost == np.inf: # We had some essential part in Y.
- return cost, None
+ if not matching:
+ return cost
else:
return cost, np.array([[-1, j] for j in range(m)])
elif m == 0:
- if not matching:
- return _perstot(X, order, internal_p, enable_autodiff)
+ cost = _perstot(X, order, internal_p, enable_autodiff)
+ if cost == np.inf:
+ return _warn_infty(matching)
else:
- cost = _perstot(X, order, internal_p, enable_autodiff)
- if cost == np.inf:
- return cost, None
+ if not matching:
+ return cost
else:
return cost, np.array([[i, -1] for i in range(n)])
# Check essential part and enable autodiff together
if enable_autodiff and keep_essential_parts:
- import warnings # should it be done at the top of the file?
warnings.warn('''enable_autodiff=True and keep_essential_parts=True are incompatible together.
keep_essential_parts is set to False: only points with finite coordiantes are considered
in the following.
@@ -262,11 +276,9 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
if keep_essential_parts:
essential_cost, essential_matching = _handle_essential_parts(X, Y, order=order)
if (essential_cost == np.inf):
- if matching:
- return np.inf, None
- else:
- return np.inf # avoid computing transport cost between the finite parts if essential parts
- # cardinalities do not match (saves time)
+ return _warn_infty(matching) # Tells the user that cost is infty and matching (if True) is None.
+ # avoid computing transport cost between the finite parts if essential parts
+ # cardinalities do not match (saves time)
else:
essential_cost = 0
essential_matching = None