diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index bf832f6..f2338ac 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + # test emd and emd2 for mass mismatch + a = ot.utils.unif(n_samples) b = a.copy() a[0] = 100 np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, b, M) def test_emd_backends(nx): @@ -201,6 +204,22 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) +def test_omp_emd2(): + # test emd2 and emd2 with openmp for simple identity + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + w = ot.emd2(u, u, M) + w2 = ot.emd2(u, u, M, numThreads=2) + + np.testing.assert_allclose(w, w2) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 @@ -320,6 +339,46 @@ def test_free_support_barycenter_backends(nx): np.testing.assert_allclose(X, nx.to_numpy(X2)) +def test_generalised_free_support_barycenter(): + np.random.seed(42) # random inits + X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 + a = [np.array([1.]), np.array([1.])] + + P = [np.eye(2), np.eye(2)] + + Y_init = np.array([-12., 7.]).reshape((1, 2)) + + # obvious barycenter location between two 2D diracs + Y_true = np.array([0., .0]).reshape((1, 2)) + + # test without log and no init + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + # test with log and init + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + +def test_generalised_free_support_barycenter_backends(nx): + np.random.seed(42) + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + Y_init = np.array([-12.]).reshape((1, 1)) + + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) + + X2 = nx.from_numpy(*X) + a2 = nx.from_numpy(*a) + P2 = nx.from_numpy(*P) + Y_init2 = nx.from_numpy(Y_init) + + Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2) + + np.testing.assert_allclose(Y, nx.to_numpy(Y2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] |