summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:24:05 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:24:05 +0200
commit63bbeb34e48f02c97a762dab5232158d90a5cffc (patch)
tree853026b5854b6e4b01fdf750db139985b3dd596f /ot/bregman.py
parentf70aabfcc11f92181e0dc987b341bad8ec030d75 (diff)
parentf66ab58c7c895011fd37bafd3e848828399c56c4 (diff)
Merge remote-tracking branch 'rflamary/master'
merge pot
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py611
1 files changed, 593 insertions, 18 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index ffa6202..321712b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,6 +11,7 @@ Bregman projections for regularized OT
# License: MIT License
import numpy as np
+from .utils import unif, dist
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -49,7 +50,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
'sinkhorn_epsilon_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
@@ -105,6 +106,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'greenkhorn':
+ def sink():
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
def sink():
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -118,7 +123,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
print('Warning : unknown method using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
return sink()
@@ -199,6 +205,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
See Also
@@ -206,6 +214,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
@@ -346,8 +355,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# print(reg)
- K = np.exp(-M / reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
# print(np.min(K))
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -355,13 +369,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
+
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
- if (np.any(KtransposeU == 0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (np.any(KtransposeU == 0)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -375,8 +390,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -391,10 +407,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -408,6 +421,159 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+ """
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper
+
+ Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
+ by Jason Altschuler, Jonathan Weed, Philippe Rigollet
+ appeared at NIPS 2017
+
+ which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.bregman.greenkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ n = a.shape[0]
+ m = b.shape[0]
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ u = np.full(n, 1. / n)
+ v = np.full(m, 1. / m)
+ G = u[:, np.newaxis] * K * v[np.newaxis, :]
+
+ viol = G.sum(1) - a
+ viol_2 = G.sum(0) - b
+ stopThr_val = 1
+
+ if log:
+ log = dict()
+ log['u'] = u
+ log['v'] = v
+
+ for i in range(numItermax):
+ i_1 = np.argmax(np.abs(viol))
+ i_2 = np.argmax(np.abs(viol_2))
+ m_viol_1 = np.abs(viol[i_1])
+ m_viol_2 = np.abs(viol_2[i_2])
+ stopThr_val = np.maximum(m_viol_1, m_viol_2)
+
+ if m_viol_1 > m_viol_2:
+ old_u = u[i_1]
+ u[i_1] = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = u[i_1] * K[i_1, :] * v
+
+ viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+
+ else:
+ old_v = v[i_2]
+ v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * v[i_2]
+ #aviol = (G@one_m - a)
+ #aviol_2 = (G.T@one_n - b)
+ viol += (-old_v + v[i_2]) * K[:, i_2] * u
+ viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
+
+ #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+
+ if stopThr_val <= stopThr:
+ break
+ else:
+ print('Warning: Algorithm did not converge')
+
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ if log:
+ return G, log
+ else:
+ return G
+
+
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
@@ -532,13 +698,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) -
- beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1))
+ - beta.reshape((1, nb))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
- reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb)))
+ / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
# print(np.min(K))
@@ -748,8 +914,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) -
- beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1))
+ - beta.reshape((1, nb))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -917,6 +1083,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - reg is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+
+ Parameters
+ ----------
+ A : np.ndarray (n,w,h)
+ n distributions (2D images) of size w x h
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each image on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (w,h) ndarray
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
+ Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
+ ACM Transactions on Graphics (TOG), 34(4), 66
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[0]) / A.shape[0]
+ else:
+ assert(len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+
+ cpt = 0
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x):
+ return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
+ for r in range(A.shape[0]):
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
+ b = np.exp(b)
+ for r in range(A.shape[0]):
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
+ else:
+ return b
+
+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
"""
@@ -1023,3 +1299,302 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1), log
else:
return np.sum(K0, axis=1)
+
+
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem and return the
+ OT matrix from empirical data
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
+ >>> print(emp_sinkhorn)
+ >>> [[4.99977301e-01 2.26989344e-05]
+ [2.26989344e-05 4.99977301e-01]]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
+
+
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem from empirical
+ data and return the OT loss
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
+ >>> print(loss_sinkhorn)
+ >>> [4.53978687e-05]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss
+
+
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Compute the sinkhorn divergence loss from empirical data
+
+ The function solves the following optimization problems and return the
+ sinkhorn divergence :math:`S`:
+
+ .. math::
+
+ W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+
+ W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+
+ S &= W - 1/2 * (W_a + W_b)
+
+ .. math::
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+
+ \gamma_a 1 = a
+
+ \gamma_a^T 1= a
+
+ \gamma_a\geq 0
+
+ \gamma_b 1 = b
+
+ \gamma_b^T 1= b
+
+ \gamma_b\geq 0
+ where :
+
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 4
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
+ >>> print(emp_sinkhorn_div)
+ >>> [2.99977435]
+
+
+ References
+ ----------
+
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ '''
+ if log:
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+
+ log = {}
+ log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
+ log['sinkhorn_loss_a'] = sinkhorn_loss_a
+ log['sinkhorn_loss_b'] = sinkhorn_loss_b
+ log['log_sinkhorn_ab'] = log_ab
+ log['log_sinkhorn_a'] = log_a
+ log['log_sinkhorn_b'] = log_b
+
+ return max(0, sinkhorn_div), log
+
+ else:
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)