summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvivienseguy <vivienseguy@gmail.com>2018-07-05 17:35:27 +0900
committervivienseguy <vivienseguy@gmail.com>2018-07-05 17:35:27 +0900
commit3f23fa1a950ffde4a4224a6343a504a0c5b7851b (patch)
treea6bb8430703046150a4f92a0942e0cc90abd34b3
parent6492e95e7a3acd8e35844a0f974dc79d3e7aa349 (diff)
free support barycenter
-rw-r--r--examples/plot_free_support_barycenter.py66
-rw-r--r--ot/lp/cvx.py29
2 files changed, 83 insertions, 12 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
new file mode 100644
index 0000000..a733745
--- /dev/null
+++ b/examples/plot_free_support_barycenter.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D Wasserstein barycenters between empirical distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters between discributions that are weighted
+sum of diracs.
+
+"""
+
+# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+
+
+##############################################################################
+# Generate data
+# -------------
+#%% parameters and data generation
+N = 4
+d = 2
+measures_locations = []
+measures_weights = []
+
+for i in range(N):
+
+ n = np.rand.int(low=1, high=20) # nb samples
+
+ mu = np.random.normal(0., 1., (d,))
+ cov = np.random.normal(0., 1., (d,d))
+
+ xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
+ b = np.random.uniform(0., 1., n)
+ b = b/np.sum(b)
+
+ measures_locations.append(xs)
+ measures_weights.append(b)
+
+k = 10
+X_init = np.random.normal(0., 1., (k,d))
+b_init = np.ones((k,)) / k
+
+
+##############################################################################
+# Compute free support barycenter
+# -------------
+X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init)
+
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot samples
+
+pl.figure(1)
+for (xs, b) in zip(measures_locations, measures_weights):
+ pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
+pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter')
+pl.legend(loc=0)
+pl.title('Data measures and their barycenter')
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 3913ae5..91a5922 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the entropic regularized wasserstein barycenter of distributions A
+ """Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
@@ -149,7 +149,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
-def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs):
+def free_support_barycenter(measures_locations, measures_weights, X_init, b_init, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -170,7 +170,7 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
Initialization of the support locations (on k atoms) of the barycenter
b_init : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- lambda : (k,) np.ndarray
+ weights : (k,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -200,25 +200,30 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
d = X_init.shape[1]
k = b_init.size
- N = len(data_positions)
+ N = len(measures_locations)
+
+ if not weights:
+ weights = np.ones((N,))/N
X = X_init
- displacement_square_norm = 1e3
+ displacement_square_norm = stopThr+1.
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
T_sum = np.zeros((k, d))
- for (data_positions_i, data_weights_i) in zip(data_positions, data_weights):
- M_i = ot.dist(X, data_positions_i)
- T_i = ot.emd(b_init, data_weights_i, M_i)
- T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i)
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
+
+ M_i = ot.dist(X, measure_locations_i)
+ T_i = ot.emd(b_init, measure_weights_i, M_i)
+ T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
- X_previous = X
- X = T_sum / N
+ displacement_square_norm = np.sum(np.square(X-T_sum))
+ X = T_sum
- displacement_square_norm = np.sum(np.square(X-X_previous))
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
iter_count += 1