diff options
author | vivienseguy <vivienseguy@gmail.com> | 2018-07-05 15:42:21 +0900 |
---|---|---|
committer | vivienseguy <vivienseguy@gmail.com> | 2018-07-05 15:42:21 +0900 |
commit | 6492e95e7a3acd8e35844a0f974dc79d3e7aa349 (patch) | |
tree | 7660aad8dc985f5544cb7bc35500e1acac2bf09f /ot/lp/cvx.py | |
parent | 39cbcd302c1d1e275c628d3bac073ec1f89596c6 (diff) |
free support barycenter
Diffstat (limited to 'ot/lp/cvx.py')
-rw-r--r-- | ot/lp/cvx.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c8c75bc..3913ae5 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -10,6 +10,7 @@ LP solvers for optimal transport using cvxopt import numpy as np import scipy as sp import scipy.sparse as sps +import ot try: import cvxopt @@ -144,3 +145,82 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po return b, sol else: return b + + + + +def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs): + + """ + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + + Parameters + ---------- + data_positions : list of (k_i,d) np.ndarray + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + data_weights : list of (k_i,) np.ndarray + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + + X_init : (k,d) np.ndarray + Initialization of the support locations (on k atoms) of the barycenter + b_init : (k,) np.ndarray + Initialization of the weights of the barycenter (non-negatives, sum to 1) + lambda : (k,) np.ndarray + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) np.ndarray + Support locations (on k atoms) of the barycenter + + References + ---------- + + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + iter_count = 0 + + d = X_init.shape[1] + k = b_init.size + N = len(data_positions) + + X = X_init + + displacement_square_norm = 1e3 + + while ( displacement_square_norm > stopThr and iter_count < numItermax ): + + T_sum = np.zeros((k, d)) + + for (data_positions_i, data_weights_i) in zip(data_positions, data_weights): + M_i = ot.dist(X, data_positions_i) + T_i = ot.emd(b_init, data_weights_i, M_i) + T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i) + + X_previous = X + X = T_sum / N + + displacement_square_norm = np.sum(np.square(X-X_previous)) + + iter_count += 1 + + return X + |