summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index bf832f6..f2338ac 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ # test emd and emd2 for mass mismatch
+ a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+ np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
def test_emd_backends(nx):
@@ -201,6 +204,22 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_omp_emd2():
+ # test emd2 and emd2 with openmp for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ w = ot.emd2(u, u, M)
+ w2 = ot.emd2(u, u, M, numThreads=2)
+
+ np.testing.assert_allclose(w, w2)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -320,6 +339,46 @@ def test_free_support_barycenter_backends(nx):
np.testing.assert_allclose(X, nx.to_numpy(X2))
+def test_generalised_free_support_barycenter():
+ np.random.seed(42) # random inits
+ X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0
+ a = [np.array([1.]), np.array([1.])]
+
+ P = [np.eye(2), np.eye(2)]
+
+ Y_init = np.array([-12., 7.]).reshape((1, 2))
+
+ # obvious barycenter location between two 2D diracs
+ Y_true = np.array([0., .0]).reshape((1, 2))
+
+ # test without log and no init
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+ # test with log and init
+ Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True)
+ np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7)
+
+
+def test_generalised_free_support_barycenter_backends(nx):
+ np.random.seed(42)
+ X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ a = [np.array([1.]), np.array([1.])]
+ P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ Y_init = np.array([-12.]).reshape((1, 1))
+
+ Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init)
+
+ X2 = nx.from_numpy(*X)
+ a2 = nx.from_numpy(*a)
+ P2 = nx.from_numpy(*P)
+ Y_init2 = nx.from_numpy(Y_init)
+
+ Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2)
+
+ np.testing.assert_allclose(Y, nx.to_numpy(Y2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]