diff options
Diffstat (limited to 'src/python/gudhi/wasserstein/wasserstein.py')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 283ecd9d..2a1dee7a 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -105,10 +105,10 @@ def _get_essential_parts(a): For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. ''' if len(a): - ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) + ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf) - ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) - ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) + ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) + ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type else: @@ -232,12 +232,20 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if not matching: return _perstot(Y, order, internal_p, enable_autodiff) else: - return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)]) + cost = _perstot(Y, order, internal_p, enable_autodiff) + if cost == np.inf: # We had some essential part here. + return cost, None + else: + return cost, np.array([[-1, j] for j in range(m)]) elif m == 0: if not matching: return _perstot(X, order, internal_p, enable_autodiff) else: - return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)]) + cost = _perstot(X, order, internal_p, enable_autodiff) + if cost == np.inf: + return cost, None + else: + return np.array([[i, -1] for i in range(n)]) # Second step: handle essential parts |