diff options
author | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:51:31 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:51:31 +0200 |
commit | 96788a3fe5601e4c3f49b592aa0d9c034247862e (patch) | |
tree | 5ee3ebcdea05f6766fc9858344913e40487e9067 /ot/bregman.py | |
parent | 35bd2c98b642df78638d7d733bc1a89d873db1de (diff) | |
parent | 89f1613861152432807077fbb146578611dc5888 (diff) |
Merge tag '0.9.0' into dfsg/latestdfsg/latest
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 378 |
1 files changed, 290 insertions, 88 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index c06af2f..20bef7e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=True, - **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn) + warn=warn, warmstart=warmstart) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -207,6 +208,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without + the entropic contribution). + .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -257,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -320,15 +327,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -341,23 +351,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -406,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -465,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) K = nx.exp(M / (-reg)) @@ -538,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, - log=False, warn=True, **kwargs): + log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -587,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -647,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, else: n_hists = 0 + # in case of multiple historgrams + if n_hists > 1 and warmstart is None: + warmstart = [None] * n_hists + if n_hists: # we do not want to use tensors sor we do a loop lst_loss = [] @@ -654,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, lst_v = [] for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) if log: lst_loss.append(nx.sum(M * res[0])) @@ -682,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, # we assume that no distances are null except those of the diagonal of # distances - - u = nx.zeros(dim_a, type_as=M) - v = nx.zeros(dim_b, type_as=M) + if warmstart is None: + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + else: + u, v = warmstart def get_logT(u, v): if n_hists: @@ -738,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True): + log=False, warn=True, warmstart=None): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -786,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -844,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, K = nx.exp(-M / reg) - u = nx.full((dim_a,), 1. / dim_a, type_as=K) - v = nx.full((dim_b,), 1. / dim_b, type_as=K) + if warmstart is None: + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) G = u[:, None] * K * v[None, :] viol = nx.sum(G, axis=1) - a @@ -1065,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) + alpha, beta = alpha + reg * \ + nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1278,7 +1312,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=1e-9, + numItermax=numInnerItermax, stopThr=stopThr, warmstart=(alpha, beta), verbose=False, print_period=20, tau=tau, log=True) @@ -1289,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: if ii % (print_period * 10) == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr and ii > numItermin: @@ -1511,7 +1547,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, for ii in range(numItermax): - UKv = u * nx.dot(K, A / nx.dot(K, u)) + UKv = u * nx.dot(K.T, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if ii % 10 == 1: @@ -1540,6 +1576,129 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, + numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / \ + b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', + iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain @@ -2084,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2162,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2321,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) # debiased Sinkhorn does not converge monotonically @@ -2401,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr and ii > 20: break @@ -2729,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, warn=True, **kwargs): + log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -2782,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns @@ -2832,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log = {"err": []} log_a, log_b = nx.log(a), nx.log(b) - f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + if warmstart is None: + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + else: + f, g = warmstart if isinstance(batchSize, int): bs, bt = batchSize, batchSize elif isinstance(batchSize, tuple) and len(batchSize) == 2: bs, bt = batchSize[0], batchSize[1] else: - raise ValueError("Batch size must be in integer or a tuple of two integers") + raise ValueError( + "Batch size must be in integer or a tuple of two integers") range_s, range_t = range(0, ns, bs), range(0, nt, bt) @@ -2877,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) M = nx.from_numpy(M, type_as=a) m1_cols.append( - nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + nx.sum(nx.exp(f[i:i + bs, None] + + g[None, :] - M / reg), axis=1) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) @@ -2885,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log["err"].append(err) if verbose and (i_ot + 1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + print("Error in marginal at iteration {} = {}".format( + i_ot + 1, err)) if err <= stopThr: break @@ -2905,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=True, **kwargs) + verbose=verbose, log=True, warmstart=warmstart, **kwargs) return pi, log else: pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=False, **kwargs) + verbose=verbose, log=False, warmstart=warmstart, **kwargs) return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, - batchSize=100, verbose=False, log=False, warn=True, **kwargs): + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2939,6 +3111,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without + the entropic contribution). + Parameters ---------- @@ -2969,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. - + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3025,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, + stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -3053,25 +3233,23 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return loss else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss, log else: sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): + numIterMax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -3118,6 +3296,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle + \gamma^*_b , \mathbf{M_b} \rangle_F)/2`. + + .. note: The current implementation does not account for the entropic contributions and thus differs from the + Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions + will be provided in a future release. + Parameters ---------- @@ -3141,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3167,23 +3355,34 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + if warmstart is None: + warmstart_a, warmstart_b = None, None + else: + u, v = warmstart + warmstart_a = (u, u) + warmstart_b = (v, v) + if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -3193,26 +3392,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli log['log_sinkhorn_a'] = log_a log['log_sinkhorn_b'] = log_b - return max(0, sinkhorn_div), log + return nx.maximum(0, sinkhorn_div), log else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) - return max(0, sinkhorn_div) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) + return nx.maximum(0, sinkhorn_div) def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, @@ -3379,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_cols), ns_budget - 1)[ns_budget - 1], type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -3389,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_rows), nt_budget - 1)[nt_budget - 1], type_as=M ) epsilon_v_square = b[0] / bK_sort |