summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
committerVivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net>2018-07-09 16:49:21 +0900
commit46712790c1276f1ecb3496362a8117e153782ede (patch)
tree15c66c04b24396078b60b8f0a89b7a9cc6476b50
parent67ddb92e28d6bb44eb65686419e255c2ce3311eb (diff)
add test free support barycenter algorithm + cleaning
-rw-r--r--examples/plot_free_support_barycenter.py29
-rw-r--r--ot/lp/__init__.py92
-rw-r--r--ot/lp/cvx.py82
-rw-r--r--test/test_ot.py15
4 files changed, 121 insertions, 97 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
index 5b08507..b2e62c8 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/plot_free_support_barycenter.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Wasserstein barycenters of distributions
+2D free support Wasserstein barycenters of distributions
====================================================
Illustration of 2D Wasserstein barycenters if discributions that are weighted
@@ -15,7 +15,8 @@ sum of diracs.
import numpy as np
import matplotlib.pylab as pl
-import ot.plot
+import ot
+
##############################################################################
# Generate data
@@ -28,16 +29,16 @@ measures_weights = []
for i in range(N):
- n = np.random.randint(low=1, high=20) # nb samples
+ n_i = np.random.randint(low=1, high=20) # nb samples
- mu = np.random.normal(0., 4., (d,))
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
- A = np.random.rand(d, d)
- cov = np.dot(A, A.transpose())
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
- x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov)
- b_i = np.random.uniform(0., 1., (n,))
- b_i = b_i / np.sum(b_i)
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
measures_locations.append(x_i)
measures_weights.append(b_i)
@@ -47,19 +48,17 @@ for i in range(N):
# Compute free support barycenter
# -------------
-k = 10
-X_init = np.random.normal(0., 1., (k, d))
-b = np.ones((k,)) / k
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
-X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
##############################################################################
# Plot data
# ---------
-#%% plot samples
-
pl.figure(1)
for (x_i, b_i) in zip(measures_locations, measures_weights):
color = np.random.randint(low=1, high=10 * N)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4c0d170..96bf6de 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -17,8 +17,9 @@ from .import cvx
from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter
+from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'cvx']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
def emd(a, b, M, numItermax=100000, log=False):
@@ -216,3 +217,92 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
res = parmap(f, [b[:, i] for i in range(nb)], processes)
return res
+
+
+
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+ """
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+
+ The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
+ This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+
+ Parameters
+ ----------
+ measures_locations : list of (k_i,d) np.ndarray
+ The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ measures_weights : list of (k_i,) np.ndarray
+ Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+
+ X_init : (k,d) np.ndarray
+ Initialization of the support locations (on k atoms) of the barycenter
+ b : (k,) np.ndarray
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (k,) np.ndarray
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) np.ndarray
+ Support locations (on k atoms) of the barycenter
+
+ References
+ ----------
+
+ .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = np.ones((k,))/k
+ if weights is None:
+ weights = np.ones((N,)) / N
+
+ X = X_init
+
+ displacement_square_norms = []
+ displacement_square_norm = stopThr + 1.
+
+ while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+
+ T_sum = np.zeros((k, d))
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
+
+ M_i = dist(X, measure_locations_i)
+ T_i = emd(b, measure_weights_i, M_i)
+ T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+
+ displacement_square_norm = np.sum(np.square(T_sum-X))
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ return X, displacement_square_norms
+ else:
+ return X \ No newline at end of file
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index c097f58..8e763be 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -10,7 +10,7 @@ LP solvers for optimal transport using cvxopt
import numpy as np
import scipy as sp
import scipy.sparse as sps
-import ot
+
try:
import cvxopt
@@ -145,83 +145,3 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
return b, sol
else:
return b
-
-
-def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
- """
- Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
-
- The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
- This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
- - we do not optimize over the weights
- - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
-
- Parameters
- ----------
- data_positions : list of (k_i,d) np.ndarray
- The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
- data_weights : list of (k_i,) np.ndarray
- Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
-
- X_init : (k,d) np.ndarray
- Initialization of the support locations (on k atoms) of the barycenter
- b : (k,) np.ndarray
- Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (k,) np.ndarray
- Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
-
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshol on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
-
- Returns
- -------
- X : (k,d) np.ndarray
- Support locations (on k atoms) of the barycenter
-
- References
- ----------
-
- .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
-
- .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
-
- """
-
- iter_count = 0
-
- d = X_init.shape[1]
- k = b.size
- N = len(measures_locations)
-
- if not weights:
- weights = np.ones((N,)) / N
-
- X = X_init
-
- displacement_square_norm = stopThr + 1.
-
- while (displacement_square_norm > stopThr and iter_count < numItermax):
-
- T_sum = np.zeros((k, d))
-
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
-
- M_i = ot.dist(X, measure_locations_i)
- T_i = ot.emd(b, measure_weights_i, M_i)
- T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
-
- displacement_square_norm = np.sum(np.square(X - T_sum))
- X = T_sum
-
- if verbose:
- print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
-
- iter_count += 1
-
- return X
diff --git a/test/test_ot.py b/test/test_ot.py
index 399e549..dafc03f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -135,6 +135,21 @@ def test_lp_barycenter():
np.testing.assert_allclose(bary.sum(), 1)
+def test_free_support_barycenter():
+
+ measures_locations = [np.array([-1.]).reshape((1,1)), np.array([1.]).reshape((1,1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+
+ X_init = np.array([-12.]).reshape((1,1))
+
+ # obvious barycenter location between two diracs
+ bar_locations = np.array([0.]).reshape((1,1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():