summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py125
1 files changed, 115 insertions, 10 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c8e69ce..35e51f8 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
np.exp(K, out=K)
# print(np.min(K))
- tmp = np.empty(K.shape, dtype=M.dtype)
tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
+
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- np.multiply(u.reshape(-1, 1), K, out=tmp)
- np.multiply(tmp, v.reshape(1, -1), out=tmp)
- np.sum(tmp, axis=0, out=tmp2)
- tmp2 -= b
- err = np.linalg.norm(tmp2)**2
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -924,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - reg is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+
+ Parameters
+ ----------
+ A : np.ndarray (n,w,h)
+ n distributions (2D images) of size w x h
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each image on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (w,h) ndarray
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
+ Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
+ ACM Transactions on Graphics (TOG), 34(4), 66
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[0]) / A.shape[0]
+ else:
+ assert(len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+
+ cpt = 0
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x):
+ return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
+ for r in range(A.shape[0]):
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
+ b = np.exp(b)
+ for r in range(A.shape[0]):
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
+ else:
+ return b
+
+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
"""