summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2021-04-12 17:02:34 +0200
committertlacombe <lacombe1993@gmail.com>2021-04-12 17:02:34 +0200
commitcdab3c9e32923f83d25d2cdf207f3cddbb3f94f6 (patch)
tree4e0eeb2e38b1b04244dd3c2a6e246c22507de82f
parent2a11e3651c2d66df8371a9aa1d23dff69ffbc31c (diff)
handle essential parts test
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py1
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
2 files changed, 2 insertions, 1 deletions
diff --git a/src/python/gudhi/wasserstein/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 2911f826..7cb9d5d9 100644
--- a/src/python/gudhi/wasserstein/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -113,6 +113,7 @@ def _get_essential_parts(a):
second_coord_infinite_positive = (a[:,1] == np.inf)
first_coord_infinite_negative = (a[:,0] == -np.inf)
second_coord_infinite_negative = (a[:,1] == -np.inf)
+
ess_first_type = np.where(second_coord_finite & first_coord_infinite_negative)[0] # coord (-inf, x)
ess_second_type = np.where(first_coord_finite & second_coord_infinite_positive)[0] # coord (x, +inf)
ess_third_type = np.where(first_coord_infinite_negative & second_coord_infinite_positive)[0] # coord (-inf, +inf)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 14d5c2ca..df7acc91 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -31,7 +31,7 @@ def test_proj_on_diag():
def test_finite_part():
diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
- assert np.array_equal(_finite_part(diag, enable_autodiff=False), [[0, 1], [3, 5]])
+ assert np.array_equal(_finite_part(diag), [[0, 1], [3, 5]])
def test_handle_essential_parts():