summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 3fa1bc4..16fd510 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -31,9 +31,11 @@ def test_emd_emd2():
# check G is identity
assert np.allclose(G, np.eye(n) / n)
+ # check constratints
+ assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
+ assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
w = ot.emd2(u, u, M)
-
# check loss=0
assert np.allclose(w, 0)