diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 125 |
1 files changed, 115 insertions, 10 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index c8e69ce..35e51f8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, np.exp(K, out=K) # print(np.min(K)) - tmp = np.empty(K.shape, dtype=M.dtype) tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K @@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v + KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.dot(Kp, v) @@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - np.multiply(u.reshape(-1, 1), K, out=tmp) - np.multiply(tmp, v.reshape(1, -1), out=tmp) - np.sum(tmp, axis=0, out=tmp2) - tmp2 -= b - err = np.linalg.norm(tmp2)**2 + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j', u, K, v, out=tmp2) + err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: log['err'].append(err) @@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if nbb: # return only loss - res = np.zeros((nbb)) - for i in range(nbb): - res[i] = np.sum( - u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M) + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -924,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): + """Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - reg is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + + Parameters + ---------- + A : np.ndarray (n,w,h) + n distributions (2D images) of size w x h + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each image on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (w,h) ndarray + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 + + + """ + + if weights is None: + weights = np.ones(A.shape[0]) / A.shape[0] + else: + assert(len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + b = np.zeros_like(A[0, :, :]) + U = np.ones_like(A) + KV = np.ones_like(A) + + cpt = 0 + err = 1 + + # build the convolution operator + t = np.linspace(0, 1, A.shape[1]) + [Y, X] = np.meshgrid(t, t) + xi1 = np.exp(-(X - Y)**2 / reg) + + def K(x): + return np.dot(np.dot(xi1, x), xi1) + + while (err > stopThr and cpt < numItermax): + + bold = b + cpt = cpt + 1 + + b = np.zeros_like(A[0, :, :]) + for r in range(A.shape[0]): + KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) + b = np.exp(b) + for r in range(A.shape[0]): + U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) + + if cpt % 10 == 1: + err = np.sum(np.abs(bold - b)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if log: + log['niter'] = cpt + log['U'] = U + return b, log + else: + return b + + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): """ |