summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-07-06 18:27:52 +0200
committertlacombe <lacombe1993@gmail.com>2020-07-06 18:27:52 +0200
commitfe3e6a3a47828841ba3cb4a0721e5d8c16ab126f (patch)
tree0e65e5edfa38c23413d738ce27eb9b0ce13e2cf1 /src/python/gudhi
parent91a9d77ed48847a8859e6bdd759390001910d411 (diff)
update test including essential parts
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 283ecd9d..2a1dee7a 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -105,10 +105,10 @@ def _get_essential_parts(a):
For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x.
'''
if len(a):
- ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x)
+ ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x)
ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf)
- ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf)
- ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf)
+ ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf)
+ ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf)
ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf)
return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type
else:
@@ -232,12 +232,20 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab
if not matching:
return _perstot(Y, order, internal_p, enable_autodiff)
else:
- return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)])
+ cost = _perstot(Y, order, internal_p, enable_autodiff)
+ if cost == np.inf: # We had some essential part here.
+ return cost, None
+ else:
+ return cost, np.array([[-1, j] for j in range(m)])
elif m == 0:
if not matching:
return _perstot(X, order, internal_p, enable_autodiff)
else:
- return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)])
+ cost = _perstot(X, order, internal_p, enable_autodiff)
+ if cost == np.inf:
+ return cost, None
+ else:
+ return np.array([[i, -1] for i in range(n)])
# Second step: handle essential parts