summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-01-28 17:40:16 +0100
committerGitHub <noreply@github.com>2022-01-28 17:40:16 +0100
commit71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (patch)
tree1d07299ff3e99003642a8eb72537abe2bc6eb8b3
parentd7c709e2bae3bafec9efad87e758919c8db61933 (diff)
[MRG] Backend implementation of the free support barycenter (#340)
* backend version barycenter * new tests * cleanup release file and doc * f*ing pep8 * remove unused variable
-rw-r--r--RELEASES.md5
-rw-r--r--ot/lp/__init__.py28
-rw-r--r--test/test_ot.py17
-rw-r--r--test/test_utils.py2
4 files changed, 37 insertions, 15 deletions
diff --git a/RELEASES.md b/RELEASES.md
index a5fcbe1..94c853b 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,5 +1,6 @@
# Releases
+
## 0.8.2dev Development
#### New features
@@ -7,10 +8,12 @@
- Better list of related examples in quick start guide with `minigallery` (PR #334)
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
of the regularization parameter (PR #336)
+- Backend implementation for `ot.lp.free_support_barycenter` (PR #340)
#### Closed issues
-- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338)
+- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR
+ #338)
## 0.8.1.0
*December 2021*
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5da897d..2ff7c1f 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -535,18 +535,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Parameters
----------
- measures_locations : list of N (k_i,d) numpy.ndarray
+ measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
(:math:`k_i` can be different for each element of the list)
- measures_weights : list of N (k_i,) numpy.ndarray
+ measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
representing the weights of each discrete input measure
- X_init : (k,d) np.ndarray
+ X_init : (k,d) array-like
Initialization of the support locations (on `k` atoms) of the barycenter
- b : (k,) np.ndarray
+ b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (N,) np.ndarray
+ weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -564,7 +564,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Returns
-------
- X : (k,d) np.ndarray
+ X : (k,d) array-like
Support locations (on k atoms) of the barycenter
@@ -577,15 +577,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
"""
+ nx = get_backend(*measures_locations,*measures_weights,X_init)
+
iter_count = 0
N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = np.ones((k,)) / k
+ b = nx.ones((k,),type_as=X_init) / k
if weights is None:
- weights = np.ones((N,)) / N
+ weights = nx.ones((N,),type_as=X_init) / N
X = X_init
@@ -596,15 +598,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = np.zeros((k, d))
+ T_sum = nx.zeros((k, d),type_as=X_init)
+
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
- weights.tolist()):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = np.sum(np.square(T_sum - X))
+ displacement_square_norm = nx.sum((T_sum - X)**2)
if log:
displacement_square_norms.append(displacement_square_norm)
diff --git a/test/test_ot.py b/test/test_ot.py
index 53edf4f..e8e2d97 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -302,6 +302,23 @@ def test_free_support_barycenter():
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+def test_free_support_barycenter_backends(nx):
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ measures_locations2 = [nx.from_numpy(x) for x in measures_locations]
+ measures_weights2 = [nx.from_numpy(x) for x in measures_weights]
+ X_init2 = nx.from_numpy(X_init)
+
+ X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)
+
+ np.testing.assert_allclose(X, nx.to_numpy(X2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6b476b2..8b23c22 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -122,7 +122,7 @@ def test_dist():
'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
'euclidean', 'hamming', 'jaccard', 'kulsinski',
'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
- 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule'
] # those that support weights
metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version