From 50bc90058940645a13e2f3e41129bdc97161dc63 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:52:02 +0200 Subject: add unbalanced barycenters --- ot/unbalanced.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) (limited to 'ot/unbalanced.py') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index f4208b5..a30fc18 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :], log else: return u[:, None] * K * v[None, :] + + +def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - alpha is the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term > 0 + alpha : float + Regularization term > 0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (d,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + """ + p, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = np.exp(- M / reg) + + fi = alpha / (alpha + reg) + + v = np.ones((p, n_hists)) / p + u = np.ones((p, 1)) / p + + cpt = 0 + err = 1. + + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (A / Kv) ** fi + Ktu = K.T.dot(u) + q = ((Ktu ** (1 - fi)).dot(weights)) + q = q ** (1 / (1 - fi)) + Q = q[:, None] + v = (Q / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ + np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if log: + log['niter'] = cpt + log['u'] = u + log['v'] = v + return q, log + else: + return q -- cgit v1.2.3