diff options
author | Vivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net> | 2018-07-09 16:49:21 +0900 |
---|---|---|
committer | Vivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net> | 2018-07-09 16:49:21 +0900 |
commit | 46712790c1276f1ecb3496362a8117e153782ede (patch) | |
tree | 15c66c04b24396078b60b8f0a89b7a9cc6476b50 /ot/lp | |
parent | 67ddb92e28d6bb44eb65686419e255c2ce3311eb (diff) |
add test free support barycenter algorithm + cleaning
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 92 | ||||
-rw-r--r-- | ot/lp/cvx.py | 82 |
2 files changed, 92 insertions, 82 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4c0d170..96bf6de 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -17,8 +17,9 @@ from .import cvx from .emd_wrap import emd_c, check_result from ..utils import parmap from .cvx import barycenter +from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'cvx'] +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx'] def emd(a, b, M, numItermax=100000, log=False): @@ -216,3 +217,92 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), res = parmap(f, [b[:, i] for i in range(nb)], processes) return res + + + +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): + """ + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + + Parameters + ---------- + measures_locations : list of (k_i,d) np.ndarray + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) + measures_weights : list of (k_i,) np.ndarray + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + + X_init : (k,d) np.ndarray + Initialization of the support locations (on k atoms) of the barycenter + b : (k,) np.ndarray + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (k,) np.ndarray + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) np.ndarray + Support locations (on k atoms) of the barycenter + + References + ---------- + + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = np.ones((k,))/k + if weights is None: + weights = np.ones((N,)) / N + + X = X_init + + displacement_square_norms = [] + displacement_square_norm = stopThr + 1. + + while ( displacement_square_norm > stopThr and iter_count < numItermax ): + + T_sum = np.zeros((k, d)) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): + + M_i = dist(X, measure_locations_i) + T_i = emd(b, measure_weights_i, M_i) + T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + + displacement_square_norm = np.sum(np.square(T_sum-X)) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + return X, displacement_square_norms + else: + return X
\ No newline at end of file diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index c097f58..8e763be 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -10,7 +10,7 @@ LP solvers for optimal transport using cvxopt import numpy as np import scipy as sp import scipy.sparse as sps -import ot + try: import cvxopt @@ -145,83 +145,3 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po return b, sol else: return b - - -def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False): - """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) - - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. - This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. - - Parameters - ---------- - data_positions : list of (k_i,d) np.ndarray - The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - data_weights : list of (k_i,) np.ndarray - Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure - - X_init : (k,d) np.ndarray - Initialization of the support locations (on k atoms) of the barycenter - b : (k,) np.ndarray - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshol on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - Returns - ------- - X : (k,d) np.ndarray - Support locations (on k atoms) of the barycenter - - References - ---------- - - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - - iter_count = 0 - - d = X_init.shape[1] - k = b.size - N = len(measures_locations) - - if not weights: - weights = np.ones((N,)) / N - - X = X_init - - displacement_square_norm = stopThr + 1. - - while (displacement_square_norm > stopThr and iter_count < numItermax): - - T_sum = np.zeros((k, d)) - - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): - - M_i = ot.dist(X, measure_locations_i) - T_i = ot.emd(b, measure_weights_i, M_i) - T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - - displacement_square_norm = np.sum(np.square(X - T_sum)) - X = T_sum - - if verbose: - print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) - - iter_count += 1 - - return X |