summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py378
1 files changed, 290 insertions, 88 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c06af2f..20bef7e 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array
from .backend import get_backend
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=True,
- **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn)
+ warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -207,6 +208,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -257,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -320,15 +327,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -341,23 +351,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -406,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -465,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
- if n_hists:
- u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
- v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ if warmstart is None:
+ if n_hists:
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ else:
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
else:
- u = nx.ones(dim_a, type_as=M) / dim_a
- v = nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
@@ -538,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -587,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -647,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
else:
n_hists = 0
+ # in case of multiple historgrams
+ if n_hists > 1 and warmstart is None:
+ warmstart = [None] * n_hists
+
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
@@ -654,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
lst_v = []
for k in range(n_hists):
- res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
@@ -682,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
-
- u = nx.zeros(dim_a, type_as=M)
- v = nx.zeros(dim_b, type_as=M)
+ if warmstart is None:
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+ else:
+ u, v = warmstart
def get_logT(u, v):
if n_hists:
@@ -738,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False, warn=True):
+ log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -786,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -844,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
K = nx.exp(-M / reg)
- u = nx.full((dim_a,), 1. / dim_a, type_as=K)
- v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ if warmstart is None:
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ else:
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
@@ -1065,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
+ alpha, beta = alpha + reg * \
+ nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
@@ -1278,7 +1312,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
- numItermax=numInnerItermax, stopThr=1e-9,
+ numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
@@ -1289,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
@@ -1511,7 +1547,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
for ii in range(numItermax):
- UKv = u * nx.dot(K, A / nx.dot(K, u))
+ UKv = u * nx.dot(K.T, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
if ii % 10 == 1:
@@ -1540,6 +1576,129 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
+ numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
+ **kwargs):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
+
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+ - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
+ transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+
+ Parameters
+ ----------
+ measures_locations : list of N (k_i,d) array-like
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) array-like
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
+
+ X_init : (k,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ reg : float
+ Regularization term >0
+ b : (k,) array-like
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (N,) array-like
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations when calculating the transport plans with Sinkhorn
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) array-like
+ Support locations (on k atoms) of the barycenter
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT solver
+ ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
+
+ .. _references-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = nx.ones((k,), type_as=X_init) / k
+ if weights is None:
+ weights = nx.ones((N,), type_as=X_init) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
+
+ T_sum = nx.zeros((k, d), type_as=X_init)
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ M_i = dist(X, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
+ numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / \
+ b[:, None] * nx.dot(T_i, measure_locations_i)
+
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n',
+ iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
@@ -2084,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2162,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2321,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
# debiased Sinkhorn does not converge monotonically
@@ -2401,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
@@ -2729,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -2782,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
@@ -2832,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log = {"err": []}
log_a, log_b = nx.log(a), nx.log(b)
- f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ if warmstart is None:
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ else:
+ f, g = warmstart
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
- raise ValueError("Batch size must be in integer or a tuple of two integers")
+ raise ValueError(
+ "Batch size must be in integer or a tuple of two integers")
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
@@ -2877,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
- nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ nx.sum(nx.exp(f[i:i + bs, None] +
+ g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
@@ -2885,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log["err"].append(err)
if verbose and (i_ot + 1) % 100 == 0:
- print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+ print("Error in marginal at iteration {} = {}".format(
+ i_ot + 1, err))
if err <= stopThr:
break
@@ -2905,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=True, **kwargs)
+ verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=False, **kwargs)
+ verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, isLazy=False,
- batchSize=100, verbose=False, log=False, warn=True, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -2939,6 +3111,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
Parameters
----------
@@ -2969,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
-
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3025,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
- numIterMax=numIterMax, stopThr=stopThr,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -3053,25 +3233,23 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return loss
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -3118,6 +3296,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
+ \gamma^*_b , \mathbf{M_b} \rangle_F)/2`.
+
+ .. note: The current implementation does not account for the entropic contributions and thus differs from the
+ Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
+ will be provided in a future release.
+
Parameters
----------
@@ -3141,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3167,23 +3355,34 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
International Conference on Artficial Intelligence and Statistics,
(AISTATS) 21, 2018
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+ if warmstart is None:
+ warmstart_a, warmstart_b = None, None
+ else:
+ u, v = warmstart
+ warmstart_a = (u, u)
+ warmstart_b = (v, v)
+
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -3193,26 +3392,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log['log_sinkhorn_a'] = log_a
log['log_sinkhorn_b'] = log_b
- return max(0, sinkhorn_div), log
+ return nx.maximum(0, sinkhorn_div), log
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
+ return nx.maximum(0, sinkhorn_div)
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
@@ -3379,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort
@@ -3389,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort