summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2018-05-29 16:16:41 +0200
committerGitHub <noreply@github.com>2018-05-29 16:16:41 +0200
commit90efa5a8b189214d1aeb81920b2bb04ce0c261ca (patch)
tree62e2f1a3cca2f4885e8c0e2a0b135a5f574d6a8c /test
parentec79b791f4f4a62f7c04b7bbf14fe2f5dcbb4c75 (diff)
parent54f0b47e55c966d5492e4ce19ec4e704ef3278d6 (diff)
Merge pull request #47 from rflamary/bary
LP Wasserstein barycenter with scipy linear solver and/or cvxopt
Diffstat (limited to 'test')
-rw-r--r--test/test_gpu.py2
-rw-r--r--test/test_ot.py36
2 files changed, 37 insertions, 1 deletions
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 615c2a7..1e97c45 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1():
time3 - time2))
describe_res(G2)
- np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)
+ np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3)
diff --git a/test/test_ot.py b/test/test_ot.py
index ea6d9dc..cc25bf4 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -10,6 +10,7 @@ import numpy as np
import ot
from ot.datasets import get_1D_gauss as gauss
+import pytest
def test_doctest():
@@ -117,6 +118,41 @@ def test_emd2_multi():
np.testing.assert_allclose(emd1, emdn)
+def test_lp_barycenter():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5])
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
+@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
+def test_lp_barycenter_cvxopt():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5], solver=None)
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
def test_warnings():
n = 100 # nb bins
m = 100 # nb bins