summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvivienseguy <vivienseguy@gmail.com>2018-07-05 15:42:21 +0900
committervivienseguy <vivienseguy@gmail.com>2018-07-05 15:42:21 +0900
commit6492e95e7a3acd8e35844a0f974dc79d3e7aa349 (patch)
tree7660aad8dc985f5544cb7bc35500e1acac2bf09f
parent39cbcd302c1d1e275c628d3bac073ec1f89596c6 (diff)
free support barycenter
-rw-r--r--ot/lp/cvx.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index c8c75bc..3913ae5 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -10,6 +10,7 @@ LP solvers for optimal transport using cvxopt
import numpy as np
import scipy as sp
import scipy.sparse as sps
+import ot
try:
import cvxopt
@@ -144,3 +145,82 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
return b, sol
else:
return b
+
+
+
+
+def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs):
+
+ """
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+
+ The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
+ This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+
+ Parameters
+ ----------
+ data_positions : list of (k_i,d) np.ndarray
+ The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
+ data_weights : list of (k_i,) np.ndarray
+ Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+
+ X_init : (k,d) np.ndarray
+ Initialization of the support locations (on k atoms) of the barycenter
+ b_init : (k,) np.ndarray
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ lambda : (k,) np.ndarray
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) np.ndarray
+ Support locations (on k atoms) of the barycenter
+
+ References
+ ----------
+
+ .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+
+ iter_count = 0
+
+ d = X_init.shape[1]
+ k = b_init.size
+ N = len(data_positions)
+
+ X = X_init
+
+ displacement_square_norm = 1e3
+
+ while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+
+ T_sum = np.zeros((k, d))
+
+ for (data_positions_i, data_weights_i) in zip(data_positions, data_weights):
+ M_i = ot.dist(X, data_positions_i)
+ T_i = ot.emd(b_init, data_weights_i, M_i)
+ T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i)
+
+ X_previous = X
+ X = T_sum / N
+
+ displacement_square_norm = np.sum(np.square(X-X_previous))
+
+ iter_count += 1
+
+ return X
+