From fdb2f3af19d04872bafa0d9ec5563732e1d6209b Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 11 May 2018 17:24:09 +0200 Subject: add test for barycenter --- test/test_gpu.py | 2 +- test/test_ot.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/test_gpu.py b/test/test_gpu.py index 615c2a7..1e97c45 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -76,4 +76,4 @@ def test_gpu_sinkhorn_lpl1(): time3 - time2)) describe_res(G2) - np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(G1, G2, rtol=1e-3, atol=1e-3) diff --git a/test/test_ot.py b/test/test_ot.py index ea6d9dc..bf23e8c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -10,7 +10,7 @@ import numpy as np import ot from ot.datasets import get_1D_gauss as gauss - +import pytest def test_doctest(): import doctest @@ -117,6 +117,39 @@ def test_emd2_multi(): np.testing.assert_allclose(emd1, emdn) +def test_lp_barycenter(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5]) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + +@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") +def test_lp_barycenter_cvxopt(): + + a1 = np.array([1.0, 0, 0])[:, None] + a2 = np.array([0, 0, 1.0])[:, None] + + A = np.hstack((a1, a2)) + M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]]) + + # obvious barycenter between two diracs + bary0 = np.array([0, 1.0, 0]) + + bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) + + np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(bary.sum(), 1) + def test_warnings(): n = 100 # nb bins m = 100 # nb bins -- cgit v1.2.3 From bd1af44ea0a819d5df0ccffbea4d05ed7547960b Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 11 May 2018 17:25:32 +0200 Subject: add test barycenter cvxopt --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index bf23e8c..cc25bf4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,6 +12,7 @@ import ot from ot.datasets import get_1D_gauss as gauss import pytest + def test_doctest(): import doctest @@ -133,6 +134,7 @@ def test_lp_barycenter(): np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): @@ -145,11 +147,12 @@ def test_lp_barycenter_cvxopt(): # obvious barycenter between two diracs bary0 = np.array([0, 1.0, 0]) - bary = ot.lp.barycenter(A, M, [.5, .5],solver=None) + bary = ot.lp.barycenter(A, M, [.5, .5], solver=None) np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) + def test_warnings(): n = 100 # nb bins m = 100 # nb bins -- cgit v1.2.3