summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
committerVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
commit46712790c1276f1ecb3496362a8117e153782ede (patch)
tree15c66c04b24396078b60b8f0a89b7a9cc6476b50 /test/test_ot.py
parent67ddb92e28d6bb44eb65686419e255c2ce3311eb (diff)
add test free support barycenter algorithm + cleaning
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 399e549..dafc03f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -135,6 +135,21 @@ def test_lp_barycenter():
np.testing.assert_allclose(bary.sum(), 1)
+def test_free_support_barycenter():
+
+ measures_locations = [np.array([-1.]).reshape((1,1)), np.array([1.]).reshape((1,1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+
+ X_init = np.array([-12.]).reshape((1,1))
+
+ # obvious barycenter location between two diracs
+ bar_locations = np.array([0.]).reshape((1,1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():