diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2020-07-20 14:59:13 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-20 13:59:13 +0200 |
commit | 23db72c49465a1eeb2897d4c6dd9c189aec9cd6e (patch) | |
tree | 749b6ea94acbe3e6b6f6fc387cf3bb78d7e75bae /ot/lp | |
parent | acfff525e043c8733936e8e99367543edd28822b (diff) |
Correct documentation for support barycenter (#201)
* example for log treatment in bregman.py
* Improve doc
* Revert "example for log treatment in bregman.py"
This reverts commit 9f51c14e
* Add comments by Flamary
* Delete repetitive description
* Added raw string to avoid pbs with backslashes
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 514a607..2a1b082 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -272,7 +272,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - + result_code_string = check_result(result_code) if log: log = {} @@ -389,7 +389,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): bsel = b != 0 - + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) if center_dual: @@ -435,26 +435,36 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): - """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. Parameters ---------- - measures_locations : list of (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) numpy.ndarray The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - measures_weights : list of (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) numpy.ndarray Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure X_init : (k,d) np.ndarray Initialization of the support locations (on k atoms) of the barycenter b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray + weights : (N,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional |