diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-01-28 17:40:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-28 17:40:16 +0100 |
commit | 71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (patch) | |
tree | 1d07299ff3e99003642a8eb72537abe2bc6eb8b3 /test/test_ot.py | |
parent | d7c709e2bae3bafec9efad87e758919c8db61933 (diff) |
[MRG] Backend implementation of the free support barycenter (#340)
* backend version barycenter
* new tests
* cleanup release file and doc
* f*ing pep8
* remove unused variable
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 53edf4f..e8e2d97 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -302,6 +302,23 @@ def test_free_support_barycenter(): np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) +def test_free_support_barycenter_backends(nx): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + X_init = np.array([-12.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + measures_locations2 = [nx.from_numpy(x) for x in measures_locations] + measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] |