From 3f23fa1a950ffde4a4224a6343a504a0c5b7851b Mon Sep 17 00:00:00 2001 From: vivienseguy Date: Thu, 5 Jul 2018 17:35:27 +0900 Subject: free support barycenter --- examples/plot_free_support_barycenter.py | 66 ++++++++++++++++++++++++++++++++ ot/lp/cvx.py | 29 ++++++++------ 2 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 examples/plot_free_support_barycenter.py diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py new file mode 100644 index 0000000..a733745 --- /dev/null +++ b/examples/plot_free_support_barycenter.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D Wasserstein barycenters between empirical distributions +==================================================== + +Illustration of 2D Wasserstein barycenters between discributions that are weighted +sum of diracs. + +""" + +# Author: Vivien Seguy +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot + + +############################################################################## +# Generate data +# ------------- +#%% parameters and data generation +N = 4 +d = 2 +measures_locations = [] +measures_weights = [] + +for i in range(N): + + n = np.rand.int(low=1, high=20) # nb samples + + mu = np.random.normal(0., 1., (d,)) + cov = np.random.normal(0., 1., (d,d)) + + xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) + b = np.random.uniform(0., 1., n) + b = b/np.sum(b) + + measures_locations.append(xs) + measures_weights.append(b) + +k = 10 +X_init = np.random.normal(0., 1., (k,d)) +b_init = np.ones((k,)) / k + + +############################################################################## +# Compute free support barycenter +# ------------- +X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) + + +############################################################################## +# Plot data +# --------- + +#%% plot samples + +pl.figure(1) +for (xs, b) in zip(measures_locations, measures_weights): + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') +pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') +pl.legend(loc=0) +pl.title('Data measures and their barycenter') diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 3913ae5..91a5922 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): - """Compute the entropic regularized wasserstein barycenter of distributions A + """Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -149,7 +149,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po -def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs): +def free_support_barycenter(measures_locations, measures_weights, X_init, b_init, weights=None, numItermax=100, stopThr=1e-6, verbose=False): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -170,7 +170,7 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, Initialization of the support locations (on k atoms) of the barycenter b_init : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - lambda : (k,) np.ndarray + weights : (k,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -200,25 +200,30 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, d = X_init.shape[1] k = b_init.size - N = len(data_positions) + N = len(measures_locations) + + if not weights: + weights = np.ones((N,))/N X = X_init - displacement_square_norm = 1e3 + displacement_square_norm = stopThr+1. while ( displacement_square_norm > stopThr and iter_count < numItermax ): T_sum = np.zeros((k, d)) - for (data_positions_i, data_weights_i) in zip(data_positions, data_weights): - M_i = ot.dist(X, data_positions_i) - T_i = ot.emd(b_init, data_weights_i, M_i) - T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i) + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): + + M_i = ot.dist(X, measure_locations_i) + T_i = ot.emd(b_init, measure_weights_i, M_i) + T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i) - X_previous = X - X = T_sum / N + displacement_square_norm = np.sum(np.square(X-T_sum)) + X = T_sum - displacement_square_norm = np.sum(np.square(X-X_previous)) + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) iter_count += 1 -- cgit v1.2.3