summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py228
1 files changed, 144 insertions, 84 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c33c92c..192a9e2 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array
from .backend import get_backend
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=True,
- **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn)
+ warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -260,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -324,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -348,25 +352,24 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -415,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -474,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
- if n_hists:
- u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
- v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ if warmstart is None:
+ if n_hists:
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ else:
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
else:
- u = nx.ones(dim_a, type_as=M) / dim_a
- v = nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
@@ -547,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -596,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -656,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
else:
n_hists = 0
+ # in case of multiple historgrams
+ if n_hists > 1 and warmstart is None:
+ warmstart = [None] * n_hists
+
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
@@ -663,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
lst_v = []
for k in range(n_hists):
- res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
@@ -691,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
-
- u = nx.zeros(dim_a, type_as=M)
- v = nx.zeros(dim_b, type_as=M)
+ if warmstart is None:
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+ else:
+ u, v = warmstart
def get_logT(u, v):
if n_hists:
@@ -747,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False, warn=True):
+ log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -795,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -853,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
K = nx.exp(-M / reg)
- u = nx.full((dim_a,), 1. / dim_a, type_as=K)
- v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ if warmstart is None:
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ else:
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
@@ -1074,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
+ alpha, beta = alpha + reg * \
+ nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
@@ -1298,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
@@ -1648,8 +1675,10 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
- T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
- T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
+ numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / \
+ b[:, None] * nx.dot(T_i, measure_locations_i)
displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
@@ -1658,7 +1687,8 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini
X = T_sum
if verbose:
- print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+ print('iteration %d, displacement_square_norm=%f\n',
+ iter_count, displacement_square_norm)
iter_count += 1
@@ -2213,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2291,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2450,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
# debiased Sinkhorn does not converge monotonically
@@ -2530,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
@@ -2858,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -2911,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
@@ -2961,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log = {"err": []}
log_a, log_b = nx.log(a), nx.log(b)
- f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ if warmstart is None:
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ else:
+ f, g = warmstart
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
- raise ValueError("Batch size must be in integer or a tuple of two integers")
+ raise ValueError(
+ "Batch size must be in integer or a tuple of two integers")
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
@@ -3006,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
- nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ nx.sum(nx.exp(f[i:i + bs, None] +
+ g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
@@ -3014,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log["err"].append(err)
if verbose and (i_ot + 1) % 100 == 0:
- print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+ print("Error in marginal at iteration {} = {}".format(
+ i_ot + 1, err))
if err <= stopThr:
break
@@ -3034,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=True, **kwargs)
+ verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=False, **kwargs)
+ verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, isLazy=False,
- batchSize=100, verbose=False, log=False, warn=True, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -3101,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
-
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3157,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
- numIterMax=numIterMax, stopThr=stopThr,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -3190,19 +3238,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -3279,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3308,24 +3358,31 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
X_s, X_t = list_to_array(X_s, X_t)
nx = get_backend(X_s, X_t)
+ if warmstart is None:
+ warmstart_a, warmstart_b = None, None
+ else:
+ u, v = warmstart
+ warmstart_a = (u, u)
+ warmstart_b = (v, v)
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -3340,20 +3397,21 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
return nx.maximum(0, sinkhorn_div)
@@ -3521,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort
@@ -3531,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort