diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 399e549..45e777a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -135,6 +135,21 @@ def test_lp_barycenter(): np.testing.assert_allclose(bary.sum(), 1) +def test_free_support_barycenter(): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + + X_init = np.array([-12.]).reshape((1, 1)) + + # obvious barycenter location between two diracs + bar_locations = np.array([0.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): |