summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-07-07 11:52:35 +0200
committertlacombe <lacombe1993@gmail.com>2020-07-07 11:52:35 +0200
commite0eba14109e02676825f8c24563872a5b49c6120 (patch)
tree760b8287044752b0f77fcc5a43feb295a9cbae93 /src/python/test/test_wasserstein_distance.py
parentfe3e6a3a47828841ba3cb4a0721e5d8c16ab126f (diff)
correction typo in test wdist
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 24be228b..e50091e9 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -55,10 +55,10 @@ def test_handle_essential_parts():
[np.inf, np.inf],
[-np.inf, np.inf], [-np.inf, np.inf]])
- c, m = _handle_essential_parts(diag1, diag2, matching=True, order=1)
+ c, m = _handle_essential_parts(diag1, diag2, order=1)
assert c == pytest.approx(3, 0.0001)
assert np.array_equal(m, [[0,0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]])
- c, m = _handle_essential_parts(diag1, diag3, matching=True, order=1)
+ c, m = _handle_essential_parts(diag1, diag3, order=1)
assert c == np.inf
assert (m is None)
@@ -68,11 +68,11 @@ def test_get_essential_parts():
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
res = _get_essential_parts(diag)
- assert res[0] = [4, 5]
- assert res[1] = [2, 3]
- assert res[2] = [8, 9]
- assert res[3] = [6]
- assert res[4] = [7]
+ assert res[0] == [4, 5]
+ assert res[1] == [2, 3]
+ assert res[2] == [8, 9]
+ assert res[3] == [6]
+ assert res[4] == [7]
def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_matching=True):