summaryrefslogtreecommitdiff
path: root/test/test_partial.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_partial.py')
-rwxr-xr-xtest/test_partial.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/test/test_partial.py b/test/test_partial.py
index 510e081..97c611b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -51,10 +51,12 @@ def test_raise_errors():
ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2,
+ log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1,
+ log=True)
def test_partial_wasserstein_lagrange():
@@ -102,7 +104,7 @@ def test_partial_wasserstein():
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -125,11 +127,11 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
@@ -192,7 +194,7 @@ def test_partial_gromov_wasserstein():
100, m=m,
log=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(