summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-12 17:52:02 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-12 17:52:02 +0200
commit50bc90058940645a13e2f3e41129bdc97161dc63 (patch)
tree24031123549ee349344c83875903d5d313e26292 /ot/unbalanced.py
parent12ed1581225f70c7c8777b6ce31710453fda7f51 (diff)
add unbalanced barycenters
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py118
1 files changed, 118 insertions, 0 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index f4208b5..a30fc18 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :], log
else:
return u[:, None] * K * v[None, :]
+
+
+def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - alpha is the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions a_i of size d
+ M : np.ndarray (d,d)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ alpha : float
+ Regularization term > 0
+ weights : np.ndarray (n,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (d,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+ """
+ p, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = np.exp(- M / reg)
+
+ fi = alpha / (alpha + reg)
+
+ v = np.ones((p, n_hists)) / p
+ u = np.ones((p, 1)) / p
+
+ cpt = 0
+ err = 1.
+
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ u = (A / Kv) ** fi
+ Ktu = K.T.dot(u)
+ q = ((Ktu ** (1 - fi)).dot(weights))
+ q = q ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = (Q / Ktu) ** fi
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
+ np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if log:
+ log['niter'] = cpt
+ log['u'] = u
+ log['v'] = v
+ return q, log
+ else:
+ return q