From 6492e95e7a3acd8e35844a0f974dc79d3e7aa349 Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 15:42:21 +0900 Subject: free support barycenter --- ot/lp/cvx.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) (limited to 'ot') diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c8c75bc..3913ae5 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -10,6 +10,7 @@ LP solvers for optimal transport using cvxopt import numpy as np import scipy as sp import scipy.sparse as sps +import ot try: import cvxopt @@ -144,3 +145,82 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po return b, sol else: return b + + + + +def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs): + + """ + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + + Parameters + ---------- + data_positions : list of (k_i,d) np.ndarray + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + data_weights : list of (k_i,) np.ndarray + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + + X_init : (k,d) np.ndarray + Initialization of the support locations (on k atoms) of the barycenter + b_init : (k,) np.ndarray + Initialization of the weights of the barycenter (non-negatives, sum to 1) + lambda : (k,) np.ndarray + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) np.ndarray + Support locations (on k atoms) of the barycenter + + References + ---------- + + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + iter_count = 0 + + d = X_init.shape[1] + k = b_init.size + N = len(data_positions) + + X = X_init + + displacement_square_norm = 1e3 + + while ( displacement_square_norm > stopThr and iter_count < numItermax ): + + T_sum = np.zeros((k, d)) + + for (data_positions_i, data_weights_i) in zip(data_positions, data_weights): + M_i = ot.dist(X, data_positions_i) + T_i = ot.emd(b_init, data_weights_i, M_i) + T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i) + + X_previous = X + X = T_sum / N + + displacement_square_norm = np.sum(np.square(X-X_previous)) + + iter_count += 1 + + return X + -- cgit v1.2.3 From 3f23fa1a950ffde4a4224a6343a504a0c5b7851b Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 17:35:27 +0900 Subject: free support barycenter --- examples/plot_free_support_barycenter.py | 66 ++++++++++++++++++++++++++++++++ ot/lp/cvx.py | 29 ++++++++------ 2 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 examples/plot_free_support_barycenter.py (limited to 'ot') diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py new file mode 100644 index 0000000..a733745 --- /dev/null +++ b/examples/plot_free_support_barycenter.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D Wasserstein barycenters between empirical distributions +==================================================== + +Illustration of 2D Wasserstein barycenters between discributions that are weighted +sum of diracs. + +""" + +# Author: Vivien Seguy +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot + + +############################################################################## +# Generate data +# ------------- +#%% parameters and data generation +N = 4 +d = 2 +measures_locations = [] +measures_weights = [] + +for i in range(N): + + n = np.rand.int(low=1, high=20) # nb samples + + mu = np.random.normal(0., 1., (d,)) + cov = np.random.normal(0., 1., (d,d)) + + xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) + b = np.random.uniform(0., 1., n) + b = b/np.sum(b) + + measures_locations.append(xs) + measures_weights.append(b) + +k = 10 +X_init = np.random.normal(0., 1., (k,d)) +b_init = np.ones((k,)) / k + + +############################################################################## +# Compute free support barycenter +# ------------- +X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) + + +############################################################################## +# Plot data +# --------- + +#%% plot samples + +pl.figure(1) +for (xs, b) in zip(measures_locations, measures_weights): + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') +pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') +pl.legend(loc=0) +pl.title('Data measures and their barycenter') diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 3913ae5..91a5922 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): - """Compute the entropic regularized wasserstein barycenter of distributions A + """Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -149,7 +149,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po -def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs): +def free_support_barycenter(measures_locations, measures_weights, X_init, b_init, weights=None, numItermax=100, stopThr=1e-6, verbose=False): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -170,7 +170,7 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, Initialization of the support locations (on k atoms) of the barycenter b_init : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - lambda : (k,) np.ndarray + weights : (k,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -200,25 +200,30 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, d = X_init.shape[1] k = b_init.size - N = len(data_positions) + N = len(measures_locations) + + if not weights: + weights = np.ones((N,))/N X = X_init - displacement_square_norm = 1e3 + displacement_square_norm = stopThr+1. while ( displacement_square_norm > stopThr and iter_count < numItermax ): T_sum = np.zeros((k, d)) - for (data_positions_i, data_weights_i) in zip(data_positions, data_weights): - M_i = ot.dist(X, data_positions_i) - T_i = ot.emd(b_init, data_weights_i, M_i) - T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i) + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): + + M_i = ot.dist(X, measure_locations_i) + T_i = ot.emd(b_init, measure_weights_i, M_i) + T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) - X_previous = X - X = T_sum / N + displacement_square_norm = np.sum(np.square(X-T_sum)) + X = T_sum - displacement_square_norm = np.sum(np.square(X-X_previous)) + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) iter_count += 1 -- cgit v1.2.3 From 98ce4ccd3536d95c609ee1c5b737ced85d68f786 Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 18:26:55 +0900 Subject: free support barycenter --- examples/plot_free_support_barycenter.py | 9 ++++----- ot/lp/cvx.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) (limited to 'ot') diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index a733745..274cf76 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -17,7 +17,6 @@ import numpy as np import matplotlib.pylab as pl import ot.plot - ############################################################################## # Generate data # ------------- @@ -29,10 +28,10 @@ measures_weights = [] for i in range(N): - n = np.rand.int(low=1, high=20) # nb samples + n = np.random.randint(low=1, high=20) # nb samples mu = np.random.normal(0., 1., (d,)) - cov = np.random.normal(0., 1., (d,d)) + cov = np.random.uniform(0., 1., (d,d)) xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) b = np.random.uniform(0., 1., n) @@ -49,7 +48,7 @@ b_init = np.ones((k,)) / k ############################################################################## # Compute free support barycenter # ------------- -X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) +X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init) ############################################################################## @@ -60,7 +59,7 @@ X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) pl.figure(1) for (xs, b) in zip(measures_locations, measures_weights): - pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') pl.legend(loc=0) pl.title('Data measures and their barycenter') diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 91a5922..b74960f 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -217,7 +217,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init M_i = ot.dist(X, measure_locations_i) T_i = ot.emd(b_init, measure_weights_i, M_i) - T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_sum = T_sum + weight_i*np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(X-T_sum)) X = T_sum -- cgit v1.2.3 From 2c7b98009f33e278a2e7e95a035c6a6231bec44e Mon Sep 17 00:00:00 2001 From: Vivien Seguy Date: Fri, 6 Jul 2018 01:58:47 +0900 Subject: add free support barycenter algorithm --- README.md | 3 +++ examples/plot_free_support_barycenter.py | 35 ++++++++++++++++---------------- ot/lp/cvx.py | 22 ++++++++------------ 3 files changed, 30 insertions(+), 30 deletions(-) (limited to 'ot') diff --git a/README.md b/README.md index 677a23b..dded582 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ It provides the following solvers: * Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat). * Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized Wasserstein barycenters [16] with LP solver (only small scale). +* Non regularized free support Wasserstein barycenters [20]. * Bregman projections for Wasserstein barycenter [3] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. @@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016). [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning \ No newline at end of file diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index 61671cf..42e22fc 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -21,7 +21,7 @@ import ot.plot # Generate data # ------------- #%% parameters and data generation -N = 6 +N = 3 d = 2 measures_locations = [] measures_weights = [] @@ -33,24 +33,25 @@ for i in range(N): mu = np.random.normal(0., 4., (d,)) A = np.random.rand(d, d) - cov = np.dot(A,A.transpose()) + cov = np.dot(A, A.transpose()) - xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) - b = np.random.uniform(0., 1., (n,)) - b = b/np.sum(b) + x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov) + b_i = np.random.uniform(0., 1., (n,)) + b_i = b_i / np.sum(b_i) - measures_locations.append(xs) - measures_weights.append(b) - -k = 10 -X_init = np.random.normal(0., 1., (k,d)) -b_init = np.ones((k,)) / k + measures_locations.append(x_i) + measures_weights.append(b_i) ############################################################################## # Compute free support barycenter # ------------- -X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init) + +k = 10 +X_init = np.random.normal(0., 1., (k, d)) +b = np.ones((k,)) / k + +X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b) ############################################################################## @@ -60,10 +61,10 @@ X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_in #%% plot samples pl.figure(1) -for (xs, b) in zip(measures_locations, measures_weights): - color = np.random.randint(low=1, high=10*N) - pl.scatter(xs[:, 0], xs[:, 1], s=b*1000, label='input measure') -pl.scatter(X[:, 0], X[:, 1], s=b_init*1000, c='black' , marker='^', label='2-Wasserstein barycenter') +for (x_i, b_i) in zip(measures_locations, measures_weights): + color = np.random.randint(low=1, high=10 * N) + pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure') +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') pl.title('Data measures and their barycenter') pl.legend(loc=0) -pl.show() \ No newline at end of file +pl.show() diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index b74960f..c097f58 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -147,10 +147,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po return b - - -def free_support_barycenter(measures_locations, measures_weights, X_init, b_init, weights=None, numItermax=100, stopThr=1e-6, verbose=False): - +def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -168,7 +165,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init X_init : (k,d) np.ndarray Initialization of the support locations (on k atoms) of the barycenter - b_init : (k,) np.ndarray + b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (k,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) @@ -199,27 +196,27 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init iter_count = 0 d = X_init.shape[1] - k = b_init.size + k = b.size N = len(measures_locations) if not weights: - weights = np.ones((N,))/N + weights = np.ones((N,)) / N X = X_init - displacement_square_norm = stopThr+1. + displacement_square_norm = stopThr + 1. - while ( displacement_square_norm > stopThr and iter_count < numItermax ): + while (displacement_square_norm > stopThr and iter_count < numItermax): T_sum = np.zeros((k, d)) for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): M_i = ot.dist(X, measure_locations_i) - T_i = ot.emd(b_init, measure_weights_i, M_i) - T_sum = T_sum + weight_i*np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_i = ot.emd(b, measure_weights_i, M_i) + T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(X-T_sum)) + displacement_square_norm = np.sum(np.square(X - T_sum)) X = T_sum if verbose: @@ -228,4 +225,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init iter_count += 1 return X - -- cgit v1.2.3 From 46712790c1276f1ecb3496362a8117e153782ede Mon Sep 17 00:00:00 2001 From: Vivien Seguy Date: Mon, 9 Jul 2018 16:49:21 +0900 Subject: add test free support barycenter algorithm + cleaning --- examples/plot_free_support_barycenter.py | 29 +++++----- ot/lp/__init__.py | 92 +++++++++++++++++++++++++++++++- ot/lp/cvx.py | 82 +--------------------------- test/test_ot.py | 15 ++++++ 4 files changed, 121 insertions(+), 97 deletions(-) (limited to 'ot') diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index 5b08507..b2e62c8 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ==================================================== -2D Wasserstein barycenters of distributions +2D free support Wasserstein barycenters of distributions ==================================================== Illustration of 2D Wasserstein barycenters if discributions that are weighted @@ -15,7 +15,8 @@ sum of diracs. import numpy as np import matplotlib.pylab as pl -import ot.plot +import ot + ############################################################################## # Generate data @@ -28,16 +29,16 @@ measures_weights = [] for i in range(N): - n = np.random.randint(low=1, high=20) # nb samples + n_i = np.random.randint(low=1, high=20) # nb samples - mu = np.random.normal(0., 4., (d,)) + mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean - A = np.random.rand(d, d) - cov = np.dot(A, A.transpose()) + A_i = np.random.rand(d, d) + cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix - x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov) - b_i = np.random.uniform(0., 1., (n,)) - b_i = b_i / np.sum(b_i) + x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations + b_i = np.random.uniform(0., 1., (n_i,)) + b_i = b_i / np.sum(b_i) # Dirac weights measures_locations.append(x_i) measures_weights.append(b_i) @@ -47,19 +48,17 @@ for i in range(N): # Compute free support barycenter # ------------- -k = 10 -X_init = np.random.normal(0., 1., (k, d)) -b = np.ones((k,)) / k +k = 10 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) -X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b) +X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) ############################################################################## # Plot data # --------- -#%% plot samples - pl.figure(1) for (x_i, b_i) in zip(measures_locations, measures_weights): color = np.random.randint(low=1, high=10 * N) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4c0d170..96bf6de 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -17,8 +17,9 @@ from .import cvx from .emd_wrap import emd_c, check_result from ..utils import parmap from .cvx import barycenter +from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'cvx'] +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx'] def emd(a, b, M, numItermax=100000, log=False): @@ -216,3 +217,92 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), res = parmap(f, [b[:, i] for i in range(nb)], processes) return res + + + +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): + """ + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of (k_i,d) np.ndarray + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + measures_weights : list of (k_i,) np.ndarray + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + + X_init : (k,d) np.ndarray + Initialization of the support locations (on k atoms) of the barycenter + b : (k,) np.ndarray + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (k,) np.ndarray + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) np.ndarray + Support locations (on k atoms) of the barycenter + + References + ---------- + + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = np.ones((k,))/k + if weights is None: + weights = np.ones((N,)) / N + + X = X_init + + displacement_square_norms = [] + displacement_square_norm = stopThr + 1. + + while ( displacement_square_norm > stopThr and iter_count < numItermax ): + + T_sum = np.zeros((k, d)) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): + + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i) + T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + + displacement_square_norm = np.sum(np.square(T_sum-X)) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + return X, displacement_square_norms + else: + return X \ No newline at end of file diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c097f58..8e763be 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -10,7 +10,7 @@ LP solvers for optimal transport using cvxopt import numpy as np import scipy as sp import scipy.sparse as sps -import ot + try: import cvxopt @@ -145,83 +145,3 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po return b, sol else: return b - - -def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False): - """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) - - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. - This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. - - Parameters - ---------- - data_positions : list of (k_i,d) np.ndarray - The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - data_weights : list of (k_i,) np.ndarray - Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure - - X_init : (k,d) np.ndarray - Initialization of the support locations (on k atoms) of the barycenter - b : (k,) np.ndarray - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshol on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - Returns - ------- - X : (k,d) np.ndarray - Support locations (on k atoms) of the barycenter - - References - ---------- - - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - - iter_count = 0 - - d = X_init.shape[1] - k = b.size - N = len(measures_locations) - - if not weights: - weights = np.ones((N,)) / N - - X = X_init - - displacement_square_norm = stopThr + 1. - - while (displacement_square_norm > stopThr and iter_count < numItermax): - - T_sum = np.zeros((k, d)) - - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): - - M_i = ot.dist(X, measure_locations_i) - T_i = ot.emd(b, measure_weights_i, M_i) - T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - - displacement_square_norm = np.sum(np.square(X - T_sum)) - X = T_sum - - if verbose: - print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) - - iter_count += 1 - - return X diff --git a/test/test_ot.py b/test/test_ot.py index 399e549..dafc03f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -135,6 +135,21 @@ def test_lp_barycenter(): np.testing.assert_allclose(bary.sum(), 1) +def test_free_support_barycenter(): + + measures_locations = [np.array([-1.]).reshape((1,1)), np.array([1.]).reshape((1,1))] + measures_weights = [np.array([1.]), np.array([1.])] + + X_init = np.array([-12.]).reshape((1,1)) + + # obvious barycenter location between two diracs + bar_locations = np.array([0.]).reshape((1,1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): -- cgit v1.2.3 From af57d90c83c860db5a2160e79aea407ae379f7b0 Mon Sep 17 00:00:00 2001 From: Vivien Seguy Date: Mon, 9 Jul 2018 17:40:41 +0900 Subject: return log dict in free support barycenter function --- ot/lp/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'ot') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 96bf6de..02cbd8c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -278,7 +278,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None X = X_init + log_dict = {} displacement_square_norms = [] + displacement_square_norm = stopThr + 1. while ( displacement_square_norm > stopThr and iter_count < numItermax ): @@ -303,6 +305,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None iter_count += 1 if log: - return X, displacement_square_norms + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict else: return X \ No newline at end of file -- cgit v1.2.3