summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-03 15:33:17 +0100
commit8e4f3d151818b78a29d11cdc6ca171947bfd6dd9 (patch)
treef7f2d562332cfa22b90628e0e95dd22739322f9e /src/python/test/test_wasserstein_distance.py
parentd2943b9e7311c8a3d8a4fb379c39b15497481b9c (diff)
update wasserstein distance with pot so that it can return optimal matching now!
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 6a6b217b..02a1d2c9 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -51,14 +51,27 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
- if(not test_infinity):
- return
+ if test_infinity:
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
- diag5 = np.array([[0, 3], [4, np.inf]])
- diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
+
+
+ if test_matching:
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=1., order=2)[1]
+ assert match == []
+ match = wasserstein_distance(emptydiag, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert match == []
+ match = wasserstein_distance(emptydiag, diag2, matching=True, internal_p=np.inf, order=2.)[1]
+ assert match == [(-1, 0), (-1, 1)]
+ match = wasserstein_distance(diag2, emptydiag, matching=True, internal_p=np.inf, order=2.24)[1]
+ assert match == [(0, -1), (1, -1)]
+ match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
+ assert match == [(0, 0), (1, 1), (2, -1)]
+
- assert wasserstein_distance(diag4, diag5) == np.inf
- assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
def hera_wrap(delta):
def fun(*kargs,**kwargs):
@@ -66,8 +79,8 @@ def hera_wrap(delta):
return fun
def test_wasserstein_distance_pot():
- _basic_wasserstein(pot, 1e-15, test_infinity=False)
+ _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12)
- _basic_wasserstein(hera_wrap(.1), .1)
+ _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)