diff options
author | tlacombe <lacombe1993@gmail.com> | 2020-07-06 18:27:52 +0200 |
---|---|---|
committer | tlacombe <lacombe1993@gmail.com> | 2020-07-06 18:27:52 +0200 |
commit | fe3e6a3a47828841ba3cb4a0721e5d8c16ab126f (patch) | |
tree | 0e65e5edfa38c23413d738ce27eb9b0ce13e2cf1 /src/python/gudhi | |
parent | 91a9d77ed48847a8859e6bdd759390001910d411 (diff) |
update test including essential parts
Diffstat (limited to 'src/python/gudhi')
-rw-r--r-- | src/python/gudhi/wasserstein/wasserstein.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py index 283ecd9d..2a1dee7a 100644 --- a/src/python/gudhi/wasserstein/wasserstein.py +++ b/src/python/gudhi/wasserstein/wasserstein.py @@ -105,10 +105,10 @@ def _get_essential_parts(a): For instance, a[_get_essential_parts(a)[0]] returns the points in a of coordinates (-inf, x) for some finite x. ''' if len(a): - ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) + ess_first_type = np.where(np.isfinite(a[:,1]) & (a[:,0] == -np.inf))[0] # coord (-inf, x) ess_second_type = np.where(np.isfinite(a[:,0]) & (a[:,1] == np.inf))[0] # coord (x, +inf) - ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) - ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) + ess_third_type = np.where((a[:,0] == -np.inf) & (a[:,1] == np.inf))[0] # coord (-inf, +inf) + ess_fourth_type = np.where((a[:,0] == -np.inf) & (a[:,1] == -np.inf))[0] # coord (-inf, -inf) ess_fifth_type = np.where((a[:,0] == np.inf) & (a[:,1] == np.inf))[0] # coord (+inf, +inf) return ess_first_type, ess_second_type, ess_third_type, ess_fourth_type, ess_fifth_type else: @@ -232,12 +232,20 @@ def wasserstein_distance(X, Y, matching=False, order=1., internal_p=np.inf, enab if not matching: return _perstot(Y, order, internal_p, enable_autodiff) else: - return _perstot(Y, order, internal_p, enable_autodiff), np.array([[-1, j] for j in range(m)]) + cost = _perstot(Y, order, internal_p, enable_autodiff) + if cost == np.inf: # We had some essential part here. + return cost, None + else: + return cost, np.array([[-1, j] for j in range(m)]) elif m == 0: if not matching: return _perstot(X, order, internal_p, enable_autodiff) else: - return _perstot(X, order, internal_p, enable_autodiff), np.array([[i, -1] for i in range(n)]) + cost = _perstot(X, order, internal_p, enable_autodiff) + if cost == np.inf: + return cost, None + else: + return np.array([[i, -1] for i in range(n)]) # Second step: handle essential parts |