summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 399e549..7652394 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -70,7 +70,7 @@ def test_emd_empty():
def test_emd2_multi():
- n = 1000 # nb bins
+ n = 500 # nb bins
# bin positions
x = np.arange(n, dtype=np.float64)
@@ -78,7 +78,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 1000, 20)
+ ls = np.arange(20, 500, 20)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):
@@ -135,6 +135,21 @@ def test_lp_barycenter():
np.testing.assert_allclose(bary.sum(), 1)
+def test_free_support_barycenter():
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # obvious barycenter location between two diracs
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
@@ -192,11 +207,11 @@ def test_warnings():
def test_dual_variables():
- n = 5000 # nb bins
- m = 6000 # nb bins
+ n = 500 # nb bins
+ m = 600 # nb bins
- mean1 = 1000
- mean2 = 1100
+ mean1 = 300
+ mean2 = 400
# bin positions
x = np.arange(n, dtype=np.float64)