summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py17
-rw-r--r--test/test_utils.py2
2 files changed, 18 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 53edf4f..e8e2d97 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -302,6 +302,23 @@ def test_free_support_barycenter():
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+def test_free_support_barycenter_backends(nx):
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ measures_locations2 = [nx.from_numpy(x) for x in measures_locations]
+ measures_weights2 = [nx.from_numpy(x) for x in measures_weights]
+ X_init2 = nx.from_numpy(X_init)
+
+ X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)
+
+ np.testing.assert_allclose(X, nx.to_numpy(X2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6b476b2..8b23c22 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -122,7 +122,7 @@ def test_dist():
'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
'euclidean', 'hamming', 'jaccard', 'kulsinski',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
- 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'
] # those that support weights
metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version