summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--examples/plot_free_support_barycenter.py69
-rw-r--r--ot/lp/__init__.py95
-rw-r--r--ot/lp/cvx.py3
-rw-r--r--test/test_ot.py15
5 files changed, 183 insertions, 2 deletions
diff --git a/README.md b/README.md
index 677a23b..dded582 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ It provides the following solvers:
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
+* Non regularized free support Wasserstein barycenters [20].
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
+
+[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning \ No newline at end of file
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
new file mode 100644
index 0000000..b6efc59
--- /dev/null
+++ b/examples/plot_free_support_barycenter.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D free support Wasserstein barycenters of distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters if discributions that are weighted
+sum of diracs.
+
+"""
+
+# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+#%% parameters and data generation
+N = 3
+d = 2
+measures_locations = []
+measures_weights = []
+
+for i in range(N):
+
+ n_i = np.random.randint(low=1, high=20) # nb samples
+
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
+
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
+
+
+##############################################################################
+# Compute free support barycenter
+# -------------
+
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1)
+for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc=0)
+pl.show()
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4c0d170..02cbd8c 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -17,8 +17,9 @@ from .import cvx
from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter
+from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'cvx']
+__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
def emd(a, b, M, numItermax=100000, log=False):
@@ -216,3 +217,95 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
res = parmap(f, [b[:, i] for i in range(nb)], processes)
return res
+
+
+
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+ """
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+
+ The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
+ This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+
+ Parameters
+ ----------
+ measures_locations : list of (k_i,d) np.ndarray
+ The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ measures_weights : list of (k_i,) np.ndarray
+ Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+
+ X_init : (k,d) np.ndarray
+ Initialization of the support locations (on k atoms) of the barycenter
+ b : (k,) np.ndarray
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (k,) np.ndarray
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) np.ndarray
+ Support locations (on k atoms) of the barycenter
+
+ References
+ ----------
+
+ .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = np.ones((k,))/k
+ if weights is None:
+ weights = np.ones((N,)) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+
+ T_sum = np.zeros((k, d))
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
+
+ M_i = dist(X, measure_locations_i)
+ T_i = emd(b, measure_weights_i, M_i)
+ T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+
+ displacement_square_norm = np.sum(np.square(T_sum-X))
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X \ No newline at end of file
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index c8c75bc..8e763be 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,6 +11,7 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
+
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the entropic regularized wasserstein barycenter of distributions A
+ """Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
diff --git a/test/test_ot.py b/test/test_ot.py
index 399e549..45e777a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -135,6 +135,21 @@ def test_lp_barycenter():
np.testing.assert_allclose(bary.sum(), 1)
+def test_free_support_barycenter():
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # obvious barycenter location between two diracs
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():