diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 228 |
1 files changed, 144 insertions, 84 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index c33c92c..192a9e2 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=True, - **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn) + warn=warn, warmstart=warmstart) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -260,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -324,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -348,25 +352,24 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -415,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -474,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) K = nx.exp(M / (-reg)) @@ -547,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, - log=False, warn=True, **kwargs): + log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -596,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -656,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, else: n_hists = 0 + # in case of multiple historgrams + if n_hists > 1 and warmstart is None: + warmstart = [None] * n_hists + if n_hists: # we do not want to use tensors sor we do a loop lst_loss = [] @@ -663,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, lst_v = [] for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) if log: lst_loss.append(nx.sum(M * res[0])) @@ -691,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, # we assume that no distances are null except those of the diagonal of # distances - - u = nx.zeros(dim_a, type_as=M) - v = nx.zeros(dim_b, type_as=M) + if warmstart is None: + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + else: + u, v = warmstart def get_logT(u, v): if n_hists: @@ -747,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True): + log=False, warn=True, warmstart=None): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -795,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -853,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, K = nx.exp(-M / reg) - u = nx.full((dim_a,), 1. / dim_a, type_as=K) - v = nx.full((dim_b,), 1. / dim_b, type_as=K) + if warmstart is None: + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) G = u[:, None] * K * v[None, :] viol = nx.sum(G, axis=1) - a @@ -1074,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) + alpha, beta = alpha + reg * \ + nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1298,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: if ii % (print_period * 10) == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr and ii > numItermin: @@ -1648,8 +1675,10 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) - T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) - T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, + numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / \ + b[:, None] * nx.dot(T_i, measure_locations_i) displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: @@ -1658,7 +1687,8 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini X = T_sum if verbose: - print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + print('iteration %d, displacement_square_norm=%f\n', + iter_count, displacement_square_norm) iter_count += 1 @@ -2213,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2291,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2450,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) # debiased Sinkhorn does not converge monotonically @@ -2530,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr and ii > 20: break @@ -2858,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, warn=True, **kwargs): + log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -2911,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns @@ -2961,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log = {"err": []} log_a, log_b = nx.log(a), nx.log(b) - f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + if warmstart is None: + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + else: + f, g = warmstart if isinstance(batchSize, int): bs, bt = batchSize, batchSize elif isinstance(batchSize, tuple) and len(batchSize) == 2: bs, bt = batchSize[0], batchSize[1] else: - raise ValueError("Batch size must be in integer or a tuple of two integers") + raise ValueError( + "Batch size must be in integer or a tuple of two integers") range_s, range_t = range(0, ns, bs), range(0, nt, bt) @@ -3006,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) M = nx.from_numpy(M, type_as=a) m1_cols.append( - nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + nx.sum(nx.exp(f[i:i + bs, None] + + g[None, :] - M / reg), axis=1) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) @@ -3014,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log["err"].append(err) if verbose and (i_ot + 1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + print("Error in marginal at iteration {} = {}".format( + i_ot + 1, err)) if err <= stopThr: break @@ -3034,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=True, **kwargs) + verbose=verbose, log=True, warmstart=warmstart, **kwargs) return pi, log else: pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=False, **kwargs) + verbose=verbose, log=False, warmstart=warmstart, **kwargs) return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, - batchSize=100, verbose=False, log=False, warn=True, **kwargs): + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -3101,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. - + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3157,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, + stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -3190,19 +3238,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss, log else: sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): + numIterMax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -3279,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3308,24 +3358,31 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) + if warmstart is None: + warmstart_a, warmstart_b = None, None + else: + u, v = warmstart + warmstart_a = (u, u) + warmstart_b = (v, v) if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -3340,20 +3397,21 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) return nx.maximum(0, sinkhorn_div) @@ -3521,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_cols), ns_budget - 1)[ns_budget - 1], type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -3531,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_rows), nt_budget - 1)[nt_budget - 1], type_as=M ) epsilon_v_square = b[0] / bK_sort |