summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-07-20 14:59:13 +0300
committerGitHub <noreply@github.com>2020-07-20 13:59:13 +0200
commit23db72c49465a1eeb2897d4c6dd9c189aec9cd6e (patch)
tree749b6ea94acbe3e6b6f6fc387cf3bb78d7e75bae /ot/lp
parentacfff525e043c8733936e8e99367543edd28822b (diff)
Correct documentation for support barycenter (#201)
* example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py26
1 files changed, 18 insertions, 8 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 514a607..2a1b082 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -272,7 +272,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
-
+
result_code_string = check_result(result_code)
if log:
log = {}
@@ -389,7 +389,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
bsel = b != 0
-
+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
if center_dual:
@@ -435,26 +435,36 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
stopThr=1e-7, verbose=False, log=None):
- """
- Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
+
+ .. math::
+ \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
+ - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
+
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
Parameters
----------
- measures_locations : list of (k_i,d) numpy.ndarray
+ measures_locations : list of N (k_i,d) numpy.ndarray
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
- measures_weights : list of (k_i,) numpy.ndarray
+ measures_weights : list of N (k_i,) numpy.ndarray
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
X_init : (k,d) np.ndarray
Initialization of the support locations (on k atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (k,) np.ndarray
+ weights : (N,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional