summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index ec4388d..6aa4e08 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -57,6 +57,9 @@ def test_sinkhorn_empty():
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+ # test empty weights greenkhorn
+ ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
def test_sinkhorn_variants():
# test sinkhorn
@@ -124,7 +127,7 @@ def test_barycenter(method):
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
+ bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -152,9 +155,9 @@ def test_barycenter_stabilization():
reg = 1e-2
bar_stable = ot.bregman.barycenter(A, M, reg, weights,
method="sinkhorn_stabilized",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
np.testing.assert_allclose(bar, bar_stable)